1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""This API defines FeatureColumn abstraction. 16 17FeatureColumns provide a high level abstraction for ingesting and representing 18features. FeatureColumns are also the primary way of encoding features for 19canned `tf.estimator.Estimator`s. 20 21When using FeatureColumns with `Estimators`, the type of feature column you 22should choose depends on (1) the feature type and (2) the model type. 23 241. Feature type: 25 26 * Continuous features can be represented by `numeric_column`. 27 * Categorical features can be represented by any `categorical_column_with_*` 28 column: 29 - `categorical_column_with_vocabulary_list` 30 - `categorical_column_with_vocabulary_file` 31 - `categorical_column_with_hash_bucket` 32 - `categorical_column_with_identity` 33 - `weighted_categorical_column` 34 352. Model type: 36 37 * Deep neural network models (`DNNClassifier`, `DNNRegressor`). 38 39 Continuous features can be directly fed into deep neural network models. 40 41 age_column = numeric_column("age") 42 43 To feed sparse features into DNN models, wrap the column with 44 `embedding_column` or `indicator_column`. `indicator_column` is recommended 45 for features with only a few possible values. For features with many 46 possible values, to reduce the size of your model, `embedding_column` is 47 recommended. 48 49 embedded_dept_column = embedding_column( 50 categorical_column_with_vocabulary_list( 51 "department", ["math", "philosophy", ...]), dimension=10) 52 53 * Wide (aka linear) models (`LinearClassifier`, `LinearRegressor`). 54 55 Sparse features can be fed directly into linear models. They behave like an 56 indicator column but with an efficient implementation. 57 58 dept_column = categorical_column_with_vocabulary_list("department", 59 ["math", "philosophy", "english"]) 60 61 It is recommended that continuous features be bucketized before being 62 fed into linear models. 63 64 bucketized_age_column = bucketized_column( 65 source_column=age_column, 66 boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65]) 67 68 Sparse features can be crossed (also known as conjuncted or combined) in 69 order to form non-linearities, and then fed into linear models. 70 71 cross_dept_age_column = crossed_column( 72 columns=["department", bucketized_age_column], 73 hash_bucket_size=1000) 74 75Example of building canned `Estimator`s using FeatureColumns: 76 77 ```python 78 # Define features and transformations 79 deep_feature_columns = [age_column, embedded_dept_column] 80 wide_feature_columns = [dept_column, bucketized_age_column, 81 cross_dept_age_column] 82 83 # Build deep model 84 estimator = DNNClassifier( 85 feature_columns=deep_feature_columns, 86 hidden_units=[500, 250, 50]) 87 estimator.train(...) 88 89 # Or build a wide model 90 estimator = LinearClassifier( 91 feature_columns=wide_feature_columns) 92 estimator.train(...) 93 94 # Or build a wide and deep model! 95 estimator = DNNLinearCombinedClassifier( 96 linear_feature_columns=wide_feature_columns, 97 dnn_feature_columns=deep_feature_columns, 98 dnn_hidden_units=[500, 250, 50]) 99 estimator.train(...) 100 ``` 101 102 103FeatureColumns can also be transformed into a generic input layer for 104custom models using `input_layer`. 105 106Example of building model using FeatureColumns, this can be used in a 107`model_fn` which is given to the {tf.estimator.Estimator}: 108 109 ```python 110 # Building model via layers 111 112 deep_feature_columns = [age_column, embedded_dept_column] 113 columns_to_tensor = parse_feature_columns_from_examples( 114 serialized=my_data, 115 feature_columns=deep_feature_columns) 116 first_layer = input_layer( 117 features=columns_to_tensor, 118 feature_columns=deep_feature_columns) 119 second_layer = fully_connected(first_layer, ...) 120 ``` 121 122NOTE: Functions prefixed with "_" indicate experimental or private parts of 123the API subject to change, and should not be relied upon! 124""" 125 126import abc 127import collections 128import math 129import re 130 131import numpy as np 132import six 133 134from tensorflow.python.data.experimental.ops import lookup_ops as data_lookup_ops 135from tensorflow.python.data.ops import readers 136from tensorflow.python.eager import context 137from tensorflow.python.feature_column import feature_column as fc_old 138from tensorflow.python.feature_column import utils as fc_utils 139from tensorflow.python.framework import dtypes 140from tensorflow.python.framework import ops 141from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib 142from tensorflow.python.framework import tensor_shape 143from tensorflow.python.ops import array_ops 144from tensorflow.python.ops import check_ops 145from tensorflow.python.ops import control_flow_ops 146from tensorflow.python.ops import embedding_ops 147from tensorflow.python.ops import init_ops 148from tensorflow.python.ops import lookup_ops 149from tensorflow.python.ops import math_ops 150from tensorflow.python.ops import parsing_ops 151from tensorflow.python.ops import sparse_ops 152from tensorflow.python.ops import string_ops 153from tensorflow.python.ops import variable_scope 154from tensorflow.python.ops import variables 155from tensorflow.python.platform import gfile 156from tensorflow.python.platform import tf_logging as logging 157from tensorflow.python.trackable import autotrackable 158from tensorflow.python.trackable import base as trackable 159from tensorflow.python.trackable import data_structures 160from tensorflow.python.training import checkpoint_utils 161from tensorflow.python.util import deprecation 162from tensorflow.python.util import nest 163from tensorflow.python.util import tf_inspect 164from tensorflow.python.util.compat import collections_abc 165from tensorflow.python.util.tf_export import tf_export 166 167 168_FEATURE_COLUMN_DEPRECATION_DATE = None 169_FEATURE_COLUMN_DEPRECATION = ('The old _FeatureColumn APIs are being ' 170 'deprecated. Please use the new FeatureColumn ' 171 'APIs instead.') 172 173 174class StateManager(object): 175 """Manages the state associated with FeatureColumns. 176 177 Some `FeatureColumn`s create variables or resources to assist their 178 computation. The `StateManager` is responsible for creating and storing these 179 objects since `FeatureColumn`s are supposed to be stateless configuration 180 only. 181 """ 182 183 def create_variable(self, 184 feature_column, 185 name, 186 shape, 187 dtype=None, 188 trainable=True, 189 use_resource=True, 190 initializer=None): 191 """Creates a new variable. 192 193 Args: 194 feature_column: A `FeatureColumn` object this variable corresponds to. 195 name: variable name. 196 shape: variable shape. 197 dtype: The type of the variable. Defaults to `self.dtype` or `float32`. 198 trainable: Whether this variable is trainable or not. 199 use_resource: If true, we use resource variables. Otherwise we use 200 RefVariable. 201 initializer: initializer instance (callable). 202 203 Returns: 204 The created variable. 205 """ 206 del feature_column, name, shape, dtype, trainable, use_resource, initializer 207 raise NotImplementedError('StateManager.create_variable') 208 209 def add_variable(self, feature_column, var): 210 """Adds an existing variable to the state. 211 212 Args: 213 feature_column: A `FeatureColumn` object to associate this variable with. 214 var: The variable. 215 """ 216 del feature_column, var 217 raise NotImplementedError('StateManager.add_variable') 218 219 def get_variable(self, feature_column, name): 220 """Returns an existing variable. 221 222 Args: 223 feature_column: A `FeatureColumn` object this variable corresponds to. 224 name: variable name. 225 """ 226 del feature_column, name 227 raise NotImplementedError('StateManager.get_var') 228 229 def add_resource(self, feature_column, name, resource): 230 """Creates a new resource. 231 232 Resources can be things such as tables, variables, trackables, etc. 233 234 Args: 235 feature_column: A `FeatureColumn` object this resource corresponds to. 236 name: Name of the resource. 237 resource: The resource. 238 239 Returns: 240 The created resource. 241 """ 242 del feature_column, name, resource 243 raise NotImplementedError('StateManager.add_resource') 244 245 def has_resource(self, feature_column, name): 246 """Returns true iff a resource with same name exists. 247 248 Resources can be things such as tables, variables, trackables, etc. 249 250 Args: 251 feature_column: A `FeatureColumn` object this variable corresponds to. 252 name: Name of the resource. 253 """ 254 del feature_column, name 255 raise NotImplementedError('StateManager.has_resource') 256 257 def get_resource(self, feature_column, name): 258 """Returns an already created resource. 259 260 Resources can be things such as tables, variables, trackables, etc. 261 262 Args: 263 feature_column: A `FeatureColumn` object this variable corresponds to. 264 name: Name of the resource. 265 """ 266 del feature_column, name 267 raise NotImplementedError('StateManager.get_resource') 268 269 270@tf_export('__internal__.feature_column.StateManager', v1=[]) 271class _StateManagerImpl(StateManager): 272 """Manages the state of DenseFeatures and LinearLayer. 273 274 Some `FeatureColumn`s create variables or resources to assist their 275 computation. The `StateManager` is responsible for creating and storing these 276 objects since `FeatureColumn`s are supposed to be stateless configuration 277 only. 278 """ 279 280 def __init__(self, layer, trainable): 281 """Creates an _StateManagerImpl object. 282 283 Args: 284 layer: The input layer this state manager is associated with. 285 trainable: Whether by default, variables created are trainable or not. 286 """ 287 self._trainable = trainable 288 self._layer = layer 289 if self._layer is not None and not hasattr(self._layer, '_resources'): 290 self._layer._resources = data_structures.Mapping() # pylint: disable=protected-access 291 self._cols_to_vars_map = collections.defaultdict(lambda: {}) 292 self._cols_to_resources_map = collections.defaultdict(lambda: {}) 293 294 def create_variable(self, 295 feature_column, 296 name, 297 shape, 298 dtype=None, 299 trainable=True, 300 use_resource=True, 301 initializer=None): 302 """Creates a new variable. 303 304 Args: 305 feature_column: A `FeatureColumn` object this variable corresponds to. 306 name: variable name. 307 shape: variable shape. 308 dtype: The type of the variable. Defaults to `self.dtype` or `float32`. 309 trainable: Whether this variable is trainable or not. 310 use_resource: If true, we use resource variables. Otherwise we use 311 RefVariable. 312 initializer: initializer instance (callable). 313 314 Returns: 315 The created variable. 316 """ 317 if name in self._cols_to_vars_map[feature_column]: 318 raise ValueError('Variable already exists.') 319 320 # We explicitly track these variables since `name` is not guaranteed to be 321 # unique and disable manual tracking that the add_weight call does. 322 with trackable.no_manual_dependency_tracking_scope(self._layer): 323 var = self._layer.add_weight( 324 name=name, 325 shape=shape, 326 dtype=dtype, 327 initializer=initializer, 328 trainable=self._trainable and trainable, 329 use_resource=use_resource, 330 # TODO(rohanj): Get rid of this hack once we have a mechanism for 331 # specifying a default partitioner for an entire layer. In that case, 332 # the default getter for Layers should work. 333 getter=variable_scope.get_variable) 334 if isinstance(var, variables.PartitionedVariable): 335 for v in var: 336 part_name = name + '/' + str(v._get_save_slice_info().var_offset[0]) # pylint: disable=protected-access 337 self._layer._track_trackable(v, feature_column.name + '/' + part_name) # pylint: disable=protected-access 338 else: 339 if isinstance(var, trackable.Trackable): 340 self._layer._track_trackable(var, feature_column.name + '/' + name) # pylint: disable=protected-access 341 342 self._cols_to_vars_map[feature_column][name] = var 343 return var 344 345 def get_variable(self, feature_column, name): 346 """Returns an existing variable. 347 348 Args: 349 feature_column: A `FeatureColumn` object this variable corresponds to. 350 name: variable name. 351 """ 352 if name in self._cols_to_vars_map[feature_column]: 353 return self._cols_to_vars_map[feature_column][name] 354 raise ValueError('Variable does not exist.') 355 356 def add_resource(self, feature_column, resource_name, resource): 357 """Creates a new resource. 358 359 Resources can be things such as tables, variables, trackables, etc. 360 361 Args: 362 feature_column: A `FeatureColumn` object this resource corresponds to. 363 resource_name: Name of the resource. 364 resource: The resource. 365 366 Returns: 367 The created resource. 368 """ 369 self._cols_to_resources_map[feature_column][resource_name] = resource 370 # pylint: disable=protected-access 371 if self._layer is not None and isinstance(resource, trackable.Trackable): 372 # Add trackable resources to the layer for serialization. 373 if feature_column.name not in self._layer._resources: 374 self._layer._resources[feature_column.name] = data_structures.Mapping() 375 if resource_name not in self._layer._resources[feature_column.name]: 376 self._layer._resources[feature_column.name][resource_name] = resource 377 # pylint: enable=protected-access 378 379 def has_resource(self, feature_column, resource_name): 380 """Returns true iff a resource with same name exists. 381 382 Resources can be things such as tables, variables, trackables, etc. 383 384 Args: 385 feature_column: A `FeatureColumn` object this variable corresponds to. 386 resource_name: Name of the resource. 387 """ 388 return resource_name in self._cols_to_resources_map[feature_column] 389 390 def get_resource(self, feature_column, resource_name): 391 """Returns an already created resource. 392 393 Resources can be things such as tables, variables, trackables, etc. 394 395 Args: 396 feature_column: A `FeatureColumn` object this variable corresponds to. 397 resource_name: Name of the resource. 398 """ 399 if (feature_column not in self._cols_to_resources_map or 400 resource_name not in self._cols_to_resources_map[feature_column]): 401 raise ValueError('Resource does not exist.') 402 return self._cols_to_resources_map[feature_column][resource_name] 403 404 405def _transform_features_v2(features, feature_columns, state_manager): 406 """Returns transformed features based on features columns passed in. 407 408 Please note that most probably you would not need to use this function. Please 409 check `input_layer` and `linear_model` to see whether they will 410 satisfy your use case or not. 411 412 Example: 413 414 ```python 415 # Define features and transformations 416 crosses_a_x_b = crossed_column( 417 columns=["sparse_feature_a", "sparse_feature_b"], hash_bucket_size=10000) 418 price_buckets = bucketized_column( 419 source_column=numeric_column("price"), boundaries=[...]) 420 421 columns = [crosses_a_x_b, price_buckets] 422 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 423 transformed = transform_features(features=features, feature_columns=columns) 424 425 assertCountEqual(columns, transformed.keys()) 426 ``` 427 428 Args: 429 features: A mapping from key to tensors. `FeatureColumn`s look up via these 430 keys. For example `numeric_column('price')` will look at 'price' key in 431 this dict. Values can be a `SparseTensor` or a `Tensor` depends on 432 corresponding `FeatureColumn`. 433 feature_columns: An iterable containing all the `FeatureColumn`s. 434 state_manager: A StateManager object that holds the FeatureColumn state. 435 436 Returns: 437 A `dict` mapping `FeatureColumn` to `Tensor` and `SparseTensor` values. 438 """ 439 feature_columns = _normalize_feature_columns(feature_columns) 440 outputs = {} 441 with ops.name_scope( 442 None, default_name='transform_features', values=features.values()): 443 transformation_cache = FeatureTransformationCache(features) 444 for column in feature_columns: 445 with ops.name_scope( 446 None, 447 default_name=_sanitize_column_name_for_variable_scope(column.name)): 448 outputs[column] = transformation_cache.get(column, state_manager) 449 return outputs 450 451 452@tf_export('feature_column.make_parse_example_spec', v1=[]) 453def make_parse_example_spec_v2(feature_columns): 454 """Creates parsing spec dictionary from input feature_columns. 455 456 The returned dictionary can be used as arg 'features' in 457 `tf.io.parse_example`. 458 459 Typical usage example: 460 461 ```python 462 # Define features and transformations 463 feature_a = tf.feature_column.categorical_column_with_vocabulary_file(...) 464 feature_b = tf.feature_column.numeric_column(...) 465 feature_c_bucketized = tf.feature_column.bucketized_column( 466 tf.feature_column.numeric_column("feature_c"), ...) 467 feature_a_x_feature_c = tf.feature_column.crossed_column( 468 columns=["feature_a", feature_c_bucketized], ...) 469 470 feature_columns = set( 471 [feature_b, feature_c_bucketized, feature_a_x_feature_c]) 472 features = tf.io.parse_example( 473 serialized=serialized_examples, 474 features=tf.feature_column.make_parse_example_spec(feature_columns)) 475 ``` 476 477 For the above example, make_parse_example_spec would return the dict: 478 479 ```python 480 { 481 "feature_a": parsing_ops.VarLenFeature(tf.string), 482 "feature_b": parsing_ops.FixedLenFeature([1], dtype=tf.float32), 483 "feature_c": parsing_ops.FixedLenFeature([1], dtype=tf.float32) 484 } 485 ``` 486 487 Args: 488 feature_columns: An iterable containing all feature columns. All items 489 should be instances of classes derived from `FeatureColumn`. 490 491 Returns: 492 A dict mapping each feature key to a `FixedLenFeature` or `VarLenFeature` 493 value. 494 495 Raises: 496 ValueError: If any of the given `feature_columns` is not a `FeatureColumn` 497 instance. 498 """ 499 result = {} 500 for column in feature_columns: 501 if not isinstance(column, FeatureColumn): 502 raise ValueError('All feature_columns must be FeatureColumn instances. ' 503 'Given: {}'.format(column)) 504 config = column.parse_example_spec 505 for key, value in six.iteritems(config): 506 if key in result and value != result[key]: 507 raise ValueError( 508 'feature_columns contain different parse_spec for key ' 509 '{}. Given {} and {}'.format(key, value, result[key])) 510 result.update(config) 511 return result 512 513 514@tf_export('feature_column.embedding_column') 515def embedding_column(categorical_column, 516 dimension, 517 combiner='mean', 518 initializer=None, 519 ckpt_to_load_from=None, 520 tensor_name_in_ckpt=None, 521 max_norm=None, 522 trainable=True, 523 use_safe_embedding_lookup=True): 524 """`DenseColumn` that converts from sparse, categorical input. 525 526 Use this when your inputs are sparse, but you want to convert them to a dense 527 representation (e.g., to feed to a DNN). 528 529 Inputs must be a `CategoricalColumn` created by any of the 530 `categorical_column_*` function. Here is an example of using 531 `embedding_column` with `DNNClassifier`: 532 533 ```python 534 video_id = categorical_column_with_identity( 535 key='video_id', num_buckets=1000000, default_value=0) 536 columns = [embedding_column(video_id, 9),...] 537 538 estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...) 539 540 label_column = ... 541 def input_fn(): 542 features = tf.io.parse_example( 543 ..., features=make_parse_example_spec(columns + [label_column])) 544 labels = features.pop(label_column.name) 545 return features, labels 546 547 estimator.train(input_fn=input_fn, steps=100) 548 ``` 549 550 Here is an example using `embedding_column` with model_fn: 551 552 ```python 553 def model_fn(features, ...): 554 video_id = categorical_column_with_identity( 555 key='video_id', num_buckets=1000000, default_value=0) 556 columns = [embedding_column(video_id, 9),...] 557 dense_tensor = input_layer(features, columns) 558 # Form DNN layers, calculate loss, and return EstimatorSpec. 559 ... 560 ``` 561 562 Args: 563 categorical_column: A `CategoricalColumn` created by a 564 `categorical_column_with_*` function. This column produces the sparse IDs 565 that are inputs to the embedding lookup. 566 dimension: An integer specifying dimension of the embedding, must be > 0. 567 combiner: A string specifying how to reduce if there are multiple entries in 568 a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with 569 'mean' the default. 'sqrtn' often achieves good accuracy, in particular 570 with bag-of-words columns. Each of this can be thought as example level 571 normalizations on the column. For more information, see 572 `tf.embedding_lookup_sparse`. 573 initializer: A variable initializer function to be used in embedding 574 variable initialization. If not specified, defaults to 575 `truncated_normal_initializer` with mean `0.0` and 576 standard deviation `1/sqrt(dimension)`. 577 ckpt_to_load_from: String representing checkpoint name/pattern from which to 578 restore column weights. Required if `tensor_name_in_ckpt` is not `None`. 579 tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which 580 to restore the column weights. Required if `ckpt_to_load_from` is not 581 `None`. 582 max_norm: If not `None`, embedding values are l2-normalized to this value. 583 trainable: Whether or not the embedding is trainable. Default is True. 584 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse 585 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures 586 there are no empty rows and all weights and ids are positive at the 587 expense of extra compute cost. This only applies to rank 2 (NxM) shaped 588 input tensors. Defaults to true, consider turning off if the above checks 589 are not needed. Note that having empty rows will not trigger any error 590 though the output result might be 0 or omitted. 591 592 Returns: 593 `DenseColumn` that converts from sparse input. 594 595 Raises: 596 ValueError: if `dimension` not > 0. 597 ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` 598 is specified. 599 ValueError: if `initializer` is specified and is not callable. 600 RuntimeError: If eager execution is enabled. 601 """ 602 if (dimension is None) or (dimension < 1): 603 raise ValueError('Invalid dimension {}.'.format(dimension)) 604 if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): 605 raise ValueError('Must specify both `ckpt_to_load_from` and ' 606 '`tensor_name_in_ckpt` or none of them.') 607 608 if (initializer is not None) and (not callable(initializer)): 609 raise ValueError('initializer must be callable if specified. ' 610 'Embedding of column_name: {}'.format( 611 categorical_column.name)) 612 if initializer is None: 613 initializer = init_ops.truncated_normal_initializer( 614 mean=0.0, stddev=1 / math.sqrt(dimension)) 615 616 return EmbeddingColumn( 617 categorical_column=categorical_column, 618 dimension=dimension, 619 combiner=combiner, 620 initializer=initializer, 621 ckpt_to_load_from=ckpt_to_load_from, 622 tensor_name_in_ckpt=tensor_name_in_ckpt, 623 max_norm=max_norm, 624 trainable=trainable, 625 use_safe_embedding_lookup=use_safe_embedding_lookup) 626 627 628@tf_export(v1=['feature_column.shared_embedding_columns']) 629def shared_embedding_columns(categorical_columns, 630 dimension, 631 combiner='mean', 632 initializer=None, 633 shared_embedding_collection_name=None, 634 ckpt_to_load_from=None, 635 tensor_name_in_ckpt=None, 636 max_norm=None, 637 trainable=True, 638 use_safe_embedding_lookup=True): 639 """List of dense columns that convert from sparse, categorical input. 640 641 This is similar to `embedding_column`, except that it produces a list of 642 embedding columns that share the same embedding weights. 643 644 Use this when your inputs are sparse and of the same type (e.g. watched and 645 impression video IDs that share the same vocabulary), and you want to convert 646 them to a dense representation (e.g., to feed to a DNN). 647 648 Inputs must be a list of categorical columns created by any of the 649 `categorical_column_*` function. They must all be of the same type and have 650 the same arguments except `key`. E.g. they can be 651 categorical_column_with_vocabulary_file with the same vocabulary_file. Some or 652 all columns could also be weighted_categorical_column. 653 654 Here is an example embedding of two features for a DNNClassifier model: 655 656 ```python 657 watched_video_id = categorical_column_with_vocabulary_file( 658 'watched_video_id', video_vocabulary_file, video_vocabulary_size) 659 impression_video_id = categorical_column_with_vocabulary_file( 660 'impression_video_id', video_vocabulary_file, video_vocabulary_size) 661 columns = shared_embedding_columns( 662 [watched_video_id, impression_video_id], dimension=10) 663 664 estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...) 665 666 label_column = ... 667 def input_fn(): 668 features = tf.io.parse_example( 669 ..., features=make_parse_example_spec(columns + [label_column])) 670 labels = features.pop(label_column.name) 671 return features, labels 672 673 estimator.train(input_fn=input_fn, steps=100) 674 ``` 675 676 Here is an example using `shared_embedding_columns` with model_fn: 677 678 ```python 679 def model_fn(features, ...): 680 watched_video_id = categorical_column_with_vocabulary_file( 681 'watched_video_id', video_vocabulary_file, video_vocabulary_size) 682 impression_video_id = categorical_column_with_vocabulary_file( 683 'impression_video_id', video_vocabulary_file, video_vocabulary_size) 684 columns = shared_embedding_columns( 685 [watched_video_id, impression_video_id], dimension=10) 686 dense_tensor = input_layer(features, columns) 687 # Form DNN layers, calculate loss, and return EstimatorSpec. 688 ... 689 ``` 690 691 Args: 692 categorical_columns: List of categorical columns created by a 693 `categorical_column_with_*` function. These columns produce the sparse IDs 694 that are inputs to the embedding lookup. All columns must be of the same 695 type and have the same arguments except `key`. E.g. they can be 696 categorical_column_with_vocabulary_file with the same vocabulary_file. 697 Some or all columns could also be weighted_categorical_column. 698 dimension: An integer specifying dimension of the embedding, must be > 0. 699 combiner: A string specifying how to reduce if there are multiple entries in 700 a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with 701 'mean' the default. 'sqrtn' often achieves good accuracy, in particular 702 with bag-of-words columns. Each of this can be thought as example level 703 normalizations on the column. For more information, see 704 `tf.embedding_lookup_sparse`. 705 initializer: A variable initializer function to be used in embedding 706 variable initialization. If not specified, defaults to 707 `truncated_normal_initializer` with mean `0.0` and 708 standard deviation `1/sqrt(dimension)`. 709 shared_embedding_collection_name: Optional name of the collection where 710 shared embedding weights are added. If not given, a reasonable name will 711 be chosen based on the names of `categorical_columns`. This is also used 712 in `variable_scope` when creating shared embedding weights. 713 ckpt_to_load_from: String representing checkpoint name/pattern from which to 714 restore column weights. Required if `tensor_name_in_ckpt` is not `None`. 715 tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from which 716 to restore the column weights. Required if `ckpt_to_load_from` is not 717 `None`. 718 max_norm: If not `None`, each embedding is clipped if its l2-norm is larger 719 than this value, before combining. 720 trainable: Whether or not the embedding is trainable. Default is True. 721 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse 722 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures 723 there are no empty rows and all weights and ids are positive at the 724 expense of extra compute cost. This only applies to rank 2 (NxM) shaped 725 input tensors. Defaults to true, consider turning off if the above checks 726 are not needed. Note that having empty rows will not trigger any error 727 though the output result might be 0 or omitted. 728 729 Returns: 730 A list of dense columns that converts from sparse input. The order of 731 results follows the ordering of `categorical_columns`. 732 733 Raises: 734 ValueError: if `dimension` not > 0. 735 ValueError: if any of the given `categorical_columns` is of different type 736 or has different arguments than the others. 737 ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` 738 is specified. 739 ValueError: if `initializer` is specified and is not callable. 740 RuntimeError: if eager execution is enabled. 741 """ 742 if context.executing_eagerly(): 743 raise RuntimeError('shared_embedding_columns are not supported when eager ' 744 'execution is enabled.') 745 746 if (dimension is None) or (dimension < 1): 747 raise ValueError('Invalid dimension {}.'.format(dimension)) 748 if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): 749 raise ValueError('Must specify both `ckpt_to_load_from` and ' 750 '`tensor_name_in_ckpt` or none of them.') 751 752 if (initializer is not None) and (not callable(initializer)): 753 raise ValueError('initializer must be callable if specified.') 754 if initializer is None: 755 initializer = init_ops.truncated_normal_initializer( 756 mean=0.0, stddev=1. / math.sqrt(dimension)) 757 758 # Sort the columns so the default collection name is deterministic even if the 759 # user passes columns from an unsorted collection, such as dict.values(). 760 sorted_columns = sorted(categorical_columns, key=lambda x: x.name) 761 762 c0 = sorted_columns[0] 763 num_buckets = c0._num_buckets # pylint: disable=protected-access 764 if not isinstance(c0, fc_old._CategoricalColumn): # pylint: disable=protected-access 765 raise ValueError( 766 'All categorical_columns must be subclasses of _CategoricalColumn. ' 767 'Given: {}, of type: {}'.format(c0, type(c0))) 768 while isinstance( 769 c0, (fc_old._WeightedCategoricalColumn, WeightedCategoricalColumn, # pylint: disable=protected-access 770 fc_old._SequenceCategoricalColumn, SequenceCategoricalColumn)): # pylint: disable=protected-access 771 c0 = c0.categorical_column 772 for c in sorted_columns[1:]: 773 while isinstance( 774 c, (fc_old._WeightedCategoricalColumn, WeightedCategoricalColumn, # pylint: disable=protected-access 775 fc_old._SequenceCategoricalColumn, SequenceCategoricalColumn)): # pylint: disable=protected-access 776 c = c.categorical_column 777 if not isinstance(c, type(c0)): 778 raise ValueError( 779 'To use shared_embedding_column, all categorical_columns must have ' 780 'the same type, or be weighted_categorical_column or sequence column ' 781 'of the same type. Given column: {} of type: {} does not match given ' 782 'column: {} of type: {}'.format(c0, type(c0), c, type(c))) 783 if num_buckets != c._num_buckets: # pylint: disable=protected-access 784 raise ValueError( 785 'To use shared_embedding_column, all categorical_columns must have ' 786 'the same number of buckets. ven column: {} with buckets: {} does ' 787 'not match column: {} with buckets: {}'.format( 788 c0, num_buckets, c, c._num_buckets)) # pylint: disable=protected-access 789 790 if not shared_embedding_collection_name: 791 shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns) 792 shared_embedding_collection_name += '_shared_embedding' 793 794 result = [] 795 for column in categorical_columns: 796 result.append( 797 fc_old._SharedEmbeddingColumn( # pylint: disable=protected-access 798 categorical_column=column, 799 initializer=initializer, 800 dimension=dimension, 801 combiner=combiner, 802 shared_embedding_collection_name=shared_embedding_collection_name, 803 ckpt_to_load_from=ckpt_to_load_from, 804 tensor_name_in_ckpt=tensor_name_in_ckpt, 805 max_norm=max_norm, 806 trainable=trainable, 807 use_safe_embedding_lookup=use_safe_embedding_lookup)) 808 809 return result 810 811 812@tf_export('feature_column.shared_embeddings', v1=[]) 813def shared_embedding_columns_v2(categorical_columns, 814 dimension, 815 combiner='mean', 816 initializer=None, 817 shared_embedding_collection_name=None, 818 ckpt_to_load_from=None, 819 tensor_name_in_ckpt=None, 820 max_norm=None, 821 trainable=True, 822 use_safe_embedding_lookup=True): 823 """List of dense columns that convert from sparse, categorical input. 824 825 This is similar to `embedding_column`, except that it produces a list of 826 embedding columns that share the same embedding weights. 827 828 Use this when your inputs are sparse and of the same type (e.g. watched and 829 impression video IDs that share the same vocabulary), and you want to convert 830 them to a dense representation (e.g., to feed to a DNN). 831 832 Inputs must be a list of categorical columns created by any of the 833 `categorical_column_*` function. They must all be of the same type and have 834 the same arguments except `key`. E.g. they can be 835 categorical_column_with_vocabulary_file with the same vocabulary_file. Some or 836 all columns could also be weighted_categorical_column. 837 838 Here is an example embedding of two features for a DNNClassifier model: 839 840 ```python 841 watched_video_id = categorical_column_with_vocabulary_file( 842 'watched_video_id', video_vocabulary_file, video_vocabulary_size) 843 impression_video_id = categorical_column_with_vocabulary_file( 844 'impression_video_id', video_vocabulary_file, video_vocabulary_size) 845 columns = shared_embedding_columns( 846 [watched_video_id, impression_video_id], dimension=10) 847 848 estimator = tf.estimator.DNNClassifier(feature_columns=columns, ...) 849 850 label_column = ... 851 def input_fn(): 852 features = tf.io.parse_example( 853 ..., features=make_parse_example_spec(columns + [label_column])) 854 labels = features.pop(label_column.name) 855 return features, labels 856 857 estimator.train(input_fn=input_fn, steps=100) 858 ``` 859 860 Here is an example using `shared_embedding_columns` with model_fn: 861 862 ```python 863 def model_fn(features, ...): 864 watched_video_id = categorical_column_with_vocabulary_file( 865 'watched_video_id', video_vocabulary_file, video_vocabulary_size) 866 impression_video_id = categorical_column_with_vocabulary_file( 867 'impression_video_id', video_vocabulary_file, video_vocabulary_size) 868 columns = shared_embedding_columns( 869 [watched_video_id, impression_video_id], dimension=10) 870 dense_tensor = input_layer(features, columns) 871 # Form DNN layers, calculate loss, and return EstimatorSpec. 872 ... 873 ``` 874 875 Args: 876 categorical_columns: List of categorical columns created by a 877 `categorical_column_with_*` function. These columns produce the sparse IDs 878 that are inputs to the embedding lookup. All columns must be of the same 879 type and have the same arguments except `key`. E.g. they can be 880 categorical_column_with_vocabulary_file with the same vocabulary_file. 881 Some or all columns could also be weighted_categorical_column. 882 dimension: An integer specifying dimension of the embedding, must be > 0. 883 combiner: A string specifying how to reduce if there are multiple entries 884 in a single row. Currently 'mean', 'sqrtn' and 'sum' are supported, with 885 'mean' the default. 'sqrtn' often achieves good accuracy, in particular 886 with bag-of-words columns. Each of this can be thought as example level 887 normalizations on the column. For more information, see 888 `tf.embedding_lookup_sparse`. 889 initializer: A variable initializer function to be used in embedding 890 variable initialization. If not specified, defaults to 891 `truncated_normal_initializer` with mean `0.0` and standard 892 deviation `1/sqrt(dimension)`. 893 shared_embedding_collection_name: Optional collective name of these columns. 894 If not given, a reasonable name will be chosen based on the names of 895 `categorical_columns`. 896 ckpt_to_load_from: String representing checkpoint name/pattern from which to 897 restore column weights. Required if `tensor_name_in_ckpt` is not `None`. 898 tensor_name_in_ckpt: Name of the `Tensor` in `ckpt_to_load_from` from 899 which to restore the column weights. Required if `ckpt_to_load_from` is 900 not `None`. 901 max_norm: If not `None`, each embedding is clipped if its l2-norm is 902 larger than this value, before combining. 903 trainable: Whether or not the embedding is trainable. Default is True. 904 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse 905 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures 906 there are no empty rows and all weights and ids are positive at the 907 expense of extra compute cost. This only applies to rank 2 (NxM) shaped 908 input tensors. Defaults to true, consider turning off if the above checks 909 are not needed. Note that having empty rows will not trigger any error 910 though the output result might be 0 or omitted. 911 912 Returns: 913 A list of dense columns that converts from sparse input. The order of 914 results follows the ordering of `categorical_columns`. 915 916 Raises: 917 ValueError: if `dimension` not > 0. 918 ValueError: if any of the given `categorical_columns` is of different type 919 or has different arguments than the others. 920 ValueError: if exactly one of `ckpt_to_load_from` and `tensor_name_in_ckpt` 921 is specified. 922 ValueError: if `initializer` is specified and is not callable. 923 RuntimeError: if eager execution is enabled. 924 """ 925 if context.executing_eagerly(): 926 raise RuntimeError('shared_embedding_columns are not supported when eager ' 927 'execution is enabled.') 928 929 if (dimension is None) or (dimension < 1): 930 raise ValueError('Invalid dimension {}.'.format(dimension)) 931 if (ckpt_to_load_from is None) != (tensor_name_in_ckpt is None): 932 raise ValueError('Must specify both `ckpt_to_load_from` and ' 933 '`tensor_name_in_ckpt` or none of them.') 934 935 if (initializer is not None) and (not callable(initializer)): 936 raise ValueError('initializer must be callable if specified.') 937 if initializer is None: 938 initializer = init_ops.truncated_normal_initializer( 939 mean=0.0, stddev=1. / math.sqrt(dimension)) 940 941 # Sort the columns so the default collection name is deterministic even if the 942 # user passes columns from an unsorted collection, such as dict.values(). 943 sorted_columns = sorted(categorical_columns, key=lambda x: x.name) 944 945 c0 = sorted_columns[0] 946 num_buckets = c0.num_buckets 947 if not isinstance(c0, CategoricalColumn): 948 raise ValueError( 949 'All categorical_columns must be subclasses of CategoricalColumn. ' 950 'Given: {}, of type: {}'.format(c0, type(c0))) 951 while isinstance(c0, (WeightedCategoricalColumn, SequenceCategoricalColumn)): 952 c0 = c0.categorical_column 953 for c in sorted_columns[1:]: 954 while isinstance(c, (WeightedCategoricalColumn, SequenceCategoricalColumn)): 955 c = c.categorical_column 956 if not isinstance(c, type(c0)): 957 raise ValueError( 958 'To use shared_embedding_column, all categorical_columns must have ' 959 'the same type, or be weighted_categorical_column or sequence column ' 960 'of the same type. Given column: {} of type: {} does not match given ' 961 'column: {} of type: {}'.format(c0, type(c0), c, type(c))) 962 if num_buckets != c.num_buckets: 963 raise ValueError( 964 'To use shared_embedding_column, all categorical_columns must have ' 965 'the same number of buckets. Given column: {} with buckets: {} does ' 966 'not match column: {} with buckets: {}'.format( 967 c0, num_buckets, c, c.num_buckets)) 968 969 if not shared_embedding_collection_name: 970 shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns) 971 shared_embedding_collection_name += '_shared_embedding' 972 973 column_creator = SharedEmbeddingColumnCreator( 974 dimension, initializer, ckpt_to_load_from, tensor_name_in_ckpt, 975 num_buckets, trainable, shared_embedding_collection_name, 976 use_safe_embedding_lookup) 977 978 result = [] 979 for column in categorical_columns: 980 result.append( 981 column_creator( 982 categorical_column=column, combiner=combiner, max_norm=max_norm)) 983 984 return result 985 986 987@tf_export('feature_column.numeric_column') 988def numeric_column(key, 989 shape=(1,), 990 default_value=None, 991 dtype=dtypes.float32, 992 normalizer_fn=None): 993 """Represents real valued or numerical features. 994 995 Example: 996 997 Assume we have data with two features `a` and `b`. 998 999 >>> data = {'a': [15, 9, 17, 19, 21, 18, 25, 30], 1000 ... 'b': [5.0, 6.4, 10.5, 13.6, 15.7, 19.9, 20.3 , 0.0]} 1001 1002 Let us represent the features `a` and `b` as numerical features. 1003 1004 >>> a = tf.feature_column.numeric_column('a') 1005 >>> b = tf.feature_column.numeric_column('b') 1006 1007 Feature column describe a set of transformations to the inputs. 1008 1009 For example, to "bucketize" feature `a`, wrap the `a` column in a 1010 `feature_column.bucketized_column`. 1011 Providing `5` bucket boundaries, the bucketized_column api 1012 will bucket this feature in total of `6` buckets. 1013 1014 >>> a_buckets = tf.feature_column.bucketized_column(a, 1015 ... boundaries=[10, 15, 20, 25, 30]) 1016 1017 Create a `DenseFeatures` layer which will apply the transformations 1018 described by the set of `tf.feature_column` objects: 1019 1020 >>> feature_layer = tf.keras.layers.DenseFeatures([a_buckets, b]) 1021 >>> print(feature_layer(data)) 1022 tf.Tensor( 1023 [[ 0. 0. 1. 0. 0. 0. 5. ] 1024 [ 1. 0. 0. 0. 0. 0. 6.4] 1025 [ 0. 0. 1. 0. 0. 0. 10.5] 1026 [ 0. 0. 1. 0. 0. 0. 13.6] 1027 [ 0. 0. 0. 1. 0. 0. 15.7] 1028 [ 0. 0. 1. 0. 0. 0. 19.9] 1029 [ 0. 0. 0. 0. 1. 0. 20.3] 1030 [ 0. 0. 0. 0. 0. 1. 0. ]], shape=(8, 7), dtype=float32) 1031 1032 Args: 1033 key: A unique string identifying the input feature. It is used as the 1034 column name and the dictionary key for feature parsing configs, feature 1035 `Tensor` objects, and feature columns. 1036 shape: An iterable of integers specifies the shape of the `Tensor`. An 1037 integer can be given which means a single dimension `Tensor` with given 1038 width. The `Tensor` representing the column will have the shape of 1039 [batch_size] + `shape`. 1040 default_value: A single value compatible with `dtype` or an iterable of 1041 values compatible with `dtype` which the column takes on during 1042 `tf.Example` parsing if data is missing. A default value of `None` will 1043 cause `tf.io.parse_example` to fail if an example does not contain this 1044 column. If a single value is provided, the same value will be applied as 1045 the default value for every item. If an iterable of values is provided, 1046 the shape of the `default_value` should be equal to the given `shape`. 1047 dtype: defines the type of values. Default value is `tf.float32`. Must be a 1048 non-quantized, real integer or floating point type. 1049 normalizer_fn: If not `None`, a function that can be used to normalize the 1050 value of the tensor after `default_value` is applied for parsing. 1051 Normalizer function takes the input `Tensor` as its argument, and returns 1052 the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that 1053 even though the most common use case of this function is normalization, it 1054 can be used for any kind of Tensorflow transformations. 1055 1056 Returns: 1057 A `NumericColumn`. 1058 1059 Raises: 1060 TypeError: if any dimension in shape is not an int 1061 ValueError: if any dimension in shape is not a positive integer 1062 TypeError: if `default_value` is an iterable but not compatible with `shape` 1063 TypeError: if `default_value` is not compatible with `dtype`. 1064 ValueError: if `dtype` is not convertible to `tf.float32`. 1065 """ 1066 shape = _check_shape(shape, key) 1067 if not (dtype.is_integer or dtype.is_floating): 1068 raise ValueError('dtype must be convertible to float. ' 1069 'dtype: {}, key: {}'.format(dtype, key)) 1070 default_value = fc_utils.check_default_value( 1071 shape, default_value, dtype, key) 1072 1073 if normalizer_fn is not None and not callable(normalizer_fn): 1074 raise TypeError( 1075 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn)) 1076 1077 fc_utils.assert_key_is_string(key) 1078 return NumericColumn( 1079 key, 1080 shape=shape, 1081 default_value=default_value, 1082 dtype=dtype, 1083 normalizer_fn=normalizer_fn) 1084 1085 1086@tf_export('feature_column.bucketized_column') 1087def bucketized_column(source_column, boundaries): 1088 """Represents discretized dense input bucketed by `boundaries`. 1089 1090 Buckets include the left boundary, and exclude the right boundary. Namely, 1091 `boundaries=[0., 1., 2.]` generates buckets `(-inf, 0.)`, `[0., 1.)`, 1092 `[1., 2.)`, and `[2., +inf)`. 1093 1094 For example, if the inputs are 1095 1096 ```python 1097 boundaries = [0, 10, 100] 1098 input tensor = [[-5, 10000] 1099 [150, 10] 1100 [5, 100]] 1101 ``` 1102 1103 then the output will be 1104 1105 ```python 1106 output = [[0, 3] 1107 [3, 2] 1108 [1, 3]] 1109 ``` 1110 1111 Example: 1112 1113 ```python 1114 price = tf.feature_column.numeric_column('price') 1115 bucketized_price = tf.feature_column.bucketized_column( 1116 price, boundaries=[...]) 1117 columns = [bucketized_price, ...] 1118 features = tf.io.parse_example( 1119 ..., features=tf.feature_column.make_parse_example_spec(columns)) 1120 dense_tensor = tf.keras.layers.DenseFeatures(columns)(features) 1121 ``` 1122 1123 A `bucketized_column` can also be crossed with another categorical column 1124 using `crossed_column`: 1125 1126 ```python 1127 price = tf.feature_column.numeric_column('price') 1128 # bucketized_column converts numerical feature to a categorical one. 1129 bucketized_price = tf.feature_column.bucketized_column( 1130 price, boundaries=[...]) 1131 # 'keywords' is a string feature. 1132 price_x_keywords = tf.feature_column.crossed_column( 1133 [bucketized_price, 'keywords'], 50K) 1134 columns = [price_x_keywords, ...] 1135 features = tf.io.parse_example( 1136 ..., features=tf.feature_column.make_parse_example_spec(columns)) 1137 dense_tensor = tf.keras.layers.DenseFeatures(columns)(features) 1138 linear_model = tf.keras.experimental.LinearModel(units=...)(dense_tensor) 1139 ``` 1140 1141 Args: 1142 source_column: A one-dimensional dense column which is generated with 1143 `numeric_column`. 1144 boundaries: A sorted list or tuple of floats specifying the boundaries. 1145 1146 Returns: 1147 A `BucketizedColumn`. 1148 1149 Raises: 1150 ValueError: If `source_column` is not a numeric column, or if it is not 1151 one-dimensional. 1152 ValueError: If `boundaries` is not a sorted list or tuple. 1153 """ 1154 if not isinstance(source_column, (NumericColumn, fc_old._NumericColumn)): # pylint: disable=protected-access 1155 raise ValueError( 1156 'source_column must be a column generated with numeric_column(). ' 1157 'Given: {}'.format(source_column)) 1158 if len(source_column.shape) > 1: 1159 raise ValueError( 1160 'source_column must be one-dimensional column. ' 1161 'Given: {}'.format(source_column)) 1162 if not boundaries: 1163 raise ValueError('boundaries must not be empty.') 1164 if not (isinstance(boundaries, list) or isinstance(boundaries, tuple)): 1165 raise ValueError('boundaries must be a sorted list.') 1166 for i in range(len(boundaries) - 1): 1167 if boundaries[i] >= boundaries[i + 1]: 1168 raise ValueError('boundaries must be a sorted list.') 1169 return BucketizedColumn(source_column, tuple(boundaries)) 1170 1171 1172@tf_export('feature_column.categorical_column_with_hash_bucket') 1173def categorical_column_with_hash_bucket(key, 1174 hash_bucket_size, 1175 dtype=dtypes.string): 1176 """Represents sparse feature where ids are set by hashing. 1177 1178 Use this when your sparse features are in string or integer format, and you 1179 want to distribute your inputs into a finite number of buckets by hashing. 1180 output_id = Hash(input_feature_string) % bucket_size for string type input. 1181 For int type input, the value is converted to its string representation first 1182 and then hashed by the same formula. 1183 1184 For input dictionary `features`, `features[key]` is either `Tensor` or 1185 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int 1186 and `''` for string, which will be dropped by this feature column. 1187 1188 Example: 1189 1190 ```python 1191 import tensorflow as tf 1192 keywords = tf.feature_column.categorical_column_with_hash_bucket("keywords", 1193 10000) 1194 columns = [keywords] 1195 features = {'keywords': tf.constant([['Tensorflow', 'Keras', 'RNN', 'LSTM', 1196 'CNN'], ['LSTM', 'CNN', 'Tensorflow', 'Keras', 'RNN'], ['CNN', 'Tensorflow', 1197 'LSTM', 'Keras', 'RNN']])} 1198 linear_prediction, _, _ = tf.compat.v1.feature_column.linear_model(features, 1199 columns) 1200 1201 # or 1202 import tensorflow as tf 1203 keywords = tf.feature_column.categorical_column_with_hash_bucket("keywords", 1204 10000) 1205 keywords_embedded = tf.feature_column.embedding_column(keywords, 16) 1206 columns = [keywords_embedded] 1207 features = {'keywords': tf.constant([['Tensorflow', 'Keras', 'RNN', 'LSTM', 1208 'CNN'], ['LSTM', 'CNN', 'Tensorflow', 'Keras', 'RNN'], ['CNN', 'Tensorflow', 1209 'LSTM', 'Keras', 'RNN']])} 1210 input_layer = tf.keras.layers.DenseFeatures(columns) 1211 dense_tensor = input_layer(features) 1212 ``` 1213 1214 Args: 1215 key: A unique string identifying the input feature. It is used as the 1216 column name and the dictionary key for feature parsing configs, feature 1217 `Tensor` objects, and feature columns. 1218 hash_bucket_size: An int > 1. The number of buckets. 1219 dtype: The type of features. Only string and integer types are supported. 1220 1221 Returns: 1222 A `HashedCategoricalColumn`. 1223 1224 Raises: 1225 ValueError: `hash_bucket_size` is not greater than 1. 1226 ValueError: `dtype` is neither string nor integer. 1227 """ 1228 if hash_bucket_size is None: 1229 raise ValueError('hash_bucket_size must be set. ' 'key: {}'.format(key)) 1230 1231 if hash_bucket_size < 1: 1232 raise ValueError('hash_bucket_size must be at least 1. ' 1233 'hash_bucket_size: {}, key: {}'.format( 1234 hash_bucket_size, key)) 1235 1236 fc_utils.assert_key_is_string(key) 1237 fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) 1238 1239 return HashedCategoricalColumn(key, hash_bucket_size, dtype) 1240 1241 1242@tf_export(v1=['feature_column.categorical_column_with_vocabulary_file']) 1243def categorical_column_with_vocabulary_file(key, 1244 vocabulary_file, 1245 vocabulary_size=None, 1246 num_oov_buckets=0, 1247 default_value=None, 1248 dtype=dtypes.string): 1249 """A `CategoricalColumn` with a vocabulary file. 1250 1251 Use this when your inputs are in string or integer format, and you have a 1252 vocabulary file that maps each value to an integer ID. By default, 1253 out-of-vocabulary values are ignored. Use either (but not both) of 1254 `num_oov_buckets` and `default_value` to specify how to include 1255 out-of-vocabulary values. 1256 1257 For input dictionary `features`, `features[key]` is either `Tensor` or 1258 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int 1259 and `''` for string, which will be dropped by this feature column. 1260 1261 Example with `num_oov_buckets`: 1262 File '/us/states.txt' contains 50 lines, each with a 2-character U.S. state 1263 abbreviation. All inputs with values in that file are assigned an ID 0-49, 1264 corresponding to its line number. All other values are hashed and assigned an 1265 ID 50-54. 1266 1267 ```python 1268 import tensorflow as tf 1269 states = tf.feature_column.categorical_column_with_vocabulary_file( 1270 key='states', vocabulary_file='states.txt', vocabulary_size=5, 1271 num_oov_buckets=1) 1272 columns = [states] 1273 features = {'states':tf.constant([['california', 'georgia', 'michigan', 1274 'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan', 1275 'texas']])} 1276 linear_prediction = tf.compat.v1.feature_column.linear_model(features, 1277 columns) 1278 ``` 1279 1280 Example with `default_value`: 1281 File '/us/states.txt' contains 51 lines - the first line is 'XX', and the 1282 other 50 each have a 2-character U.S. state abbreviation. Both a literal 'XX' 1283 in input, and other values missing from the file, will be assigned ID 0. All 1284 others are assigned the corresponding line number 1-50. 1285 1286 ```python 1287 import tensorflow as tf 1288 states = tf.feature_column.categorical_column_with_vocabulary_file( 1289 key='states', vocabulary_file='states.txt', vocabulary_size=6, 1290 default_value=0) 1291 columns = [states] 1292 features = {'states':tf.constant([['california', 'georgia', 'michigan', 1293 'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan', 1294 'texas']])} 1295 linear_prediction = tf.compat.v1.feature_column.linear_model(features, 1296 columns) 1297 ``` 1298 1299 And to make an embedding with either: 1300 1301 ```python 1302 import tensorflow as tf 1303 states = tf.feature_column.categorical_column_with_vocabulary_file( 1304 key='states', vocabulary_file='states.txt', vocabulary_size=5, 1305 num_oov_buckets=1) 1306 columns = [tf.feature_column.embedding_column(states, 3)] 1307 features = {'states':tf.constant([['california', 'georgia', 'michigan', 1308 'texas', 'new york'], ['new york', 'georgia', 'california', 'michigan', 1309 'texas']])} 1310 input_layer = tf.keras.layers.DenseFeatures(columns) 1311 dense_tensor = input_layer(features) 1312 ``` 1313 1314 Args: 1315 key: A unique string identifying the input feature. It is used as the 1316 column name and the dictionary key for feature parsing configs, feature 1317 `Tensor` objects, and feature columns. 1318 vocabulary_file: The vocabulary file name. 1319 vocabulary_size: Number of the elements in the vocabulary. This must be no 1320 greater than length of `vocabulary_file`, if less than length, later 1321 values are ignored. If None, it is set to the length of `vocabulary_file`. 1322 num_oov_buckets: Non-negative integer, the number of out-of-vocabulary 1323 buckets. All out-of-vocabulary inputs will be assigned IDs in the range 1324 `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of 1325 the input value. A positive `num_oov_buckets` can not be specified with 1326 `default_value`. 1327 default_value: The integer ID value to return for out-of-vocabulary feature 1328 values, defaults to `-1`. This can not be specified with a positive 1329 `num_oov_buckets`. 1330 dtype: The type of features. Only string and integer types are supported. 1331 1332 Returns: 1333 A `CategoricalColumn` with a vocabulary file. 1334 1335 Raises: 1336 ValueError: `vocabulary_file` is missing or cannot be opened. 1337 ValueError: `vocabulary_size` is missing or < 1. 1338 ValueError: `num_oov_buckets` is a negative integer. 1339 ValueError: `num_oov_buckets` and `default_value` are both specified. 1340 ValueError: `dtype` is neither string nor integer. 1341 """ 1342 return categorical_column_with_vocabulary_file_v2( 1343 key, vocabulary_file, vocabulary_size, 1344 dtype, default_value, 1345 num_oov_buckets) 1346 1347 1348@tf_export('feature_column.categorical_column_with_vocabulary_file', v1=[]) 1349def categorical_column_with_vocabulary_file_v2(key, 1350 vocabulary_file, 1351 vocabulary_size=None, 1352 dtype=dtypes.string, 1353 default_value=None, 1354 num_oov_buckets=0, 1355 file_format=None): 1356 """A `CategoricalColumn` with a vocabulary file. 1357 1358 Use this when your inputs are in string or integer format, and you have a 1359 vocabulary file that maps each value to an integer ID. By default, 1360 out-of-vocabulary values are ignored. Use either (but not both) of 1361 `num_oov_buckets` and `default_value` to specify how to include 1362 out-of-vocabulary values. 1363 1364 For input dictionary `features`, `features[key]` is either `Tensor` or 1365 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int 1366 and `''` for string, which will be dropped by this feature column. 1367 1368 Example with `num_oov_buckets`: 1369 File `'/us/states.txt'` contains 50 lines, each with a 2-character U.S. state 1370 abbreviation. All inputs with values in that file are assigned an ID 0-49, 1371 corresponding to its line number. All other values are hashed and assigned an 1372 ID 50-54. 1373 1374 ```python 1375 states = categorical_column_with_vocabulary_file( 1376 key='states', vocabulary_file='/us/states.txt', vocabulary_size=50, 1377 num_oov_buckets=5) 1378 columns = [states, ...] 1379 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1380 linear_prediction = linear_model(features, columns) 1381 ``` 1382 1383 Example with `default_value`: 1384 File `'/us/states.txt'` contains 51 lines - the first line is `'XX'`, and the 1385 other 50 each have a 2-character U.S. state abbreviation. Both a literal 1386 `'XX'` in input, and other values missing from the file, will be assigned 1387 ID 0. All others are assigned the corresponding line number 1-50. 1388 1389 ```python 1390 states = categorical_column_with_vocabulary_file( 1391 key='states', vocabulary_file='/us/states.txt', vocabulary_size=51, 1392 default_value=0) 1393 columns = [states, ...] 1394 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1395 linear_prediction, _, _ = linear_model(features, columns) 1396 ``` 1397 1398 And to make an embedding with either: 1399 1400 ```python 1401 columns = [embedding_column(states, 3),...] 1402 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1403 dense_tensor = input_layer(features, columns) 1404 ``` 1405 1406 Args: 1407 key: A unique string identifying the input feature. It is used as the 1408 column name and the dictionary key for feature parsing configs, feature 1409 `Tensor` objects, and feature columns. 1410 vocabulary_file: The vocabulary file name. 1411 vocabulary_size: Number of the elements in the vocabulary. This must be no 1412 greater than length of `vocabulary_file`, if less than length, later 1413 values are ignored. If None, it is set to the length of `vocabulary_file`. 1414 dtype: The type of features. Only string and integer types are supported. 1415 default_value: The integer ID value to return for out-of-vocabulary feature 1416 values, defaults to `-1`. This can not be specified with a positive 1417 `num_oov_buckets`. 1418 num_oov_buckets: Non-negative integer, the number of out-of-vocabulary 1419 buckets. All out-of-vocabulary inputs will be assigned IDs in the range 1420 `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of 1421 the input value. A positive `num_oov_buckets` can not be specified with 1422 `default_value`. 1423 file_format: The format of the vocabulary file. The format is 'text' by 1424 default unless `vocabulary_file` is a string which ends in 'tfrecord.gz'. 1425 Accepted alternative value for `file_format` is 'tfrecord_gzip'. 1426 1427 Returns: 1428 A `CategoricalColumn` with a vocabulary file. 1429 1430 Raises: 1431 ValueError: `vocabulary_file` is missing or cannot be opened. 1432 ValueError: `vocabulary_size` is missing or < 1. 1433 ValueError: `num_oov_buckets` is a negative integer. 1434 ValueError: `num_oov_buckets` and `default_value` are both specified. 1435 ValueError: `dtype` is neither string nor integer. 1436 """ 1437 if not vocabulary_file: 1438 raise ValueError('Missing vocabulary_file in {}.'.format(key)) 1439 1440 if file_format is None and vocabulary_file.endswith('tfrecord.gz'): 1441 file_format = 'tfrecord_gzip' 1442 1443 if vocabulary_size is None: 1444 if not gfile.Exists(vocabulary_file): 1445 raise ValueError('vocabulary_file in {} does not exist.'.format(key)) 1446 1447 if file_format == 'tfrecord_gzip': 1448 ds = readers.TFRecordDataset(vocabulary_file, 'GZIP') 1449 vocabulary_size = ds.reduce(0, lambda x, _: x + 1) 1450 if context.executing_eagerly(): 1451 vocabulary_size = vocabulary_size.numpy() 1452 else: 1453 with gfile.GFile(vocabulary_file, mode='rb') as f: 1454 vocabulary_size = sum(1 for _ in f) 1455 logging.info( 1456 'vocabulary_size = %d in %s is inferred from the number of elements ' 1457 'in the vocabulary_file %s.', vocabulary_size, key, vocabulary_file) 1458 1459 # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`. 1460 if not isinstance(vocabulary_size, ops.Tensor) and vocabulary_size < 1: 1461 raise ValueError('Invalid vocabulary_size in {}.'.format(key)) 1462 if num_oov_buckets: 1463 if default_value is not None: 1464 raise ValueError( 1465 'Can\'t specify both num_oov_buckets and default_value in {}.'.format( 1466 key)) 1467 if num_oov_buckets < 0: 1468 raise ValueError('Invalid num_oov_buckets {} in {}.'.format( 1469 num_oov_buckets, key)) 1470 fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) 1471 fc_utils.assert_key_is_string(key) 1472 return VocabularyFileCategoricalColumn( 1473 key=key, 1474 vocabulary_file=vocabulary_file, 1475 vocabulary_size=vocabulary_size, 1476 num_oov_buckets=0 if num_oov_buckets is None else num_oov_buckets, 1477 default_value=-1 if default_value is None else default_value, 1478 dtype=dtype, 1479 file_format=file_format) 1480 1481 1482@tf_export('feature_column.categorical_column_with_vocabulary_list') 1483def categorical_column_with_vocabulary_list(key, 1484 vocabulary_list, 1485 dtype=None, 1486 default_value=-1, 1487 num_oov_buckets=0): 1488 """A `CategoricalColumn` with in-memory vocabulary. 1489 1490 Use this when your inputs are in string or integer format, and you have an 1491 in-memory vocabulary mapping each value to an integer ID. By default, 1492 out-of-vocabulary values are ignored. Use either (but not both) of 1493 `num_oov_buckets` and `default_value` to specify how to include 1494 out-of-vocabulary values. 1495 1496 For input dictionary `features`, `features[key]` is either `Tensor` or 1497 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int 1498 and `''` for string, which will be dropped by this feature column. 1499 1500 Example with `num_oov_buckets`: 1501 In the following example, each input in `vocabulary_list` is assigned an ID 1502 0-3 corresponding to its index (e.g., input 'B' produces output 2). All other 1503 inputs are hashed and assigned an ID 4-5. 1504 1505 ```python 1506 colors = categorical_column_with_vocabulary_list( 1507 key='colors', vocabulary_list=('R', 'G', 'B', 'Y'), 1508 num_oov_buckets=2) 1509 columns = [colors, ...] 1510 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1511 linear_prediction, _, _ = linear_model(features, columns) 1512 ``` 1513 1514 Example with `default_value`: 1515 In the following example, each input in `vocabulary_list` is assigned an ID 1516 0-4 corresponding to its index (e.g., input 'B' produces output 3). All other 1517 inputs are assigned `default_value` 0. 1518 1519 1520 ```python 1521 colors = categorical_column_with_vocabulary_list( 1522 key='colors', vocabulary_list=('X', 'R', 'G', 'B', 'Y'), default_value=0) 1523 columns = [colors, ...] 1524 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1525 linear_prediction, _, _ = linear_model(features, columns) 1526 ``` 1527 1528 And to make an embedding with either: 1529 1530 ```python 1531 columns = [embedding_column(colors, 3),...] 1532 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1533 dense_tensor = input_layer(features, columns) 1534 ``` 1535 1536 Args: 1537 key: A unique string identifying the input feature. It is used as the column 1538 name and the dictionary key for feature parsing configs, feature `Tensor` 1539 objects, and feature columns. 1540 vocabulary_list: An ordered iterable defining the vocabulary. Each feature 1541 is mapped to the index of its value (if present) in `vocabulary_list`. 1542 Must be castable to `dtype`. 1543 dtype: The type of features. Only string and integer types are supported. If 1544 `None`, it will be inferred from `vocabulary_list`. 1545 default_value: The integer ID value to return for out-of-vocabulary feature 1546 values, defaults to `-1`. This can not be specified with a positive 1547 `num_oov_buckets`. 1548 num_oov_buckets: Non-negative integer, the number of out-of-vocabulary 1549 buckets. All out-of-vocabulary inputs will be assigned IDs in the range 1550 `[len(vocabulary_list), len(vocabulary_list)+num_oov_buckets)` based on a 1551 hash of the input value. A positive `num_oov_buckets` can not be specified 1552 with `default_value`. 1553 1554 Returns: 1555 A `CategoricalColumn` with in-memory vocabulary. 1556 1557 Raises: 1558 ValueError: if `vocabulary_list` is empty, or contains duplicate keys. 1559 ValueError: `num_oov_buckets` is a negative integer. 1560 ValueError: `num_oov_buckets` and `default_value` are both specified. 1561 ValueError: if `dtype` is not integer or string. 1562 """ 1563 if (vocabulary_list is None) or (len(vocabulary_list) < 1): 1564 raise ValueError( 1565 'vocabulary_list {} must be non-empty, column_name: {}'.format( 1566 vocabulary_list, key)) 1567 if len(set(vocabulary_list)) != len(vocabulary_list): 1568 raise ValueError( 1569 'Duplicate keys in vocabulary_list {}, column_name: {}'.format( 1570 vocabulary_list, key)) 1571 vocabulary_dtype = dtypes.as_dtype(np.array(vocabulary_list).dtype) 1572 if num_oov_buckets: 1573 if default_value != -1: 1574 raise ValueError( 1575 'Can\'t specify both num_oov_buckets and default_value in {}.'.format( 1576 key)) 1577 if num_oov_buckets < 0: 1578 raise ValueError('Invalid num_oov_buckets {} in {}.'.format( 1579 num_oov_buckets, key)) 1580 fc_utils.assert_string_or_int( 1581 vocabulary_dtype, prefix='column_name: {} vocabulary'.format(key)) 1582 if dtype is None: 1583 dtype = vocabulary_dtype 1584 elif dtype.is_integer != vocabulary_dtype.is_integer: 1585 raise ValueError( 1586 'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format( 1587 dtype, vocabulary_dtype, key)) 1588 fc_utils.assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) 1589 fc_utils.assert_key_is_string(key) 1590 1591 return VocabularyListCategoricalColumn( 1592 key=key, 1593 vocabulary_list=tuple(vocabulary_list), 1594 dtype=dtype, 1595 default_value=default_value, 1596 num_oov_buckets=num_oov_buckets) 1597 1598 1599@tf_export('feature_column.categorical_column_with_identity') 1600def categorical_column_with_identity(key, num_buckets, default_value=None): 1601 """A `CategoricalColumn` that returns identity values. 1602 1603 Use this when your inputs are integers in the range `[0, num_buckets)`, and 1604 you want to use the input value itself as the categorical ID. Values outside 1605 this range will result in `default_value` if specified, otherwise it will 1606 fail. 1607 1608 Typically, this is used for contiguous ranges of integer indexes, but 1609 it doesn't have to be. This might be inefficient, however, if many of IDs 1610 are unused. Consider `categorical_column_with_hash_bucket` in that case. 1611 1612 For input dictionary `features`, `features[key]` is either `Tensor` or 1613 `SparseTensor`. If `Tensor`, missing values can be represented by `-1` for int 1614 and `''` for string, which will be dropped by this feature column. 1615 1616 In the following examples, each input in the range `[0, 1000000)` is assigned 1617 the same value. All other inputs are assigned `default_value` 0. Note that a 1618 literal 0 in inputs will result in the same default ID. 1619 1620 Linear model: 1621 1622 ```python 1623 import tensorflow as tf 1624 video_id = tf.feature_column.categorical_column_with_identity( 1625 key='video_id', num_buckets=1000000, default_value=0) 1626 columns = [video_id] 1627 features = {'video_id': tf.sparse.from_dense([[2, 85, 0, 0, 0], 1628 [33,78, 2, 73, 1]])} 1629 linear_prediction = tf.compat.v1.feature_column.linear_model(features, 1630 columns) 1631 ``` 1632 1633 Embedding for a DNN model: 1634 1635 ```python 1636 import tensorflow as tf 1637 video_id = tf.feature_column.categorical_column_with_identity( 1638 key='video_id', num_buckets=1000000, default_value=0) 1639 columns = [tf.feature_column.embedding_column(video_id, 9)] 1640 features = {'video_id': tf.sparse.from_dense([[2, 85, 0, 0, 0], 1641 [33,78, 2, 73, 1]])} 1642 input_layer = tf.keras.layers.DenseFeatures(columns) 1643 dense_tensor = input_layer(features) 1644 ``` 1645 1646 Args: 1647 key: A unique string identifying the input feature. It is used as the 1648 column name and the dictionary key for feature parsing configs, feature 1649 `Tensor` objects, and feature columns. 1650 num_buckets: Range of inputs and outputs is `[0, num_buckets)`. 1651 default_value: If set, values outside of range `[0, num_buckets)` will 1652 be replaced with this value. If not set, values >= num_buckets will 1653 cause a failure while values < 0 will be dropped. 1654 1655 Returns: 1656 A `CategoricalColumn` that returns identity values. 1657 1658 Raises: 1659 ValueError: if `num_buckets` is less than one. 1660 ValueError: if `default_value` is not in range `[0, num_buckets)`. 1661 """ 1662 if num_buckets < 1: 1663 raise ValueError( 1664 'num_buckets {} < 1, column_name {}'.format(num_buckets, key)) 1665 if (default_value is not None) and ( 1666 (default_value < 0) or (default_value >= num_buckets)): 1667 raise ValueError( 1668 'default_value {} not in range [0, {}), column_name {}'.format( 1669 default_value, num_buckets, key)) 1670 fc_utils.assert_key_is_string(key) 1671 return IdentityCategoricalColumn( 1672 key=key, number_buckets=num_buckets, default_value=default_value) 1673 1674 1675@tf_export('feature_column.indicator_column') 1676def indicator_column(categorical_column): 1677 """Represents multi-hot representation of given categorical column. 1678 1679 - For DNN model, `indicator_column` can be used to wrap any 1680 `categorical_column_*` (e.g., to feed to DNN). Consider to Use 1681 `embedding_column` if the number of buckets/unique(values) are large. 1682 1683 - For Wide (aka linear) model, `indicator_column` is the internal 1684 representation for categorical column when passing categorical column 1685 directly (as any element in feature_columns) to `linear_model`. See 1686 `linear_model` for details. 1687 1688 ```python 1689 name = indicator_column(categorical_column_with_vocabulary_list( 1690 'name', ['bob', 'george', 'wanda'])) 1691 columns = [name, ...] 1692 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1693 dense_tensor = input_layer(features, columns) 1694 1695 dense_tensor == [[1, 0, 0]] # If "name" bytes_list is ["bob"] 1696 dense_tensor == [[1, 0, 1]] # If "name" bytes_list is ["bob", "wanda"] 1697 dense_tensor == [[2, 0, 0]] # If "name" bytes_list is ["bob", "bob"] 1698 ``` 1699 1700 Args: 1701 categorical_column: A `CategoricalColumn` which is created by 1702 `categorical_column_with_*` or `crossed_column` functions. 1703 1704 Returns: 1705 An `IndicatorColumn`. 1706 1707 Raises: 1708 ValueError: If `categorical_column` is not CategoricalColumn type. 1709 """ 1710 if not isinstance(categorical_column, 1711 (CategoricalColumn, fc_old._CategoricalColumn)): # pylint: disable=protected-access 1712 raise ValueError( 1713 'Unsupported input type. Input must be a CategoricalColumn. ' 1714 'Given: {}'.format(categorical_column)) 1715 return IndicatorColumn(categorical_column) 1716 1717 1718@tf_export('feature_column.weighted_categorical_column') 1719def weighted_categorical_column(categorical_column, 1720 weight_feature_key, 1721 dtype=dtypes.float32): 1722 """Applies weight values to a `CategoricalColumn`. 1723 1724 Use this when each of your sparse inputs has both an ID and a value. For 1725 example, if you're representing text documents as a collection of word 1726 frequencies, you can provide 2 parallel sparse input features ('terms' and 1727 'frequencies' below). 1728 1729 Example: 1730 1731 Input `tf.Example` objects: 1732 1733 ```proto 1734 [ 1735 features { 1736 feature { 1737 key: "terms" 1738 value {bytes_list {value: "very" value: "model"}} 1739 } 1740 feature { 1741 key: "frequencies" 1742 value {float_list {value: 0.3 value: 0.1}} 1743 } 1744 }, 1745 features { 1746 feature { 1747 key: "terms" 1748 value {bytes_list {value: "when" value: "course" value: "human"}} 1749 } 1750 feature { 1751 key: "frequencies" 1752 value {float_list {value: 0.4 value: 0.1 value: 0.2}} 1753 } 1754 } 1755 ] 1756 ``` 1757 1758 ```python 1759 categorical_column = categorical_column_with_hash_bucket( 1760 column_name='terms', hash_bucket_size=1000) 1761 weighted_column = weighted_categorical_column( 1762 categorical_column=categorical_column, weight_feature_key='frequencies') 1763 columns = [weighted_column, ...] 1764 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1765 linear_prediction, _, _ = linear_model(features, columns) 1766 ``` 1767 1768 This assumes the input dictionary contains a `SparseTensor` for key 1769 'terms', and a `SparseTensor` for key 'frequencies'. These 2 tensors must have 1770 the same indices and dense shape. 1771 1772 Args: 1773 categorical_column: A `CategoricalColumn` created by 1774 `categorical_column_with_*` functions. 1775 weight_feature_key: String key for weight values. 1776 dtype: Type of weights, such as `tf.float32`. Only float and integer weights 1777 are supported. 1778 1779 Returns: 1780 A `CategoricalColumn` composed of two sparse features: one represents id, 1781 the other represents weight (value) of the id feature in that example. 1782 1783 Raises: 1784 ValueError: if `dtype` is not convertible to float. 1785 """ 1786 if (dtype is None) or not (dtype.is_integer or dtype.is_floating): 1787 raise ValueError('dtype {} is not convertible to float.'.format(dtype)) 1788 return WeightedCategoricalColumn( 1789 categorical_column=categorical_column, 1790 weight_feature_key=weight_feature_key, 1791 dtype=dtype) 1792 1793 1794@tf_export('feature_column.crossed_column') 1795def crossed_column(keys, hash_bucket_size, hash_key=None): 1796 """Returns a column for performing crosses of categorical features. 1797 1798 Crossed features will be hashed according to `hash_bucket_size`. Conceptually, 1799 the transformation can be thought of as: 1800 Hash(cartesian product of features) % `hash_bucket_size` 1801 1802 For example, if the input features are: 1803 1804 * SparseTensor referred by first key: 1805 1806 ```python 1807 shape = [2, 2] 1808 { 1809 [0, 0]: "a" 1810 [1, 0]: "b" 1811 [1, 1]: "c" 1812 } 1813 ``` 1814 1815 * SparseTensor referred by second key: 1816 1817 ```python 1818 shape = [2, 1] 1819 { 1820 [0, 0]: "d" 1821 [1, 0]: "e" 1822 } 1823 ``` 1824 1825 then crossed feature will look like: 1826 1827 ```python 1828 shape = [2, 2] 1829 { 1830 [0, 0]: Hash64("d", Hash64("a")) % hash_bucket_size 1831 [1, 0]: Hash64("e", Hash64("b")) % hash_bucket_size 1832 [1, 1]: Hash64("e", Hash64("c")) % hash_bucket_size 1833 } 1834 ``` 1835 1836 Here is an example to create a linear model with crosses of string features: 1837 1838 ```python 1839 keywords_x_doc_terms = crossed_column(['keywords', 'doc_terms'], 50K) 1840 columns = [keywords_x_doc_terms, ...] 1841 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1842 linear_prediction = linear_model(features, columns) 1843 ``` 1844 1845 You could also use vocabulary lookup before crossing: 1846 1847 ```python 1848 keywords = categorical_column_with_vocabulary_file( 1849 'keywords', '/path/to/vocabulary/file', vocabulary_size=1K) 1850 keywords_x_doc_terms = crossed_column([keywords, 'doc_terms'], 50K) 1851 columns = [keywords_x_doc_terms, ...] 1852 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1853 linear_prediction = linear_model(features, columns) 1854 ``` 1855 1856 If an input feature is of numeric type, you can use 1857 `categorical_column_with_identity`, or `bucketized_column`, as in the example: 1858 1859 ```python 1860 # vertical_id is an integer categorical feature. 1861 vertical_id = categorical_column_with_identity('vertical_id', 10K) 1862 price = numeric_column('price') 1863 # bucketized_column converts numerical feature to a categorical one. 1864 bucketized_price = bucketized_column(price, boundaries=[...]) 1865 vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K) 1866 columns = [vertical_id_x_price, ...] 1867 features = tf.io.parse_example(..., features=make_parse_example_spec(columns)) 1868 linear_prediction = linear_model(features, columns) 1869 ``` 1870 1871 To use crossed column in DNN model, you need to add it in an embedding column 1872 as in this example: 1873 1874 ```python 1875 vertical_id_x_price = crossed_column([vertical_id, bucketized_price], 50K) 1876 vertical_id_x_price_embedded = embedding_column(vertical_id_x_price, 10) 1877 dense_tensor = input_layer(features, [vertical_id_x_price_embedded, ...]) 1878 ``` 1879 1880 Args: 1881 keys: An iterable identifying the features to be crossed. Each element can 1882 be either: 1883 * string: Will use the corresponding feature which must be of string type. 1884 * `CategoricalColumn`: Will use the transformed tensor produced by this 1885 column. Does not support hashed categorical column. 1886 hash_bucket_size: An int > 1. The number of buckets. 1887 hash_key: Specify the hash_key that will be used by the `FingerprintCat64` 1888 function to combine the crosses fingerprints on SparseCrossOp (optional). 1889 1890 Returns: 1891 A `CrossedColumn`. 1892 1893 Raises: 1894 ValueError: If `len(keys) < 2`. 1895 ValueError: If any of the keys is neither a string nor `CategoricalColumn`. 1896 ValueError: If any of the keys is `HashedCategoricalColumn`. 1897 ValueError: If `hash_bucket_size < 1`. 1898 """ 1899 if not hash_bucket_size or hash_bucket_size < 1: 1900 raise ValueError('hash_bucket_size must be > 1. ' 1901 'hash_bucket_size: {}'.format(hash_bucket_size)) 1902 if not keys or len(keys) < 2: 1903 raise ValueError( 1904 'keys must be a list with length > 1. Given: {}'.format(keys)) 1905 for key in keys: 1906 if (not isinstance(key, six.string_types) and 1907 not isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn))): # pylint: disable=protected-access 1908 raise ValueError( 1909 'Unsupported key type. All keys must be either string, or ' 1910 'categorical column except HashedCategoricalColumn. ' 1911 'Given: {}'.format(key)) 1912 if isinstance(key, 1913 (HashedCategoricalColumn, fc_old._HashedCategoricalColumn)): # pylint: disable=protected-access 1914 raise ValueError( 1915 'categorical_column_with_hash_bucket is not supported for crossing. ' 1916 'Hashing before crossing will increase probability of collision. ' 1917 'Instead, use the feature name as a string. Given: {}'.format(key)) 1918 return CrossedColumn( 1919 keys=tuple(keys), hash_bucket_size=hash_bucket_size, hash_key=hash_key) 1920 1921 1922# TODO(b/181853833): Add a tf.type for instance type checking. 1923@tf_export('__internal__.feature_column.FeatureColumn', v1=[]) 1924@six.add_metaclass(abc.ABCMeta) 1925class FeatureColumn(object): 1926 """Represents a feature column abstraction. 1927 1928 WARNING: Do not subclass this layer unless you know what you are doing: 1929 the API is subject to future changes. 1930 1931 To distinguish between the concept of a feature family and a specific binary 1932 feature within a family, we refer to a feature family like "country" as a 1933 feature column. For example, we can have a feature in a `tf.Example` format: 1934 {key: "country", value: [ "US" ]} 1935 In this example the value of feature is "US" and "country" refers to the 1936 column of the feature. 1937 1938 This class is an abstract class. Users should not create instances of this. 1939 """ 1940 1941 @abc.abstractproperty 1942 def name(self): 1943 """Returns string. Used for naming.""" 1944 pass 1945 1946 def __lt__(self, other): 1947 """Allows feature columns to be sorted in Python 3 as they are in Python 2. 1948 1949 Feature columns need to occasionally be sortable, for example when used as 1950 keys in a features dictionary passed to a layer. 1951 1952 In CPython, `__lt__` must be defined for all objects in the 1953 sequence being sorted. 1954 1955 If any objects in the sequence being sorted do not have an `__lt__` method 1956 compatible with feature column objects (such as strings), then CPython will 1957 fall back to using the `__gt__` method below. 1958 https://docs.python.org/3/library/stdtypes.html#list.sort 1959 1960 Args: 1961 other: The other object to compare to. 1962 1963 Returns: 1964 True if the string representation of this object is lexicographically less 1965 than the string representation of `other`. For FeatureColumn objects, 1966 this looks like "<__main__.FeatureColumn object at 0xa>". 1967 """ 1968 return str(self) < str(other) 1969 1970 def __gt__(self, other): 1971 """Allows feature columns to be sorted in Python 3 as they are in Python 2. 1972 1973 Feature columns need to occasionally be sortable, for example when used as 1974 keys in a features dictionary passed to a layer. 1975 1976 `__gt__` is called when the "other" object being compared during the sort 1977 does not have `__lt__` defined. 1978 Example: 1979 ``` 1980 # __lt__ only class 1981 class A(): 1982 def __lt__(self, other): return str(self) < str(other) 1983 1984 a = A() 1985 a < "b" # True 1986 "0" < a # Error 1987 1988 # __lt__ and __gt__ class 1989 class B(): 1990 def __lt__(self, other): return str(self) < str(other) 1991 def __gt__(self, other): return str(self) > str(other) 1992 1993 b = B() 1994 b < "c" # True 1995 "0" < b # True 1996 ``` 1997 1998 Args: 1999 other: The other object to compare to. 2000 2001 Returns: 2002 True if the string representation of this object is lexicographically 2003 greater than the string representation of `other`. For FeatureColumn 2004 objects, this looks like "<__main__.FeatureColumn object at 0xa>". 2005 """ 2006 return str(self) > str(other) 2007 2008 @abc.abstractmethod 2009 def transform_feature(self, transformation_cache, state_manager): 2010 """Returns intermediate representation (usually a `Tensor`). 2011 2012 Uses `transformation_cache` to create an intermediate representation 2013 (usually a `Tensor`) that other feature columns can use. 2014 2015 Example usage of `transformation_cache`: 2016 Let's say a Feature column depends on raw feature ('raw') and another 2017 `FeatureColumn` (input_fc). To access corresponding `Tensor`s, 2018 transformation_cache will be used as follows: 2019 2020 ```python 2021 raw_tensor = transformation_cache.get('raw', state_manager) 2022 fc_tensor = transformation_cache.get(input_fc, state_manager) 2023 ``` 2024 2025 Args: 2026 transformation_cache: A `FeatureTransformationCache` object to access 2027 features. 2028 state_manager: A `StateManager` to create / access resources such as 2029 lookup tables. 2030 2031 Returns: 2032 Transformed feature `Tensor`. 2033 """ 2034 pass 2035 2036 @abc.abstractproperty 2037 def parse_example_spec(self): 2038 """Returns a `tf.Example` parsing spec as dict. 2039 2040 It is used for get_parsing_spec for `tf.io.parse_example`. Returned spec is 2041 a dict from keys ('string') to `VarLenFeature`, `FixedLenFeature`, and other 2042 supported objects. Please check documentation of `tf.io.parse_example` for 2043 all supported spec objects. 2044 2045 Let's say a Feature column depends on raw feature ('raw') and another 2046 `FeatureColumn` (input_fc). One possible implementation of 2047 parse_example_spec is as follows: 2048 2049 ```python 2050 spec = {'raw': tf.io.FixedLenFeature(...)} 2051 spec.update(input_fc.parse_example_spec) 2052 return spec 2053 ``` 2054 """ 2055 pass 2056 2057 def create_state(self, state_manager): 2058 """Uses the `state_manager` to create state for the FeatureColumn. 2059 2060 Args: 2061 state_manager: A `StateManager` to create / access resources such as 2062 lookup tables and variables. 2063 """ 2064 pass 2065 2066 @abc.abstractproperty 2067 def _is_v2_column(self): 2068 """Returns whether this FeatureColumn is fully conformant to the new API. 2069 2070 This is needed for composition type cases where an EmbeddingColumn etc. 2071 might take in old categorical columns as input and then we want to use the 2072 old API. 2073 """ 2074 pass 2075 2076 @abc.abstractproperty 2077 def parents(self): 2078 """Returns a list of immediate raw feature and FeatureColumn dependencies. 2079 2080 For example: 2081 # For the following feature columns 2082 a = numeric_column('f1') 2083 c = crossed_column(a, 'f2') 2084 # The expected parents are: 2085 a.parents = ['f1'] 2086 c.parents = [a, 'f2'] 2087 """ 2088 pass 2089 2090 def get_config(self): 2091 """Returns the config of the feature column. 2092 2093 A FeatureColumn config is a Python dictionary (serializable) containing the 2094 configuration of a FeatureColumn. The same FeatureColumn can be 2095 reinstantiated later from this configuration. 2096 2097 The config of a feature column does not include information about feature 2098 columns depending on it nor the FeatureColumn class name. 2099 2100 Example with (de)serialization practices followed in this file: 2101 ```python 2102 class SerializationExampleFeatureColumn( 2103 FeatureColumn, collections.namedtuple( 2104 'SerializationExampleFeatureColumn', 2105 ('dimension', 'parent', 'dtype', 'normalizer_fn'))): 2106 2107 def get_config(self): 2108 # Create a dict from the namedtuple. 2109 # Python attribute literals can be directly copied from / to the config. 2110 # For example 'dimension', assuming it is an integer literal. 2111 config = dict(zip(self._fields, self)) 2112 2113 # (De)serialization of parent FeatureColumns should use the provided 2114 # (de)serialize_feature_column() methods that take care of de-duping. 2115 config['parent'] = serialize_feature_column(self.parent) 2116 2117 # Many objects provide custom (de)serialization e.g: for tf.DType 2118 # tf.DType.name, tf.as_dtype() can be used. 2119 config['dtype'] = self.dtype.name 2120 2121 # Non-trivial dependencies should be Keras-(de)serializable. 2122 config['normalizer_fn'] = generic_utils.serialize_keras_object( 2123 self.normalizer_fn) 2124 2125 return config 2126 2127 @classmethod 2128 def from_config(cls, config, custom_objects=None, columns_by_name=None): 2129 # This should do the inverse transform from `get_config` and construct 2130 # the namedtuple. 2131 kwargs = config.copy() 2132 kwargs['parent'] = deserialize_feature_column( 2133 config['parent'], custom_objects, columns_by_name) 2134 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 2135 kwargs['normalizer_fn'] = generic_utils.deserialize_keras_object( 2136 config['normalizer_fn'], custom_objects=custom_objects) 2137 return cls(**kwargs) 2138 2139 ``` 2140 Returns: 2141 A serializable Dict that can be used to deserialize the object with 2142 from_config. 2143 """ 2144 return self._get_config() 2145 2146 def _get_config(self): 2147 raise NotImplementedError('Must be implemented in subclasses.') 2148 2149 @classmethod 2150 def from_config(cls, config, custom_objects=None, columns_by_name=None): 2151 """Creates a FeatureColumn from its config. 2152 2153 This method should be the reverse of `get_config`, capable of instantiating 2154 the same FeatureColumn from the config dictionary. See `get_config` for an 2155 example of common (de)serialization practices followed in this file. 2156 2157 TODO(b/118939620): This is a private method until consensus is reached on 2158 supporting object deserialization deduping within Keras. 2159 2160 Args: 2161 config: A Dict config acquired with `get_config`. 2162 custom_objects: Optional dictionary mapping names (strings) to custom 2163 classes or functions to be considered during deserialization. 2164 columns_by_name: A Dict[String, FeatureColumn] of existing columns in 2165 order to avoid duplication. Should be passed to any calls to 2166 deserialize_feature_column(). 2167 2168 Returns: 2169 A FeatureColumn for the input config. 2170 """ 2171 return cls._from_config(config, custom_objects, columns_by_name) 2172 2173 @classmethod 2174 def _from_config(cls, config, custom_objects=None, columns_by_name=None): 2175 raise NotImplementedError('Must be implemented in subclasses.') 2176 2177 2178# TODO(b/181853833): Add a tf.type for instance type checking. 2179@tf_export('__internal__.feature_column.DenseColumn', v1=[]) 2180class DenseColumn(FeatureColumn): 2181 """Represents a column which can be represented as `Tensor`. 2182 2183 Some examples of this type are: numeric_column, embedding_column, 2184 indicator_column. 2185 """ 2186 2187 @abc.abstractproperty 2188 def variable_shape(self): 2189 """`TensorShape` of `get_dense_tensor`, without batch dimension.""" 2190 pass 2191 2192 @abc.abstractmethod 2193 def get_dense_tensor(self, transformation_cache, state_manager): 2194 """Returns a `Tensor`. 2195 2196 The output of this function will be used by model-builder-functions. For 2197 example the pseudo code of `input_layer` will be like: 2198 2199 ```python 2200 def input_layer(features, feature_columns, ...): 2201 outputs = [fc.get_dense_tensor(...) for fc in feature_columns] 2202 return tf.concat(outputs) 2203 ``` 2204 2205 Args: 2206 transformation_cache: A `FeatureTransformationCache` object to access 2207 features. 2208 state_manager: A `StateManager` to create / access resources such as 2209 lookup tables. 2210 2211 Returns: 2212 `Tensor` of shape [batch_size] + `variable_shape`. 2213 """ 2214 pass 2215 2216 2217def is_feature_column_v2(feature_columns): 2218 """Returns True if all feature columns are V2.""" 2219 for feature_column in feature_columns: 2220 if not isinstance(feature_column, FeatureColumn): 2221 return False 2222 if not feature_column._is_v2_column: # pylint: disable=protected-access 2223 return False 2224 return True 2225 2226 2227def _create_weighted_sum(column, transformation_cache, state_manager, 2228 sparse_combiner, weight_var): 2229 """Creates a weighted sum for a dense/categorical column for linear_model.""" 2230 if isinstance(column, CategoricalColumn): 2231 return _create_categorical_column_weighted_sum( 2232 column=column, 2233 transformation_cache=transformation_cache, 2234 state_manager=state_manager, 2235 sparse_combiner=sparse_combiner, 2236 weight_var=weight_var) 2237 else: 2238 return _create_dense_column_weighted_sum( 2239 column=column, 2240 transformation_cache=transformation_cache, 2241 state_manager=state_manager, 2242 weight_var=weight_var) 2243 2244 2245def _create_dense_column_weighted_sum(column, transformation_cache, 2246 state_manager, weight_var): 2247 """Create a weighted sum of a dense column for linear_model.""" 2248 tensor = column.get_dense_tensor(transformation_cache, state_manager) 2249 num_elements = column.variable_shape.num_elements() 2250 batch_size = array_ops.shape(tensor)[0] 2251 tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements)) 2252 return math_ops.matmul(tensor, weight_var, name='weighted_sum') 2253 2254 2255class CategoricalColumn(FeatureColumn): 2256 """Represents a categorical feature. 2257 2258 A categorical feature typically handled with a `tf.sparse.SparseTensor` of 2259 IDs. 2260 """ 2261 2262 IdWeightPair = collections.namedtuple( # pylint: disable=invalid-name 2263 'IdWeightPair', ('id_tensor', 'weight_tensor')) 2264 2265 @abc.abstractproperty 2266 def num_buckets(self): 2267 """Returns number of buckets in this sparse feature.""" 2268 pass 2269 2270 @abc.abstractmethod 2271 def get_sparse_tensors(self, transformation_cache, state_manager): 2272 """Returns an IdWeightPair. 2273 2274 `IdWeightPair` is a pair of `SparseTensor`s which represents ids and 2275 weights. 2276 2277 `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets` 2278 `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a 2279 `SparseTensor` of `float` or `None` to indicate all weights should be 2280 taken to be 1. If specified, `weight_tensor` must have exactly the same 2281 shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing 2282 output of a `VarLenFeature` which is a ragged matrix. 2283 2284 Args: 2285 transformation_cache: A `FeatureTransformationCache` object to access 2286 features. 2287 state_manager: A `StateManager` to create / access resources such as 2288 lookup tables. 2289 """ 2290 pass 2291 2292 2293def _create_categorical_column_weighted_sum( 2294 column, transformation_cache, state_manager, sparse_combiner, weight_var): 2295 # pylint: disable=g-doc-return-or-yield,g-doc-args 2296 """Create a weighted sum of a categorical column for linear_model. 2297 2298 Note to maintainer: As implementation details, the weighted sum is 2299 implemented via embedding_lookup_sparse toward efficiency. Mathematically, 2300 they are the same. 2301 2302 To be specific, conceptually, categorical column can be treated as multi-hot 2303 vector. Say: 2304 2305 ```python 2306 x = [0 0 1] # categorical column input 2307 w = [a b c] # weights 2308 ``` 2309 The weighted sum is `c` in this case, which is same as `w[2]`. 2310 2311 Another example is 2312 2313 ```python 2314 x = [0 1 1] # categorical column input 2315 w = [a b c] # weights 2316 ``` 2317 The weighted sum is `b + c` in this case, which is same as `w[2] + w[3]`. 2318 2319 For both cases, we can implement weighted sum via embedding_lookup with 2320 sparse_combiner = "sum". 2321 """ 2322 2323 sparse_tensors = column.get_sparse_tensors(transformation_cache, 2324 state_manager) 2325 id_tensor = sparse_ops.sparse_reshape(sparse_tensors.id_tensor, [ 2326 array_ops.shape(sparse_tensors.id_tensor)[0], -1 2327 ]) 2328 weight_tensor = sparse_tensors.weight_tensor 2329 if weight_tensor is not None: 2330 weight_tensor = sparse_ops.sparse_reshape( 2331 weight_tensor, [array_ops.shape(weight_tensor)[0], -1]) 2332 2333 return embedding_ops.safe_embedding_lookup_sparse( 2334 weight_var, 2335 id_tensor, 2336 sparse_weights=weight_tensor, 2337 combiner=sparse_combiner, 2338 name='weighted_sum') 2339 2340 2341# TODO(b/181853833): Add a tf.type for instance type checking. 2342@tf_export('__internal__.feature_column.SequenceDenseColumn', v1=[]) 2343class SequenceDenseColumn(FeatureColumn): 2344 """Represents dense sequence data.""" 2345 2346 TensorSequenceLengthPair = collections.namedtuple( # pylint: disable=invalid-name 2347 'TensorSequenceLengthPair', ('dense_tensor', 'sequence_length')) 2348 2349 @abc.abstractmethod 2350 def get_sequence_dense_tensor(self, transformation_cache, state_manager): 2351 """Returns a `TensorSequenceLengthPair`. 2352 2353 Args: 2354 transformation_cache: A `FeatureTransformationCache` object to access 2355 features. 2356 state_manager: A `StateManager` to create / access resources such as 2357 lookup tables. 2358 """ 2359 pass 2360 2361 2362@tf_export('__internal__.feature_column.FeatureTransformationCache', v1=[]) 2363class FeatureTransformationCache(object): 2364 """Handles caching of transformations while building the model. 2365 2366 `FeatureColumn` specifies how to digest an input column to the network. Some 2367 feature columns require data transformations. This class caches those 2368 transformations. 2369 2370 Some features may be used in more than one place. For example, one can use a 2371 bucketized feature by itself and a cross with it. In that case we 2372 should create only one bucketization op instead of creating ops for each 2373 feature column separately. To handle re-use of transformed columns, 2374 `FeatureTransformationCache` caches all previously transformed columns. 2375 2376 Example: 2377 We're trying to use the following `FeatureColumn`s: 2378 2379 ```python 2380 bucketized_age = fc.bucketized_column(fc.numeric_column("age"), ...) 2381 keywords = fc.categorical_column_with_hash_buckets("keywords", ...) 2382 age_X_keywords = fc.crossed_column([bucketized_age, "keywords"]) 2383 ... = linear_model(features, 2384 [bucketized_age, keywords, age_X_keywords] 2385 ``` 2386 2387 If we transform each column independently, then we'll get duplication of 2388 bucketization (one for cross, one for bucketization itself). 2389 The `FeatureTransformationCache` eliminates this duplication. 2390 """ 2391 2392 def __init__(self, features): 2393 """Creates a `FeatureTransformationCache`. 2394 2395 Args: 2396 features: A mapping from feature column to objects that are `Tensor` or 2397 `SparseTensor`, or can be converted to same via 2398 `sparse_tensor.convert_to_tensor_or_sparse_tensor`. A `string` key 2399 signifies a base feature (not-transformed). A `FeatureColumn` key 2400 means that this `Tensor` is the output of an existing `FeatureColumn` 2401 which can be reused. 2402 """ 2403 self._features = features.copy() 2404 self._feature_tensors = {} 2405 2406 def get(self, key, state_manager, training=None): 2407 """Returns a `Tensor` for the given key. 2408 2409 A `str` key is used to access a base feature (not-transformed). When a 2410 `FeatureColumn` is passed, the transformed feature is returned if it 2411 already exists, otherwise the given `FeatureColumn` is asked to provide its 2412 transformed output, which is then cached. 2413 2414 Args: 2415 key: a `str` or a `FeatureColumn`. 2416 state_manager: A StateManager object that holds the FeatureColumn state. 2417 training: Boolean indicating whether to the column is being used in 2418 training mode. This argument is passed to the transform_feature method 2419 of any `FeatureColumn` that takes a `training` argument. For example, if 2420 a `FeatureColumn` performed dropout, it could expose a `training` 2421 argument to control whether the dropout should be applied. 2422 2423 Returns: 2424 The transformed `Tensor` corresponding to the `key`. 2425 2426 Raises: 2427 ValueError: if key is not found or a transformed `Tensor` cannot be 2428 computed. 2429 """ 2430 if key in self._feature_tensors: 2431 # FeatureColumn is already transformed or converted. 2432 return self._feature_tensors[key] 2433 2434 if key in self._features: 2435 feature_tensor = self._get_raw_feature_as_tensor(key) 2436 self._feature_tensors[key] = feature_tensor 2437 return feature_tensor 2438 2439 if isinstance(key, six.string_types): 2440 raise ValueError('Feature {} is not in features dictionary.'.format(key)) 2441 2442 if not isinstance(key, FeatureColumn): 2443 raise TypeError('"key" must be either a "str" or "FeatureColumn". ' 2444 'Provided: {}'.format(key)) 2445 2446 column = key 2447 logging.debug('Transforming feature_column %s.', column) 2448 2449 # Some columns may need information about whether the transformation is 2450 # happening in training or prediction mode, but not all columns expose this 2451 # argument. 2452 try: 2453 transformed = column.transform_feature( 2454 self, state_manager, training=training) 2455 except TypeError: 2456 transformed = column.transform_feature(self, state_manager) 2457 if transformed is None: 2458 raise ValueError('Column {} is not supported.'.format(column.name)) 2459 self._feature_tensors[column] = transformed 2460 return transformed 2461 2462 def _get_raw_feature_as_tensor(self, key): 2463 """Gets the raw_feature (keyed by `key`) as `tensor`. 2464 2465 The raw feature is converted to (sparse) tensor and maybe expand dim. 2466 2467 For both `Tensor` and `SparseTensor`, the rank will be expanded (to 2) if 2468 the rank is 1. This supports dynamic rank also. For rank 0 raw feature, will 2469 error out as it is not supported. 2470 2471 Args: 2472 key: A `str` key to access the raw feature. 2473 2474 Returns: 2475 A `Tensor` or `SparseTensor`. 2476 2477 Raises: 2478 ValueError: if the raw feature has rank 0. 2479 """ 2480 raw_feature = self._features[key] 2481 feature_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( 2482 raw_feature) 2483 2484 def expand_dims(input_tensor): 2485 # Input_tensor must have rank 1. 2486 if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): 2487 return sparse_ops.sparse_reshape( 2488 input_tensor, [array_ops.shape(input_tensor)[0], 1]) 2489 else: 2490 return array_ops.expand_dims(input_tensor, -1) 2491 2492 rank = feature_tensor.get_shape().ndims 2493 if rank is not None: 2494 if rank == 0: 2495 raise ValueError( 2496 'Feature (key: {}) cannot have rank 0. Given: {}'.format( 2497 key, feature_tensor)) 2498 return feature_tensor if rank != 1 else expand_dims(feature_tensor) 2499 2500 # Handle dynamic rank. 2501 with ops.control_dependencies([ 2502 check_ops.assert_positive( 2503 array_ops.rank(feature_tensor), 2504 message='Feature (key: {}) cannot have rank 0. Given: {}'.format( 2505 key, feature_tensor))]): 2506 return control_flow_ops.cond( 2507 math_ops.equal(1, array_ops.rank(feature_tensor)), 2508 lambda: expand_dims(feature_tensor), 2509 lambda: feature_tensor) 2510 2511 2512# TODO(ptucker): Move to third_party/tensorflow/python/ops/sparse_ops.py 2513def _to_sparse_input_and_drop_ignore_values(input_tensor, ignore_value=None): 2514 """Converts a `Tensor` to a `SparseTensor`, dropping ignore_value cells. 2515 2516 If `input_tensor` is already a `SparseTensor`, just return it. 2517 2518 Args: 2519 input_tensor: A string or integer `Tensor`. 2520 ignore_value: Entries in `dense_tensor` equal to this value will be 2521 absent from the resulting `SparseTensor`. If `None`, default value of 2522 `dense_tensor`'s dtype will be used ('' for `str`, -1 for `int`). 2523 2524 Returns: 2525 A `SparseTensor` with the same shape as `input_tensor`. 2526 2527 Raises: 2528 ValueError: when `input_tensor`'s rank is `None`. 2529 """ 2530 input_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( 2531 input_tensor) 2532 if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): 2533 return input_tensor 2534 with ops.name_scope(None, 'to_sparse_input', (input_tensor, ignore_value,)): 2535 if ignore_value is None: 2536 if input_tensor.dtype == dtypes.string: 2537 # Exception due to TF strings are converted to numpy objects by default. 2538 ignore_value = '' 2539 elif input_tensor.dtype.is_integer: 2540 ignore_value = -1 # -1 has a special meaning of missing feature 2541 else: 2542 # NOTE: `as_numpy_dtype` is a property, so with the parentheses this is 2543 # constructing a new numpy object of the given type, which yields the 2544 # default value for that type. 2545 ignore_value = input_tensor.dtype.as_numpy_dtype() 2546 ignore_value = math_ops.cast( 2547 ignore_value, input_tensor.dtype, name='ignore_value') 2548 indices = array_ops.where_v2( 2549 math_ops.not_equal(input_tensor, ignore_value), name='indices') 2550 return sparse_tensor_lib.SparseTensor( 2551 indices=indices, 2552 values=array_ops.gather_nd(input_tensor, indices, name='values'), 2553 dense_shape=array_ops.shape( 2554 input_tensor, out_type=dtypes.int64, name='dense_shape')) 2555 2556 2557def _normalize_feature_columns(feature_columns): 2558 """Normalizes the `feature_columns` input. 2559 2560 This method converts the `feature_columns` to list type as best as it can. In 2561 addition, verifies the type and other parts of feature_columns, required by 2562 downstream library. 2563 2564 Args: 2565 feature_columns: The raw feature columns, usually passed by users. 2566 2567 Returns: 2568 The normalized feature column list. 2569 2570 Raises: 2571 ValueError: for any invalid inputs, such as empty, duplicated names, etc. 2572 """ 2573 if isinstance(feature_columns, FeatureColumn): 2574 feature_columns = [feature_columns] 2575 2576 if isinstance(feature_columns, collections_abc.Iterator): 2577 feature_columns = list(feature_columns) 2578 2579 if isinstance(feature_columns, dict): 2580 raise ValueError('Expected feature_columns to be iterable, found dict.') 2581 2582 for column in feature_columns: 2583 if not isinstance(column, FeatureColumn): 2584 raise ValueError('Items of feature_columns must be a FeatureColumn. ' 2585 'Given (type {}): {}.'.format(type(column), column)) 2586 if not feature_columns: 2587 raise ValueError('feature_columns must not be empty.') 2588 name_to_column = {} 2589 for column in feature_columns: 2590 if column.name in name_to_column: 2591 raise ValueError('Duplicate feature column name found for columns: {} ' 2592 'and {}. This usually means that these columns refer to ' 2593 'same base feature. Either one must be discarded or a ' 2594 'duplicated but renamed item must be inserted in ' 2595 'features dict.'.format(column, 2596 name_to_column[column.name])) 2597 name_to_column[column.name] = column 2598 2599 return sorted(feature_columns, key=lambda x: x.name) 2600 2601 2602class NumericColumn( 2603 DenseColumn, 2604 fc_old._DenseColumn, # pylint: disable=protected-access 2605 collections.namedtuple( 2606 'NumericColumn', 2607 ('key', 'shape', 'default_value', 'dtype', 'normalizer_fn'))): 2608 """see `numeric_column`.""" 2609 2610 @property 2611 def _is_v2_column(self): 2612 return True 2613 2614 @property 2615 def name(self): 2616 """See `FeatureColumn` base class.""" 2617 return self.key 2618 2619 @property 2620 def parse_example_spec(self): 2621 """See `FeatureColumn` base class.""" 2622 return { 2623 self.key: 2624 parsing_ops.FixedLenFeature(self.shape, self.dtype, 2625 self.default_value) 2626 } 2627 2628 @property 2629 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2630 _FEATURE_COLUMN_DEPRECATION) 2631 def _parse_example_spec(self): 2632 return self.parse_example_spec 2633 2634 def _transform_input_tensor(self, input_tensor): 2635 if isinstance(input_tensor, sparse_tensor_lib.SparseTensor): 2636 raise ValueError( 2637 'The corresponding Tensor of numerical column must be a Tensor. ' 2638 'SparseTensor is not supported. key: {}'.format(self.key)) 2639 if self.normalizer_fn is not None: 2640 input_tensor = self.normalizer_fn(input_tensor) 2641 return math_ops.cast(input_tensor, dtypes.float32) 2642 2643 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2644 _FEATURE_COLUMN_DEPRECATION) 2645 def _transform_feature(self, inputs): 2646 input_tensor = inputs.get(self.key) 2647 return self._transform_input_tensor(input_tensor) 2648 2649 def transform_feature(self, transformation_cache, state_manager): 2650 """See `FeatureColumn` base class. 2651 2652 In this case, we apply the `normalizer_fn` to the input tensor. 2653 2654 Args: 2655 transformation_cache: A `FeatureTransformationCache` object to access 2656 features. 2657 state_manager: A `StateManager` to create / access resources such as 2658 lookup tables. 2659 2660 Returns: 2661 Normalized input tensor. 2662 Raises: 2663 ValueError: If a SparseTensor is passed in. 2664 """ 2665 input_tensor = transformation_cache.get(self.key, state_manager) 2666 return self._transform_input_tensor(input_tensor) 2667 2668 @property 2669 def variable_shape(self): 2670 """See `DenseColumn` base class.""" 2671 return tensor_shape.TensorShape(self.shape) 2672 2673 @property 2674 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2675 _FEATURE_COLUMN_DEPRECATION) 2676 def _variable_shape(self): 2677 return self.variable_shape 2678 2679 def get_dense_tensor(self, transformation_cache, state_manager): 2680 """Returns dense `Tensor` representing numeric feature. 2681 2682 Args: 2683 transformation_cache: A `FeatureTransformationCache` object to access 2684 features. 2685 state_manager: A `StateManager` to create / access resources such as 2686 lookup tables. 2687 2688 Returns: 2689 Dense `Tensor` created within `transform_feature`. 2690 """ 2691 # Feature has been already transformed. Return the intermediate 2692 # representation created by _transform_feature. 2693 return transformation_cache.get(self, state_manager) 2694 2695 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2696 _FEATURE_COLUMN_DEPRECATION) 2697 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 2698 del weight_collections 2699 del trainable 2700 return inputs.get(self) 2701 2702 @property 2703 def parents(self): 2704 """See 'FeatureColumn` base class.""" 2705 return [self.key] 2706 2707 def get_config(self): 2708 """See 'FeatureColumn` base class.""" 2709 config = dict(zip(self._fields, self)) 2710 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top 2711 config['normalizer_fn'] = serialization._serialize_keras_object( # pylint: disable=protected-access 2712 self.normalizer_fn) 2713 config['dtype'] = self.dtype.name 2714 return config 2715 2716 @classmethod 2717 def from_config(cls, config, custom_objects=None, columns_by_name=None): 2718 """See 'FeatureColumn` base class.""" 2719 _check_config_keys(config, cls._fields) 2720 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top 2721 kwargs = _standardize_and_copy_config(config) 2722 kwargs['normalizer_fn'] = serialization._deserialize_keras_object( # pylint: disable=protected-access 2723 config['normalizer_fn'], custom_objects=custom_objects) 2724 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 2725 2726 return cls(**kwargs) 2727 2728 2729class BucketizedColumn( 2730 DenseColumn, 2731 CategoricalColumn, 2732 fc_old._DenseColumn, # pylint: disable=protected-access 2733 fc_old._CategoricalColumn, # pylint: disable=protected-access 2734 collections.namedtuple('BucketizedColumn', 2735 ('source_column', 'boundaries'))): 2736 """See `bucketized_column`.""" 2737 2738 @property 2739 def _is_v2_column(self): 2740 return (isinstance(self.source_column, FeatureColumn) and 2741 self.source_column._is_v2_column) # pylint: disable=protected-access 2742 2743 @property 2744 def name(self): 2745 """See `FeatureColumn` base class.""" 2746 return '{}_bucketized'.format(self.source_column.name) 2747 2748 @property 2749 def parse_example_spec(self): 2750 """See `FeatureColumn` base class.""" 2751 return self.source_column.parse_example_spec 2752 2753 @property 2754 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2755 _FEATURE_COLUMN_DEPRECATION) 2756 def _parse_example_spec(self): 2757 return self.source_column._parse_example_spec # pylint: disable=protected-access 2758 2759 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2760 _FEATURE_COLUMN_DEPRECATION) 2761 def _transform_feature(self, inputs): 2762 """Returns bucketized categorical `source_column` tensor.""" 2763 source_tensor = inputs.get(self.source_column) 2764 return math_ops._bucketize( # pylint: disable=protected-access 2765 source_tensor, 2766 boundaries=self.boundaries) 2767 2768 def transform_feature(self, transformation_cache, state_manager): 2769 """Returns bucketized categorical `source_column` tensor.""" 2770 source_tensor = transformation_cache.get(self.source_column, state_manager) 2771 return math_ops._bucketize( # pylint: disable=protected-access 2772 source_tensor, 2773 boundaries=self.boundaries) 2774 2775 @property 2776 def variable_shape(self): 2777 """See `DenseColumn` base class.""" 2778 return tensor_shape.TensorShape( 2779 tuple(self.source_column.shape) + (len(self.boundaries) + 1,)) 2780 2781 @property 2782 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2783 _FEATURE_COLUMN_DEPRECATION) 2784 def _variable_shape(self): 2785 return self.variable_shape 2786 2787 def _get_dense_tensor_for_input_tensor(self, input_tensor): 2788 return array_ops.one_hot( 2789 indices=math_ops.cast(input_tensor, dtypes.int64), 2790 depth=len(self.boundaries) + 1, 2791 on_value=1., 2792 off_value=0.) 2793 2794 def get_dense_tensor(self, transformation_cache, state_manager): 2795 """Returns one hot encoded dense `Tensor`.""" 2796 input_tensor = transformation_cache.get(self, state_manager) 2797 return self._get_dense_tensor_for_input_tensor(input_tensor) 2798 2799 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2800 _FEATURE_COLUMN_DEPRECATION) 2801 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 2802 del weight_collections 2803 del trainable 2804 input_tensor = inputs.get(self) 2805 return self._get_dense_tensor_for_input_tensor(input_tensor) 2806 2807 @property 2808 def num_buckets(self): 2809 """See `CategoricalColumn` base class.""" 2810 # By construction, source_column is always one-dimensional. 2811 return (len(self.boundaries) + 1) * self.source_column.shape[0] 2812 2813 @property 2814 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2815 _FEATURE_COLUMN_DEPRECATION) 2816 def _num_buckets(self): 2817 return self.num_buckets 2818 2819 def _get_sparse_tensors_for_input_tensor(self, input_tensor): 2820 batch_size = array_ops.shape(input_tensor)[0] 2821 # By construction, source_column is always one-dimensional. 2822 source_dimension = self.source_column.shape[0] 2823 2824 i1 = array_ops.reshape( 2825 array_ops.tile( 2826 array_ops.expand_dims(math_ops.range(0, batch_size), 1), 2827 [1, source_dimension]), 2828 (-1,)) 2829 i2 = array_ops.tile(math_ops.range(0, source_dimension), [batch_size]) 2830 # Flatten the bucket indices and unique them across dimensions 2831 # E.g. 2nd dimension indices will range from k to 2*k-1 with k buckets 2832 bucket_indices = ( 2833 array_ops.reshape(input_tensor, (-1,)) + 2834 (len(self.boundaries) + 1) * i2) 2835 2836 indices = math_ops.cast( 2837 array_ops.transpose(array_ops.stack((i1, i2))), dtypes.int64) 2838 dense_shape = math_ops.cast( 2839 array_ops.stack([batch_size, source_dimension]), dtypes.int64) 2840 sparse_tensor = sparse_tensor_lib.SparseTensor( 2841 indices=indices, 2842 values=bucket_indices, 2843 dense_shape=dense_shape) 2844 return CategoricalColumn.IdWeightPair(sparse_tensor, None) 2845 2846 def get_sparse_tensors(self, transformation_cache, state_manager): 2847 """Converts dense inputs to SparseTensor so downstream code can use it.""" 2848 input_tensor = transformation_cache.get(self, state_manager) 2849 return self._get_sparse_tensors_for_input_tensor(input_tensor) 2850 2851 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2852 _FEATURE_COLUMN_DEPRECATION) 2853 def _get_sparse_tensors(self, inputs, weight_collections=None, 2854 trainable=None): 2855 """Converts dense inputs to SparseTensor so downstream code can use it.""" 2856 del weight_collections 2857 del trainable 2858 input_tensor = inputs.get(self) 2859 return self._get_sparse_tensors_for_input_tensor(input_tensor) 2860 2861 @property 2862 def parents(self): 2863 """See 'FeatureColumn` base class.""" 2864 return [self.source_column] 2865 2866 def get_config(self): 2867 """See 'FeatureColumn` base class.""" 2868 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 2869 config = dict(zip(self._fields, self)) 2870 config['source_column'] = serialize_feature_column(self.source_column) 2871 return config 2872 2873 @classmethod 2874 def from_config(cls, config, custom_objects=None, columns_by_name=None): 2875 """See 'FeatureColumn` base class.""" 2876 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 2877 _check_config_keys(config, cls._fields) 2878 kwargs = _standardize_and_copy_config(config) 2879 kwargs['source_column'] = deserialize_feature_column( 2880 config['source_column'], custom_objects, columns_by_name) 2881 return cls(**kwargs) 2882 2883 2884class EmbeddingColumn( 2885 DenseColumn, 2886 SequenceDenseColumn, 2887 fc_old._DenseColumn, # pylint: disable=protected-access 2888 fc_old._SequenceDenseColumn, # pylint: disable=protected-access 2889 collections.namedtuple( 2890 'EmbeddingColumn', 2891 ('categorical_column', 'dimension', 'combiner', 'initializer', 2892 'ckpt_to_load_from', 'tensor_name_in_ckpt', 'max_norm', 'trainable', 2893 'use_safe_embedding_lookup'))): 2894 """See `embedding_column`.""" 2895 2896 def __new__(cls, 2897 categorical_column, 2898 dimension, 2899 combiner, 2900 initializer, 2901 ckpt_to_load_from, 2902 tensor_name_in_ckpt, 2903 max_norm, 2904 trainable, 2905 use_safe_embedding_lookup=True): 2906 return super(EmbeddingColumn, cls).__new__( 2907 cls, 2908 categorical_column=categorical_column, 2909 dimension=dimension, 2910 combiner=combiner, 2911 initializer=initializer, 2912 ckpt_to_load_from=ckpt_to_load_from, 2913 tensor_name_in_ckpt=tensor_name_in_ckpt, 2914 max_norm=max_norm, 2915 trainable=trainable, 2916 use_safe_embedding_lookup=use_safe_embedding_lookup) 2917 2918 @property 2919 def _is_v2_column(self): 2920 return (isinstance(self.categorical_column, FeatureColumn) and 2921 self.categorical_column._is_v2_column) # pylint: disable=protected-access 2922 2923 @property 2924 def name(self): 2925 """See `FeatureColumn` base class.""" 2926 return '{}_embedding'.format(self.categorical_column.name) 2927 2928 @property 2929 def parse_example_spec(self): 2930 """See `FeatureColumn` base class.""" 2931 return self.categorical_column.parse_example_spec 2932 2933 @property 2934 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2935 _FEATURE_COLUMN_DEPRECATION) 2936 def _parse_example_spec(self): 2937 return self.categorical_column._parse_example_spec # pylint: disable=protected-access 2938 2939 def transform_feature(self, transformation_cache, state_manager): 2940 """Transforms underlying `categorical_column`.""" 2941 return transformation_cache.get(self.categorical_column, state_manager) 2942 2943 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2944 _FEATURE_COLUMN_DEPRECATION) 2945 def _transform_feature(self, inputs): 2946 return inputs.get(self.categorical_column) 2947 2948 @property 2949 def variable_shape(self): 2950 """See `DenseColumn` base class.""" 2951 return tensor_shape.TensorShape([self.dimension]) 2952 2953 @property 2954 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 2955 _FEATURE_COLUMN_DEPRECATION) 2956 def _variable_shape(self): 2957 return self.variable_shape 2958 2959 def create_state(self, state_manager): 2960 """Creates the embedding lookup variable.""" 2961 default_num_buckets = (self.categorical_column.num_buckets 2962 if self._is_v2_column 2963 else self.categorical_column._num_buckets) # pylint: disable=protected-access 2964 num_buckets = getattr(self.categorical_column, 'num_buckets', 2965 default_num_buckets) 2966 embedding_shape = (num_buckets, self.dimension) 2967 state_manager.create_variable( 2968 self, 2969 name='embedding_weights', 2970 shape=embedding_shape, 2971 dtype=dtypes.float32, 2972 trainable=self.trainable, 2973 use_resource=True, 2974 initializer=self.initializer) 2975 2976 def _get_dense_tensor_internal_helper(self, sparse_tensors, 2977 embedding_weights): 2978 sparse_ids = sparse_tensors.id_tensor 2979 sparse_weights = sparse_tensors.weight_tensor 2980 2981 if self.ckpt_to_load_from is not None: 2982 to_restore = embedding_weights 2983 if isinstance(to_restore, variables.PartitionedVariable): 2984 to_restore = to_restore._get_variable_list() # pylint: disable=protected-access 2985 checkpoint_utils.init_from_checkpoint(self.ckpt_to_load_from, { 2986 self.tensor_name_in_ckpt: to_restore 2987 }) 2988 2989 sparse_id_rank = tensor_shape.dimension_value( 2990 sparse_ids.dense_shape.get_shape()[0]) 2991 embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse 2992 if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and 2993 sparse_id_rank <= 2): 2994 embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2 2995 # Return embedding lookup result. 2996 return embedding_lookup_sparse( 2997 embedding_weights, 2998 sparse_ids, 2999 sparse_weights, 3000 combiner=self.combiner, 3001 name='%s_weights' % self.name, 3002 max_norm=self.max_norm) 3003 3004 def _get_dense_tensor_internal(self, sparse_tensors, state_manager): 3005 """Private method that follows the signature of get_dense_tensor.""" 3006 embedding_weights = state_manager.get_variable( 3007 self, name='embedding_weights') 3008 return self._get_dense_tensor_internal_helper(sparse_tensors, 3009 embedding_weights) 3010 3011 def _old_get_dense_tensor_internal(self, sparse_tensors, weight_collections, 3012 trainable): 3013 """Private method that follows the signature of _get_dense_tensor.""" 3014 embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access 3015 if (weight_collections and 3016 ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections): 3017 weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES) 3018 embedding_weights = variable_scope.get_variable( 3019 name='embedding_weights', 3020 shape=embedding_shape, 3021 dtype=dtypes.float32, 3022 initializer=self.initializer, 3023 trainable=self.trainable and trainable, 3024 collections=weight_collections) 3025 return self._get_dense_tensor_internal_helper(sparse_tensors, 3026 embedding_weights) 3027 3028 def get_dense_tensor(self, transformation_cache, state_manager): 3029 """Returns tensor after doing the embedding lookup. 3030 3031 Args: 3032 transformation_cache: A `FeatureTransformationCache` object to access 3033 features. 3034 state_manager: A `StateManager` to create / access resources such as 3035 lookup tables. 3036 3037 Returns: 3038 Embedding lookup tensor. 3039 3040 Raises: 3041 ValueError: `categorical_column` is SequenceCategoricalColumn. 3042 """ 3043 if isinstance(self.categorical_column, SequenceCategoricalColumn): 3044 raise ValueError( 3045 'In embedding_column: {}. ' 3046 'categorical_column must not be of type SequenceCategoricalColumn. ' 3047 'Suggested fix A: If you wish to use DenseFeatures, use a ' 3048 'non-sequence categorical_column_with_*. ' 3049 'Suggested fix B: If you wish to create sequence input, use ' 3050 'SequenceFeatures instead of DenseFeatures. ' 3051 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3052 self.categorical_column)) 3053 # Get sparse IDs and weights. 3054 sparse_tensors = self.categorical_column.get_sparse_tensors( 3055 transformation_cache, state_manager) 3056 return self._get_dense_tensor_internal(sparse_tensors, state_manager) 3057 3058 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3059 _FEATURE_COLUMN_DEPRECATION) 3060 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 3061 if isinstance( 3062 self.categorical_column, 3063 (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access 3064 raise ValueError( 3065 'In embedding_column: {}. ' 3066 'categorical_column must not be of type _SequenceCategoricalColumn. ' 3067 'Suggested fix A: If you wish to use DenseFeatures, use a ' 3068 'non-sequence categorical_column_with_*. ' 3069 'Suggested fix B: If you wish to create sequence input, use ' 3070 'SequenceFeatures instead of DenseFeatures. ' 3071 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3072 self.categorical_column)) 3073 sparse_tensors = self.categorical_column._get_sparse_tensors( # pylint: disable=protected-access 3074 inputs, weight_collections, trainable) 3075 return self._old_get_dense_tensor_internal(sparse_tensors, 3076 weight_collections, trainable) 3077 3078 def get_sequence_dense_tensor(self, transformation_cache, state_manager): 3079 """See `SequenceDenseColumn` base class.""" 3080 if not isinstance(self.categorical_column, SequenceCategoricalColumn): 3081 raise ValueError( 3082 'In embedding_column: {}. ' 3083 'categorical_column must be of type SequenceCategoricalColumn ' 3084 'to use SequenceFeatures. ' 3085 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 3086 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3087 self.categorical_column)) 3088 sparse_tensors = self.categorical_column.get_sparse_tensors( 3089 transformation_cache, state_manager) 3090 dense_tensor = self._get_dense_tensor_internal(sparse_tensors, 3091 state_manager) 3092 sequence_length = fc_utils.sequence_length_from_sparse_tensor( 3093 sparse_tensors.id_tensor) 3094 return SequenceDenseColumn.TensorSequenceLengthPair( 3095 dense_tensor=dense_tensor, sequence_length=sequence_length) 3096 3097 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3098 _FEATURE_COLUMN_DEPRECATION) 3099 def _get_sequence_dense_tensor(self, 3100 inputs, 3101 weight_collections=None, 3102 trainable=None): 3103 if not isinstance( 3104 self.categorical_column, 3105 (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access 3106 raise ValueError( 3107 'In embedding_column: {}. ' 3108 'categorical_column must be of type SequenceCategoricalColumn ' 3109 'to use SequenceFeatures. ' 3110 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 3111 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3112 self.categorical_column)) 3113 sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access 3114 dense_tensor = self._old_get_dense_tensor_internal( 3115 sparse_tensors, 3116 weight_collections=weight_collections, 3117 trainable=trainable) 3118 sequence_length = fc_utils.sequence_length_from_sparse_tensor( 3119 sparse_tensors.id_tensor) 3120 return SequenceDenseColumn.TensorSequenceLengthPair( 3121 dense_tensor=dense_tensor, sequence_length=sequence_length) 3122 3123 @property 3124 def parents(self): 3125 """See 'FeatureColumn` base class.""" 3126 return [self.categorical_column] 3127 3128 def get_config(self): 3129 """See 'FeatureColumn` base class.""" 3130 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top 3131 config = dict(zip(self._fields, self)) 3132 config['categorical_column'] = serialization.serialize_feature_column( 3133 self.categorical_column) 3134 config['initializer'] = serialization._serialize_keras_object( # pylint: disable=protected-access 3135 self.initializer) 3136 return config 3137 3138 @classmethod 3139 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3140 """See 'FeatureColumn` base class.""" 3141 if 'use_safe_embedding_lookup' not in config: 3142 config['use_safe_embedding_lookup'] = True 3143 from tensorflow.python.feature_column import serialization # pylint: disable=g-import-not-at-top 3144 _check_config_keys(config, cls._fields) 3145 kwargs = _standardize_and_copy_config(config) 3146 kwargs['categorical_column'] = serialization.deserialize_feature_column( 3147 config['categorical_column'], custom_objects, columns_by_name) 3148 all_initializers = dict(tf_inspect.getmembers(init_ops, tf_inspect.isclass)) 3149 kwargs['initializer'] = serialization._deserialize_keras_object( # pylint: disable=protected-access 3150 config['initializer'], 3151 module_objects=all_initializers, 3152 custom_objects=custom_objects) 3153 return cls(**kwargs) 3154 3155 3156def _raise_shared_embedding_column_error(): 3157 raise ValueError('SharedEmbeddingColumns are not supported in ' 3158 '`linear_model` or `input_layer`. Please use ' 3159 '`DenseFeatures` or `LinearModel` instead.') 3160 3161 3162class SharedEmbeddingColumnCreator(autotrackable.AutoTrackable): 3163 """Class that creates a `SharedEmbeddingColumn`.""" 3164 3165 def __init__(self, 3166 dimension, 3167 initializer, 3168 ckpt_to_load_from, 3169 tensor_name_in_ckpt, 3170 num_buckets, 3171 trainable, 3172 name='shared_embedding_column_creator', 3173 use_safe_embedding_lookup=True): 3174 self._dimension = dimension 3175 self._initializer = initializer 3176 self._ckpt_to_load_from = ckpt_to_load_from 3177 self._tensor_name_in_ckpt = tensor_name_in_ckpt 3178 self._num_buckets = num_buckets 3179 self._trainable = trainable 3180 self._name = name 3181 self._use_safe_embedding_lookup = use_safe_embedding_lookup 3182 # Map from graph keys to embedding_weight variables. 3183 self._embedding_weights = {} 3184 3185 def __call__(self, categorical_column, combiner, max_norm): 3186 return SharedEmbeddingColumn(categorical_column, self, combiner, max_norm, 3187 self._use_safe_embedding_lookup) 3188 3189 @property 3190 def embedding_weights(self): 3191 key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 3192 if key not in self._embedding_weights: 3193 embedding_shape = (self._num_buckets, self._dimension) 3194 var = variable_scope.get_variable( 3195 name=self._name, 3196 shape=embedding_shape, 3197 dtype=dtypes.float32, 3198 initializer=self._initializer, 3199 trainable=self._trainable) 3200 3201 if self._ckpt_to_load_from is not None: 3202 to_restore = var 3203 if isinstance(to_restore, variables.PartitionedVariable): 3204 to_restore = to_restore._get_variable_list() # pylint: disable=protected-access 3205 checkpoint_utils.init_from_checkpoint( 3206 self._ckpt_to_load_from, {self._tensor_name_in_ckpt: to_restore}) 3207 self._embedding_weights[key] = var 3208 return self._embedding_weights[key] 3209 3210 @property 3211 def dimension(self): 3212 return self._dimension 3213 3214 3215class SharedEmbeddingColumn( 3216 DenseColumn, 3217 SequenceDenseColumn, 3218 fc_old._DenseColumn, # pylint: disable=protected-access 3219 fc_old._SequenceDenseColumn, # pylint: disable=protected-access 3220 collections.namedtuple( 3221 'SharedEmbeddingColumn', 3222 ('categorical_column', 'shared_embedding_column_creator', 'combiner', 3223 'max_norm', 'use_safe_embedding_lookup'))): 3224 """See `embedding_column`.""" 3225 3226 def __new__(cls, 3227 categorical_column, 3228 shared_embedding_column_creator, 3229 combiner, 3230 max_norm, 3231 use_safe_embedding_lookup=True): 3232 return super(SharedEmbeddingColumn, cls).__new__( 3233 cls, 3234 categorical_column=categorical_column, 3235 shared_embedding_column_creator=shared_embedding_column_creator, 3236 combiner=combiner, 3237 max_norm=max_norm, 3238 use_safe_embedding_lookup=use_safe_embedding_lookup) 3239 3240 @property 3241 def _is_v2_column(self): 3242 return True 3243 3244 @property 3245 def name(self): 3246 """See `FeatureColumn` base class.""" 3247 return '{}_shared_embedding'.format(self.categorical_column.name) 3248 3249 @property 3250 def parse_example_spec(self): 3251 """See `FeatureColumn` base class.""" 3252 return self.categorical_column.parse_example_spec 3253 3254 @property 3255 def _parse_example_spec(self): 3256 return _raise_shared_embedding_column_error() 3257 3258 def transform_feature(self, transformation_cache, state_manager): 3259 """See `FeatureColumn` base class.""" 3260 return transformation_cache.get(self.categorical_column, state_manager) 3261 3262 def _transform_feature(self, inputs): 3263 return _raise_shared_embedding_column_error() 3264 3265 @property 3266 def variable_shape(self): 3267 """See `DenseColumn` base class.""" 3268 return tensor_shape.TensorShape( 3269 [self.shared_embedding_column_creator.dimension]) 3270 3271 @property 3272 def _variable_shape(self): 3273 return _raise_shared_embedding_column_error() 3274 3275 def _get_dense_tensor_internal(self, transformation_cache, state_manager): 3276 """Private method that follows the signature of _get_dense_tensor.""" 3277 # This method is called from a variable_scope with name _var_scope_name, 3278 # which is shared among all shared embeddings. Open a name_scope here, so 3279 # that the ops for different columns have distinct names. 3280 with ops.name_scope(None, default_name=self.name): 3281 # Get sparse IDs and weights. 3282 sparse_tensors = self.categorical_column.get_sparse_tensors( 3283 transformation_cache, state_manager) 3284 sparse_ids = sparse_tensors.id_tensor 3285 sparse_weights = sparse_tensors.weight_tensor 3286 3287 embedding_weights = self.shared_embedding_column_creator.embedding_weights 3288 3289 sparse_id_rank = tensor_shape.dimension_value( 3290 sparse_ids.dense_shape.get_shape()[0]) 3291 embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse 3292 if (not self.use_safe_embedding_lookup and sparse_id_rank is not None and 3293 sparse_id_rank <= 2): 3294 embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2 3295 # Return embedding lookup result. 3296 return embedding_lookup_sparse( 3297 embedding_weights, 3298 sparse_ids, 3299 sparse_weights, 3300 combiner=self.combiner, 3301 name='%s_weights' % self.name, 3302 max_norm=self.max_norm) 3303 3304 def get_dense_tensor(self, transformation_cache, state_manager): 3305 """Returns the embedding lookup result.""" 3306 if isinstance(self.categorical_column, SequenceCategoricalColumn): 3307 raise ValueError( 3308 'In embedding_column: {}. ' 3309 'categorical_column must not be of type SequenceCategoricalColumn. ' 3310 'Suggested fix A: If you wish to use DenseFeatures, use a ' 3311 'non-sequence categorical_column_with_*. ' 3312 'Suggested fix B: If you wish to create sequence input, use ' 3313 'SequenceFeatures instead of DenseFeatures. ' 3314 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3315 self.categorical_column)) 3316 return self._get_dense_tensor_internal(transformation_cache, state_manager) 3317 3318 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 3319 return _raise_shared_embedding_column_error() 3320 3321 def get_sequence_dense_tensor(self, transformation_cache, state_manager): 3322 """See `SequenceDenseColumn` base class.""" 3323 if not isinstance(self.categorical_column, SequenceCategoricalColumn): 3324 raise ValueError( 3325 'In embedding_column: {}. ' 3326 'categorical_column must be of type SequenceCategoricalColumn ' 3327 'to use SequenceFeatures. ' 3328 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 3329 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 3330 self.categorical_column)) 3331 dense_tensor = self._get_dense_tensor_internal(transformation_cache, 3332 state_manager) 3333 sparse_tensors = self.categorical_column.get_sparse_tensors( 3334 transformation_cache, state_manager) 3335 sequence_length = fc_utils.sequence_length_from_sparse_tensor( 3336 sparse_tensors.id_tensor) 3337 return SequenceDenseColumn.TensorSequenceLengthPair( 3338 dense_tensor=dense_tensor, sequence_length=sequence_length) 3339 3340 def _get_sequence_dense_tensor(self, 3341 inputs, 3342 weight_collections=None, 3343 trainable=None): 3344 return _raise_shared_embedding_column_error() 3345 3346 @property 3347 def parents(self): 3348 """See 'FeatureColumn` base class.""" 3349 return [self.categorical_column] 3350 3351 3352def _check_shape(shape, key): 3353 """Returns shape if it's valid, raises error otherwise.""" 3354 assert shape is not None 3355 if not nest.is_nested(shape): 3356 shape = [shape] 3357 shape = tuple(shape) 3358 for dimension in shape: 3359 if not isinstance(dimension, int): 3360 raise TypeError('shape dimensions must be integer. ' 3361 'shape: {}, key: {}'.format(shape, key)) 3362 if dimension < 1: 3363 raise ValueError('shape dimensions must be greater than 0. ' 3364 'shape: {}, key: {}'.format(shape, key)) 3365 return shape 3366 3367 3368class HashedCategoricalColumn( 3369 CategoricalColumn, 3370 fc_old._CategoricalColumn, # pylint: disable=protected-access 3371 collections.namedtuple('HashedCategoricalColumn', 3372 ('key', 'hash_bucket_size', 'dtype'))): 3373 """see `categorical_column_with_hash_bucket`.""" 3374 3375 @property 3376 def _is_v2_column(self): 3377 return True 3378 3379 @property 3380 def name(self): 3381 """See `FeatureColumn` base class.""" 3382 return self.key 3383 3384 @property 3385 def parse_example_spec(self): 3386 """See `FeatureColumn` base class.""" 3387 return {self.key: parsing_ops.VarLenFeature(self.dtype)} 3388 3389 @property 3390 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3391 _FEATURE_COLUMN_DEPRECATION) 3392 def _parse_example_spec(self): 3393 return self.parse_example_spec 3394 3395 def _transform_input_tensor(self, input_tensor): 3396 """Hashes the values in the feature_column.""" 3397 if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor): 3398 raise ValueError('SparseColumn input must be a SparseTensor.') 3399 3400 fc_utils.assert_string_or_int( 3401 input_tensor.dtype, 3402 prefix='column_name: {} input_tensor'.format(self.key)) 3403 3404 if self.dtype.is_integer != input_tensor.dtype.is_integer: 3405 raise ValueError( 3406 'Column dtype and SparseTensors dtype must be compatible. ' 3407 'key: {}, column dtype: {}, tensor dtype: {}'.format( 3408 self.key, self.dtype, input_tensor.dtype)) 3409 3410 if self.dtype == dtypes.string: 3411 sparse_values = input_tensor.values 3412 else: 3413 sparse_values = string_ops.as_string(input_tensor.values) 3414 3415 sparse_id_values = string_ops.string_to_hash_bucket_fast( 3416 sparse_values, self.hash_bucket_size, name='lookup') 3417 return sparse_tensor_lib.SparseTensor( 3418 input_tensor.indices, sparse_id_values, input_tensor.dense_shape) 3419 3420 def transform_feature(self, transformation_cache, state_manager): 3421 """Hashes the values in the feature_column.""" 3422 input_tensor = _to_sparse_input_and_drop_ignore_values( 3423 transformation_cache.get(self.key, state_manager)) 3424 return self._transform_input_tensor(input_tensor) 3425 3426 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3427 _FEATURE_COLUMN_DEPRECATION) 3428 def _transform_feature(self, inputs): 3429 input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) 3430 return self._transform_input_tensor(input_tensor) 3431 3432 @property 3433 def num_buckets(self): 3434 """Returns number of buckets in this sparse feature.""" 3435 return self.hash_bucket_size 3436 3437 @property 3438 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3439 _FEATURE_COLUMN_DEPRECATION) 3440 def _num_buckets(self): 3441 return self.num_buckets 3442 3443 def get_sparse_tensors(self, transformation_cache, state_manager): 3444 """See `CategoricalColumn` base class.""" 3445 return CategoricalColumn.IdWeightPair( 3446 transformation_cache.get(self, state_manager), None) 3447 3448 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3449 _FEATURE_COLUMN_DEPRECATION) 3450 def _get_sparse_tensors(self, inputs, weight_collections=None, 3451 trainable=None): 3452 del weight_collections 3453 del trainable 3454 return CategoricalColumn.IdWeightPair(inputs.get(self), None) 3455 3456 @property 3457 def parents(self): 3458 """See 'FeatureColumn` base class.""" 3459 return [self.key] 3460 3461 def get_config(self): 3462 """See 'FeatureColumn` base class.""" 3463 config = dict(zip(self._fields, self)) 3464 config['dtype'] = self.dtype.name 3465 return config 3466 3467 @classmethod 3468 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3469 """See 'FeatureColumn` base class.""" 3470 _check_config_keys(config, cls._fields) 3471 kwargs = _standardize_and_copy_config(config) 3472 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 3473 return cls(**kwargs) 3474 3475 3476class VocabularyFileCategoricalColumn( 3477 CategoricalColumn, 3478 fc_old._CategoricalColumn, # pylint: disable=protected-access 3479 collections.namedtuple( 3480 'VocabularyFileCategoricalColumn', 3481 ('key', 'vocabulary_file', 'vocabulary_size', 'num_oov_buckets', 3482 'dtype', 'default_value', 'file_format'))): 3483 """See `categorical_column_with_vocabulary_file`.""" 3484 3485 def __new__(cls, 3486 key, 3487 vocabulary_file, 3488 vocabulary_size, 3489 num_oov_buckets, 3490 dtype, 3491 default_value, 3492 file_format=None): 3493 return super(VocabularyFileCategoricalColumn, cls).__new__( 3494 cls, 3495 key=key, 3496 vocabulary_file=vocabulary_file, 3497 vocabulary_size=vocabulary_size, 3498 num_oov_buckets=num_oov_buckets, 3499 dtype=dtype, 3500 default_value=default_value, 3501 file_format=file_format) 3502 3503 @property 3504 def _is_v2_column(self): 3505 return True 3506 3507 @property 3508 def name(self): 3509 """See `FeatureColumn` base class.""" 3510 return self.key 3511 3512 @property 3513 def parse_example_spec(self): 3514 """See `FeatureColumn` base class.""" 3515 return {self.key: parsing_ops.VarLenFeature(self.dtype)} 3516 3517 @property 3518 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3519 _FEATURE_COLUMN_DEPRECATION) 3520 def _parse_example_spec(self): 3521 return self.parse_example_spec 3522 3523 def _make_table_from_tfrecord_gzip_file(self, key_dtype, name): 3524 dataset = readers.TFRecordDataset( 3525 self.vocabulary_file, compression_type='GZIP') 3526 3527 def key_dtype_fn(key): 3528 return key if key_dtype is dtypes.string else string_ops.string_to_number( 3529 key, out_type=key_dtype) 3530 3531 return data_lookup_ops.index_table_from_dataset( 3532 dataset.map(key_dtype_fn), 3533 num_oov_buckets=self.num_oov_buckets, 3534 vocab_size=self.vocabulary_size, 3535 default_value=self.default_value, 3536 key_dtype=key_dtype, 3537 name=name) 3538 3539 def _make_table(self, key_dtype, state_manager): 3540 name = '{}_lookup'.format(self.key) 3541 if state_manager is None or not state_manager.has_resource(self, name): 3542 with ops.init_scope(): 3543 if self.file_format == 'tfrecord_gzip': 3544 table = self._make_table_from_tfrecord_gzip_file(key_dtype, name) 3545 else: 3546 table = lookup_ops.index_table_from_file( 3547 vocabulary_file=self.vocabulary_file, 3548 num_oov_buckets=self.num_oov_buckets, 3549 vocab_size=self.vocabulary_size, 3550 default_value=self.default_value, 3551 key_dtype=key_dtype, 3552 name=name) 3553 if state_manager is not None: 3554 state_manager.add_resource(self, name, table) 3555 else: 3556 # Reuse the table from the previous run. 3557 table = state_manager.get_resource(self, name) 3558 return table 3559 3560 def _transform_input_tensor(self, input_tensor, state_manager=None): 3561 """Creates a lookup table for the vocabulary.""" 3562 if self.dtype.is_integer != input_tensor.dtype.is_integer: 3563 raise ValueError( 3564 'Column dtype and SparseTensors dtype must be compatible. ' 3565 'key: {}, column dtype: {}, tensor dtype: {}'.format( 3566 self.key, self.dtype, input_tensor.dtype)) 3567 3568 fc_utils.assert_string_or_int( 3569 input_tensor.dtype, 3570 prefix='column_name: {} input_tensor'.format(self.key)) 3571 3572 key_dtype = self.dtype 3573 if input_tensor.dtype.is_integer: 3574 # `index_table_from_file` requires 64-bit integer keys. 3575 key_dtype = dtypes.int64 3576 input_tensor = math_ops.cast(input_tensor, dtypes.int64) 3577 return self._make_table(key_dtype, state_manager).lookup(input_tensor) 3578 3579 def transform_feature(self, transformation_cache, state_manager): 3580 """Creates a lookup table for the vocabulary.""" 3581 input_tensor = _to_sparse_input_and_drop_ignore_values( 3582 transformation_cache.get(self.key, state_manager)) 3583 return self._transform_input_tensor(input_tensor, state_manager) 3584 3585 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3586 _FEATURE_COLUMN_DEPRECATION) 3587 def _transform_feature(self, inputs): 3588 input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) 3589 return self._transform_input_tensor(input_tensor) 3590 3591 @property 3592 def num_buckets(self): 3593 """Returns number of buckets in this sparse feature.""" 3594 return self.vocabulary_size + self.num_oov_buckets 3595 3596 @property 3597 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3598 _FEATURE_COLUMN_DEPRECATION) 3599 def _num_buckets(self): 3600 return self.num_buckets 3601 3602 def get_sparse_tensors(self, transformation_cache, state_manager): 3603 """See `CategoricalColumn` base class.""" 3604 return CategoricalColumn.IdWeightPair( 3605 transformation_cache.get(self, state_manager), None) 3606 3607 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3608 _FEATURE_COLUMN_DEPRECATION) 3609 def _get_sparse_tensors(self, inputs, weight_collections=None, 3610 trainable=None): 3611 del weight_collections 3612 del trainable 3613 return CategoricalColumn.IdWeightPair(inputs.get(self), None) 3614 3615 @property 3616 def parents(self): 3617 """See 'FeatureColumn` base class.""" 3618 return [self.key] 3619 3620 def get_config(self): 3621 """See 'FeatureColumn` base class.""" 3622 config = dict(zip(self._fields, self)) 3623 config['dtype'] = self.dtype.name 3624 return config 3625 3626 @classmethod 3627 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3628 """See 'FeatureColumn` base class.""" 3629 _check_config_keys(config, cls._fields) 3630 kwargs = _standardize_and_copy_config(config) 3631 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 3632 return cls(**kwargs) 3633 3634 3635class VocabularyListCategoricalColumn( 3636 CategoricalColumn, 3637 fc_old._CategoricalColumn, # pylint: disable=protected-access 3638 collections.namedtuple( 3639 'VocabularyListCategoricalColumn', 3640 ('key', 'vocabulary_list', 'dtype', 'default_value', 'num_oov_buckets')) 3641): 3642 """See `categorical_column_with_vocabulary_list`.""" 3643 3644 @property 3645 def _is_v2_column(self): 3646 return True 3647 3648 @property 3649 def name(self): 3650 """See `FeatureColumn` base class.""" 3651 return self.key 3652 3653 @property 3654 def parse_example_spec(self): 3655 """See `FeatureColumn` base class.""" 3656 return {self.key: parsing_ops.VarLenFeature(self.dtype)} 3657 3658 @property 3659 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3660 _FEATURE_COLUMN_DEPRECATION) 3661 def _parse_example_spec(self): 3662 return self.parse_example_spec 3663 3664 def _transform_input_tensor(self, input_tensor, state_manager=None): 3665 """Creates a lookup table for the vocabulary list.""" 3666 if self.dtype.is_integer != input_tensor.dtype.is_integer: 3667 raise ValueError( 3668 'Column dtype and SparseTensors dtype must be compatible. ' 3669 'key: {}, column dtype: {}, tensor dtype: {}'.format( 3670 self.key, self.dtype, input_tensor.dtype)) 3671 3672 fc_utils.assert_string_or_int( 3673 input_tensor.dtype, 3674 prefix='column_name: {} input_tensor'.format(self.key)) 3675 3676 key_dtype = self.dtype 3677 if input_tensor.dtype.is_integer: 3678 # `index_table_from_tensor` requires 64-bit integer keys. 3679 key_dtype = dtypes.int64 3680 input_tensor = math_ops.cast(input_tensor, dtypes.int64) 3681 3682 name = '{}_lookup'.format(self.key) 3683 if state_manager is None or not state_manager.has_resource(self, name): 3684 with ops.init_scope(): 3685 table = lookup_ops.index_table_from_tensor( 3686 vocabulary_list=tuple(self.vocabulary_list), 3687 default_value=self.default_value, 3688 num_oov_buckets=self.num_oov_buckets, 3689 dtype=key_dtype, 3690 name=name) 3691 if state_manager is not None: 3692 state_manager.add_resource(self, name, table) 3693 else: 3694 # Reuse the table from the previous run. 3695 table = state_manager.get_resource(self, name) 3696 return table.lookup(input_tensor) 3697 3698 def transform_feature(self, transformation_cache, state_manager): 3699 """Creates a lookup table for the vocabulary list.""" 3700 input_tensor = _to_sparse_input_and_drop_ignore_values( 3701 transformation_cache.get(self.key, state_manager)) 3702 return self._transform_input_tensor(input_tensor, state_manager) 3703 3704 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3705 _FEATURE_COLUMN_DEPRECATION) 3706 def _transform_feature(self, inputs): 3707 input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) 3708 return self._transform_input_tensor(input_tensor) 3709 3710 @property 3711 def num_buckets(self): 3712 """Returns number of buckets in this sparse feature.""" 3713 return len(self.vocabulary_list) + self.num_oov_buckets 3714 3715 @property 3716 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3717 _FEATURE_COLUMN_DEPRECATION) 3718 def _num_buckets(self): 3719 return self.num_buckets 3720 3721 def get_sparse_tensors(self, transformation_cache, state_manager): 3722 """See `CategoricalColumn` base class.""" 3723 return CategoricalColumn.IdWeightPair( 3724 transformation_cache.get(self, state_manager), None) 3725 3726 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3727 _FEATURE_COLUMN_DEPRECATION) 3728 def _get_sparse_tensors(self, inputs, weight_collections=None, 3729 trainable=None): 3730 del weight_collections 3731 del trainable 3732 return CategoricalColumn.IdWeightPair(inputs.get(self), None) 3733 3734 @property 3735 def parents(self): 3736 """See 'FeatureColumn` base class.""" 3737 return [self.key] 3738 3739 def get_config(self): 3740 """See 'FeatureColumn` base class.""" 3741 config = dict(zip(self._fields, self)) 3742 config['dtype'] = self.dtype.name 3743 return config 3744 3745 @classmethod 3746 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3747 """See 'FeatureColumn` base class.""" 3748 _check_config_keys(config, cls._fields) 3749 kwargs = _standardize_and_copy_config(config) 3750 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 3751 return cls(**kwargs) 3752 3753 3754class IdentityCategoricalColumn( 3755 CategoricalColumn, 3756 fc_old._CategoricalColumn, # pylint: disable=protected-access 3757 collections.namedtuple('IdentityCategoricalColumn', 3758 ('key', 'number_buckets', 'default_value'))): 3759 3760 """See `categorical_column_with_identity`.""" 3761 3762 @property 3763 def _is_v2_column(self): 3764 return True 3765 3766 @property 3767 def name(self): 3768 """See `FeatureColumn` base class.""" 3769 return self.key 3770 3771 @property 3772 def parse_example_spec(self): 3773 """See `FeatureColumn` base class.""" 3774 return {self.key: parsing_ops.VarLenFeature(dtypes.int64)} 3775 3776 @property 3777 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3778 _FEATURE_COLUMN_DEPRECATION) 3779 def _parse_example_spec(self): 3780 return self.parse_example_spec 3781 3782 def _transform_input_tensor(self, input_tensor): 3783 """Returns a SparseTensor with identity values.""" 3784 if not input_tensor.dtype.is_integer: 3785 raise ValueError( 3786 'Invalid input, not integer. key: {} dtype: {}'.format( 3787 self.key, input_tensor.dtype)) 3788 values = input_tensor.values 3789 if input_tensor.values.dtype != dtypes.int64: 3790 values = math_ops.cast(values, dtypes.int64, name='values') 3791 if self.default_value is not None: 3792 values = math_ops.cast(input_tensor.values, dtypes.int64, name='values') 3793 num_buckets = math_ops.cast( 3794 self.num_buckets, dtypes.int64, name='num_buckets') 3795 zero = math_ops.cast(0, dtypes.int64, name='zero') 3796 # Assign default for out-of-range values. 3797 values = array_ops.where_v2( 3798 math_ops.logical_or( 3799 values < zero, values >= num_buckets, name='out_of_range'), 3800 array_ops.fill( 3801 dims=array_ops.shape(values), 3802 value=math_ops.cast(self.default_value, dtypes.int64), 3803 name='default_values'), values) 3804 3805 return sparse_tensor_lib.SparseTensor( 3806 indices=input_tensor.indices, 3807 values=values, 3808 dense_shape=input_tensor.dense_shape) 3809 3810 def transform_feature(self, transformation_cache, state_manager): 3811 """Returns a SparseTensor with identity values.""" 3812 input_tensor = _to_sparse_input_and_drop_ignore_values( 3813 transformation_cache.get(self.key, state_manager)) 3814 return self._transform_input_tensor(input_tensor) 3815 3816 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3817 _FEATURE_COLUMN_DEPRECATION) 3818 def _transform_feature(self, inputs): 3819 input_tensor = _to_sparse_input_and_drop_ignore_values(inputs.get(self.key)) 3820 return self._transform_input_tensor(input_tensor) 3821 3822 @property 3823 def num_buckets(self): 3824 """Returns number of buckets in this sparse feature.""" 3825 return self.number_buckets 3826 3827 @property 3828 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3829 _FEATURE_COLUMN_DEPRECATION) 3830 def _num_buckets(self): 3831 return self.num_buckets 3832 3833 def get_sparse_tensors(self, transformation_cache, state_manager): 3834 """See `CategoricalColumn` base class.""" 3835 return CategoricalColumn.IdWeightPair( 3836 transformation_cache.get(self, state_manager), None) 3837 3838 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3839 _FEATURE_COLUMN_DEPRECATION) 3840 def _get_sparse_tensors(self, inputs, weight_collections=None, 3841 trainable=None): 3842 del weight_collections 3843 del trainable 3844 return CategoricalColumn.IdWeightPair(inputs.get(self), None) 3845 3846 @property 3847 def parents(self): 3848 """See 'FeatureColumn` base class.""" 3849 return [self.key] 3850 3851 def get_config(self): 3852 """See 'FeatureColumn` base class.""" 3853 return dict(zip(self._fields, self)) 3854 3855 @classmethod 3856 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3857 """See 'FeatureColumn` base class.""" 3858 _check_config_keys(config, cls._fields) 3859 kwargs = _standardize_and_copy_config(config) 3860 return cls(**kwargs) 3861 3862 3863class WeightedCategoricalColumn( 3864 CategoricalColumn, 3865 fc_old._CategoricalColumn, # pylint: disable=protected-access 3866 collections.namedtuple( 3867 'WeightedCategoricalColumn', 3868 ('categorical_column', 'weight_feature_key', 'dtype'))): 3869 """See `weighted_categorical_column`.""" 3870 3871 @property 3872 def _is_v2_column(self): 3873 return (isinstance(self.categorical_column, FeatureColumn) and 3874 self.categorical_column._is_v2_column) # pylint: disable=protected-access 3875 3876 @property 3877 def name(self): 3878 """See `FeatureColumn` base class.""" 3879 return '{}_weighted_by_{}'.format( 3880 self.categorical_column.name, self.weight_feature_key) 3881 3882 @property 3883 def parse_example_spec(self): 3884 """See `FeatureColumn` base class.""" 3885 config = self.categorical_column.parse_example_spec 3886 if self.weight_feature_key in config: 3887 raise ValueError('Parse config {} already exists for {}.'.format( 3888 config[self.weight_feature_key], self.weight_feature_key)) 3889 config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype) 3890 return config 3891 3892 @property 3893 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3894 _FEATURE_COLUMN_DEPRECATION) 3895 def _parse_example_spec(self): 3896 config = self.categorical_column._parse_example_spec # pylint: disable=protected-access 3897 if self.weight_feature_key in config: 3898 raise ValueError('Parse config {} already exists for {}.'.format( 3899 config[self.weight_feature_key], self.weight_feature_key)) 3900 config[self.weight_feature_key] = parsing_ops.VarLenFeature(self.dtype) 3901 return config 3902 3903 @property 3904 def num_buckets(self): 3905 """See `DenseColumn` base class.""" 3906 return self.categorical_column.num_buckets 3907 3908 @property 3909 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3910 _FEATURE_COLUMN_DEPRECATION) 3911 def _num_buckets(self): 3912 return self.categorical_column._num_buckets # pylint: disable=protected-access 3913 3914 def _transform_weight_tensor(self, weight_tensor): 3915 if weight_tensor is None: 3916 raise ValueError('Missing weights {}.'.format(self.weight_feature_key)) 3917 weight_tensor = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( 3918 weight_tensor) 3919 if self.dtype != weight_tensor.dtype.base_dtype: 3920 raise ValueError('Bad dtype, expected {}, but got {}.'.format( 3921 self.dtype, weight_tensor.dtype)) 3922 if not isinstance(weight_tensor, sparse_tensor_lib.SparseTensor): 3923 # The weight tensor can be a regular Tensor. In this case, sparsify it. 3924 weight_tensor = _to_sparse_input_and_drop_ignore_values( 3925 weight_tensor, ignore_value=0.0) 3926 if not weight_tensor.dtype.is_floating: 3927 weight_tensor = math_ops.cast(weight_tensor, dtypes.float32) 3928 return weight_tensor 3929 3930 def transform_feature(self, transformation_cache, state_manager): 3931 """Applies weights to tensor generated from `categorical_column`'.""" 3932 weight_tensor = transformation_cache.get(self.weight_feature_key, 3933 state_manager) 3934 sparse_weight_tensor = self._transform_weight_tensor(weight_tensor) 3935 sparse_categorical_tensor = _to_sparse_input_and_drop_ignore_values( 3936 transformation_cache.get(self.categorical_column, state_manager)) 3937 return (sparse_categorical_tensor, sparse_weight_tensor) 3938 3939 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3940 _FEATURE_COLUMN_DEPRECATION) 3941 def _transform_feature(self, inputs): 3942 """Applies weights to tensor generated from `categorical_column`'.""" 3943 weight_tensor = inputs.get(self.weight_feature_key) 3944 weight_tensor = self._transform_weight_tensor(weight_tensor) 3945 return (inputs.get(self.categorical_column), weight_tensor) 3946 3947 def get_sparse_tensors(self, transformation_cache, state_manager): 3948 """See `CategoricalColumn` base class.""" 3949 tensors = transformation_cache.get(self, state_manager) 3950 return CategoricalColumn.IdWeightPair(tensors[0], tensors[1]) 3951 3952 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 3953 _FEATURE_COLUMN_DEPRECATION) 3954 def _get_sparse_tensors(self, inputs, weight_collections=None, 3955 trainable=None): 3956 del weight_collections 3957 del trainable 3958 tensors = inputs.get(self) 3959 return CategoricalColumn.IdWeightPair(tensors[0], tensors[1]) 3960 3961 @property 3962 def parents(self): 3963 """See 'FeatureColumn` base class.""" 3964 return [self.categorical_column, self.weight_feature_key] 3965 3966 def get_config(self): 3967 """See 'FeatureColumn` base class.""" 3968 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 3969 config = dict(zip(self._fields, self)) 3970 config['categorical_column'] = serialize_feature_column( 3971 self.categorical_column) 3972 config['dtype'] = self.dtype.name 3973 return config 3974 3975 @classmethod 3976 def from_config(cls, config, custom_objects=None, columns_by_name=None): 3977 """See 'FeatureColumn` base class.""" 3978 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 3979 _check_config_keys(config, cls._fields) 3980 kwargs = _standardize_and_copy_config(config) 3981 kwargs['categorical_column'] = deserialize_feature_column( 3982 config['categorical_column'], custom_objects, columns_by_name) 3983 kwargs['dtype'] = dtypes.as_dtype(config['dtype']) 3984 return cls(**kwargs) 3985 3986 3987class CrossedColumn( 3988 CategoricalColumn, 3989 fc_old._CategoricalColumn, # pylint: disable=protected-access 3990 collections.namedtuple('CrossedColumn', 3991 ('keys', 'hash_bucket_size', 'hash_key'))): 3992 """See `crossed_column`.""" 3993 3994 @property 3995 def _is_v2_column(self): 3996 for key in _collect_leaf_level_keys(self): 3997 if isinstance(key, six.string_types): 3998 continue 3999 if not isinstance(key, FeatureColumn): 4000 return False 4001 if not key._is_v2_column: # pylint: disable=protected-access 4002 return False 4003 return True 4004 4005 @property 4006 def name(self): 4007 """See `FeatureColumn` base class.""" 4008 feature_names = [] 4009 for key in _collect_leaf_level_keys(self): 4010 if isinstance(key, (FeatureColumn, fc_old._FeatureColumn)): # pylint: disable=protected-access 4011 feature_names.append(key.name) 4012 else: # key must be a string 4013 feature_names.append(key) 4014 return '_X_'.join(sorted(feature_names)) 4015 4016 @property 4017 def parse_example_spec(self): 4018 """See `FeatureColumn` base class.""" 4019 config = {} 4020 for key in self.keys: 4021 if isinstance(key, FeatureColumn): 4022 config.update(key.parse_example_spec) 4023 elif isinstance(key, fc_old._FeatureColumn): # pylint: disable=protected-access 4024 config.update(key._parse_example_spec) # pylint: disable=protected-access 4025 else: # key must be a string 4026 config.update({key: parsing_ops.VarLenFeature(dtypes.string)}) 4027 return config 4028 4029 @property 4030 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4031 _FEATURE_COLUMN_DEPRECATION) 4032 def _parse_example_spec(self): 4033 return self.parse_example_spec 4034 4035 def transform_feature(self, transformation_cache, state_manager): 4036 """Generates a hashed sparse cross from the input tensors.""" 4037 feature_tensors = [] 4038 for key in _collect_leaf_level_keys(self): 4039 if isinstance(key, six.string_types): 4040 feature_tensors.append(transformation_cache.get(key, state_manager)) 4041 elif isinstance(key, (fc_old._CategoricalColumn, CategoricalColumn)): # pylint: disable=protected-access 4042 ids_and_weights = key.get_sparse_tensors(transformation_cache, 4043 state_manager) 4044 if ids_and_weights.weight_tensor is not None: 4045 raise ValueError( 4046 'crossed_column does not support weight_tensor, but the given ' 4047 'column populates weight_tensor. ' 4048 'Given column: {}'.format(key.name)) 4049 feature_tensors.append(ids_and_weights.id_tensor) 4050 else: 4051 raise ValueError('Unsupported column type. Given: {}'.format(key)) 4052 return sparse_ops.sparse_cross_hashed( 4053 inputs=feature_tensors, 4054 num_buckets=self.hash_bucket_size, 4055 hash_key=self.hash_key) 4056 4057 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4058 _FEATURE_COLUMN_DEPRECATION) 4059 def _transform_feature(self, inputs): 4060 """Generates a hashed sparse cross from the input tensors.""" 4061 feature_tensors = [] 4062 for key in _collect_leaf_level_keys(self): 4063 if isinstance(key, six.string_types): 4064 feature_tensors.append(inputs.get(key)) 4065 elif isinstance(key, (CategoricalColumn, fc_old._CategoricalColumn)): # pylint: disable=protected-access 4066 ids_and_weights = key._get_sparse_tensors(inputs) # pylint: disable=protected-access 4067 if ids_and_weights.weight_tensor is not None: 4068 raise ValueError( 4069 'crossed_column does not support weight_tensor, but the given ' 4070 'column populates weight_tensor. ' 4071 'Given column: {}'.format(key.name)) 4072 feature_tensors.append(ids_and_weights.id_tensor) 4073 else: 4074 raise ValueError('Unsupported column type. Given: {}'.format(key)) 4075 return sparse_ops.sparse_cross_hashed( 4076 inputs=feature_tensors, 4077 num_buckets=self.hash_bucket_size, 4078 hash_key=self.hash_key) 4079 4080 @property 4081 def num_buckets(self): 4082 """Returns number of buckets in this sparse feature.""" 4083 return self.hash_bucket_size 4084 4085 @property 4086 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4087 _FEATURE_COLUMN_DEPRECATION) 4088 def _num_buckets(self): 4089 return self.num_buckets 4090 4091 def get_sparse_tensors(self, transformation_cache, state_manager): 4092 """See `CategoricalColumn` base class.""" 4093 return CategoricalColumn.IdWeightPair( 4094 transformation_cache.get(self, state_manager), None) 4095 4096 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4097 _FEATURE_COLUMN_DEPRECATION) 4098 def _get_sparse_tensors(self, inputs, weight_collections=None, 4099 trainable=None): 4100 """See `CategoricalColumn` base class.""" 4101 del weight_collections 4102 del trainable 4103 return CategoricalColumn.IdWeightPair(inputs.get(self), None) 4104 4105 @property 4106 def parents(self): 4107 """See 'FeatureColumn` base class.""" 4108 return list(self.keys) 4109 4110 def get_config(self): 4111 """See 'FeatureColumn` base class.""" 4112 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 4113 config = dict(zip(self._fields, self)) 4114 config['keys'] = tuple([serialize_feature_column(fc) for fc in self.keys]) 4115 return config 4116 4117 @classmethod 4118 def from_config(cls, config, custom_objects=None, columns_by_name=None): 4119 """See 'FeatureColumn` base class.""" 4120 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 4121 _check_config_keys(config, cls._fields) 4122 kwargs = _standardize_and_copy_config(config) 4123 kwargs['keys'] = tuple([ 4124 deserialize_feature_column(c, custom_objects, columns_by_name) 4125 for c in config['keys'] 4126 ]) 4127 return cls(**kwargs) 4128 4129 4130def _collect_leaf_level_keys(cross): 4131 """Collects base keys by expanding all nested crosses. 4132 4133 Args: 4134 cross: A `CrossedColumn`. 4135 4136 Returns: 4137 A list of strings or `CategoricalColumn` instances. 4138 """ 4139 leaf_level_keys = [] 4140 for k in cross.keys: 4141 if isinstance(k, CrossedColumn): 4142 leaf_level_keys.extend(_collect_leaf_level_keys(k)) 4143 else: 4144 leaf_level_keys.append(k) 4145 return leaf_level_keys 4146 4147 4148def _prune_invalid_ids(sparse_ids, sparse_weights): 4149 """Prune invalid IDs (< 0) from the input ids and weights.""" 4150 is_id_valid = math_ops.greater_equal(sparse_ids.values, 0) 4151 if sparse_weights is not None: 4152 is_id_valid = math_ops.logical_and( 4153 is_id_valid, 4154 array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool)) 4155 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid) 4156 if sparse_weights is not None: 4157 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid) 4158 return sparse_ids, sparse_weights 4159 4160 4161def _prune_invalid_weights(sparse_ids, sparse_weights): 4162 """Prune invalid weights (< 0) from the input ids and weights.""" 4163 if sparse_weights is not None: 4164 is_weights_valid = math_ops.greater(sparse_weights.values, 0) 4165 sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid) 4166 sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid) 4167 return sparse_ids, sparse_weights 4168 4169 4170class IndicatorColumn( 4171 DenseColumn, 4172 SequenceDenseColumn, 4173 fc_old._DenseColumn, # pylint: disable=protected-access 4174 fc_old._SequenceDenseColumn, # pylint: disable=protected-access 4175 collections.namedtuple('IndicatorColumn', ('categorical_column'))): 4176 """Represents a one-hot column for use in deep networks. 4177 4178 Args: 4179 categorical_column: A `CategoricalColumn` which is created by 4180 `categorical_column_with_*` function. 4181 """ 4182 4183 @property 4184 def _is_v2_column(self): 4185 return (isinstance(self.categorical_column, FeatureColumn) and 4186 self.categorical_column._is_v2_column) # pylint: disable=protected-access 4187 4188 @property 4189 def name(self): 4190 """See `FeatureColumn` base class.""" 4191 return '{}_indicator'.format(self.categorical_column.name) 4192 4193 def _transform_id_weight_pair(self, id_weight_pair, size): 4194 id_tensor = id_weight_pair.id_tensor 4195 weight_tensor = id_weight_pair.weight_tensor 4196 4197 # If the underlying column is weighted, return the input as a dense tensor. 4198 if weight_tensor is not None: 4199 weighted_column = sparse_ops.sparse_merge( 4200 sp_ids=id_tensor, sp_values=weight_tensor, vocab_size=int(size)) 4201 # Remove (?, -1) index. 4202 weighted_column = sparse_ops.sparse_slice(weighted_column, [0, 0], 4203 weighted_column.dense_shape) 4204 # Use scatter_nd to merge duplicated indices if existed, 4205 # instead of sparse_tensor_to_dense. 4206 return array_ops.scatter_nd(weighted_column.indices, 4207 weighted_column.values, 4208 weighted_column.dense_shape) 4209 4210 dense_id_tensor = sparse_ops.sparse_tensor_to_dense( 4211 id_tensor, default_value=-1) 4212 4213 # One hot must be float for tf.concat reasons since all other inputs to 4214 # input_layer are float32. 4215 one_hot_id_tensor = array_ops.one_hot( 4216 dense_id_tensor, depth=size, on_value=1.0, off_value=0.0) 4217 4218 # Reduce to get a multi-hot per example. 4219 return math_ops.reduce_sum(one_hot_id_tensor, axis=[-2]) 4220 4221 def transform_feature(self, transformation_cache, state_manager): 4222 """Returns dense `Tensor` representing feature. 4223 4224 Args: 4225 transformation_cache: A `FeatureTransformationCache` object to access 4226 features. 4227 state_manager: A `StateManager` to create / access resources such as 4228 lookup tables. 4229 4230 Returns: 4231 Transformed feature `Tensor`. 4232 4233 Raises: 4234 ValueError: if input rank is not known at graph building time. 4235 """ 4236 id_weight_pair = self.categorical_column.get_sparse_tensors( 4237 transformation_cache, state_manager) 4238 return self._transform_id_weight_pair(id_weight_pair, 4239 self.variable_shape[-1]) 4240 4241 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4242 _FEATURE_COLUMN_DEPRECATION) 4243 def _transform_feature(self, inputs): 4244 id_weight_pair = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access 4245 return self._transform_id_weight_pair(id_weight_pair, 4246 self._variable_shape[-1]) 4247 4248 @property 4249 def parse_example_spec(self): 4250 """See `FeatureColumn` base class.""" 4251 return self.categorical_column.parse_example_spec 4252 4253 @property 4254 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4255 _FEATURE_COLUMN_DEPRECATION) 4256 def _parse_example_spec(self): 4257 return self.categorical_column._parse_example_spec # pylint: disable=protected-access 4258 4259 @property 4260 def variable_shape(self): 4261 """Returns a `TensorShape` representing the shape of the dense `Tensor`.""" 4262 if isinstance(self.categorical_column, FeatureColumn): 4263 return tensor_shape.TensorShape([1, self.categorical_column.num_buckets]) 4264 else: 4265 return tensor_shape.TensorShape([1, self.categorical_column._num_buckets]) # pylint: disable=protected-access 4266 4267 @property 4268 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4269 _FEATURE_COLUMN_DEPRECATION) 4270 def _variable_shape(self): 4271 return tensor_shape.TensorShape([1, self.categorical_column._num_buckets]) # pylint: disable=protected-access 4272 4273 def get_dense_tensor(self, transformation_cache, state_manager): 4274 """Returns dense `Tensor` representing feature. 4275 4276 Args: 4277 transformation_cache: A `FeatureTransformationCache` object to access 4278 features. 4279 state_manager: A `StateManager` to create / access resources such as 4280 lookup tables. 4281 4282 Returns: 4283 Dense `Tensor` created within `transform_feature`. 4284 4285 Raises: 4286 ValueError: If `categorical_column` is a `SequenceCategoricalColumn`. 4287 """ 4288 if isinstance(self.categorical_column, SequenceCategoricalColumn): 4289 raise ValueError( 4290 'In indicator_column: {}. ' 4291 'categorical_column must not be of type SequenceCategoricalColumn. ' 4292 'Suggested fix A: If you wish to use DenseFeatures, use a ' 4293 'non-sequence categorical_column_with_*. ' 4294 'Suggested fix B: If you wish to create sequence input, use ' 4295 'SequenceFeatures instead of DenseFeatures. ' 4296 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 4297 self.categorical_column)) 4298 # Feature has been already transformed. Return the intermediate 4299 # representation created by transform_feature. 4300 return transformation_cache.get(self, state_manager) 4301 4302 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4303 _FEATURE_COLUMN_DEPRECATION) 4304 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 4305 del weight_collections 4306 del trainable 4307 if isinstance( 4308 self.categorical_column, 4309 (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access 4310 raise ValueError( 4311 'In indicator_column: {}. ' 4312 'categorical_column must not be of type _SequenceCategoricalColumn. ' 4313 'Suggested fix A: If you wish to use DenseFeatures, use a ' 4314 'non-sequence categorical_column_with_*. ' 4315 'Suggested fix B: If you wish to create sequence input, use ' 4316 'SequenceFeatures instead of DenseFeatures. ' 4317 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 4318 self.categorical_column)) 4319 # Feature has been already transformed. Return the intermediate 4320 # representation created by transform_feature. 4321 return inputs.get(self) 4322 4323 def get_sequence_dense_tensor(self, transformation_cache, state_manager): 4324 """See `SequenceDenseColumn` base class.""" 4325 if not isinstance(self.categorical_column, SequenceCategoricalColumn): 4326 raise ValueError( 4327 'In indicator_column: {}. ' 4328 'categorical_column must be of type SequenceCategoricalColumn ' 4329 'to use SequenceFeatures. ' 4330 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 4331 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 4332 self.categorical_column)) 4333 # Feature has been already transformed. Return the intermediate 4334 # representation created by transform_feature. 4335 dense_tensor = transformation_cache.get(self, state_manager) 4336 sparse_tensors = self.categorical_column.get_sparse_tensors( 4337 transformation_cache, state_manager) 4338 sequence_length = fc_utils.sequence_length_from_sparse_tensor( 4339 sparse_tensors.id_tensor) 4340 return SequenceDenseColumn.TensorSequenceLengthPair( 4341 dense_tensor=dense_tensor, sequence_length=sequence_length) 4342 4343 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4344 _FEATURE_COLUMN_DEPRECATION) 4345 def _get_sequence_dense_tensor(self, 4346 inputs, 4347 weight_collections=None, 4348 trainable=None): 4349 # Do nothing with weight_collections and trainable since no variables are 4350 # created in this function. 4351 del weight_collections 4352 del trainable 4353 if not isinstance( 4354 self.categorical_column, 4355 (SequenceCategoricalColumn, fc_old._SequenceCategoricalColumn)): # pylint: disable=protected-access 4356 raise ValueError( 4357 'In indicator_column: {}. ' 4358 'categorical_column must be of type _SequenceCategoricalColumn ' 4359 'to use SequenceFeatures. ' 4360 'Suggested fix: Use one of sequence_categorical_column_with_*. ' 4361 'Given (type {}): {}'.format(self.name, type(self.categorical_column), 4362 self.categorical_column)) 4363 # Feature has been already transformed. Return the intermediate 4364 # representation created by _transform_feature. 4365 dense_tensor = inputs.get(self) 4366 sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access 4367 sequence_length = fc_utils.sequence_length_from_sparse_tensor( 4368 sparse_tensors.id_tensor) 4369 return SequenceDenseColumn.TensorSequenceLengthPair( 4370 dense_tensor=dense_tensor, sequence_length=sequence_length) 4371 4372 @property 4373 def parents(self): 4374 """See 'FeatureColumn` base class.""" 4375 return [self.categorical_column] 4376 4377 def get_config(self): 4378 """See 'FeatureColumn` base class.""" 4379 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 4380 config = dict(zip(self._fields, self)) 4381 config['categorical_column'] = serialize_feature_column( 4382 self.categorical_column) 4383 return config 4384 4385 @classmethod 4386 def from_config(cls, config, custom_objects=None, columns_by_name=None): 4387 """See 'FeatureColumn` base class.""" 4388 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 4389 _check_config_keys(config, cls._fields) 4390 kwargs = _standardize_and_copy_config(config) 4391 kwargs['categorical_column'] = deserialize_feature_column( 4392 config['categorical_column'], custom_objects, columns_by_name) 4393 return cls(**kwargs) 4394 4395 4396def _verify_static_batch_size_equality(tensors, columns): 4397 """Verify equality between static batch sizes. 4398 4399 Args: 4400 tensors: iterable of input tensors. 4401 columns: Corresponding feature columns. 4402 4403 Raises: 4404 ValueError: in case of mismatched batch sizes. 4405 """ 4406 # bath_size is a Dimension object. 4407 expected_batch_size = None 4408 for i in range(0, len(tensors)): 4409 batch_size = tensor_shape.Dimension(tensor_shape.dimension_value( 4410 tensors[i].shape[0])) 4411 if batch_size.value is not None: 4412 if expected_batch_size is None: 4413 bath_size_column_index = i 4414 expected_batch_size = batch_size 4415 elif not expected_batch_size.is_compatible_with(batch_size): 4416 raise ValueError( 4417 'Batch size (first dimension) of each feature must be same. ' 4418 'Batch size of columns ({}, {}): ({}, {})'.format( 4419 columns[bath_size_column_index].name, columns[i].name, 4420 expected_batch_size, batch_size)) 4421 4422 4423class SequenceCategoricalColumn( 4424 CategoricalColumn, 4425 fc_old._SequenceCategoricalColumn, # pylint: disable=protected-access 4426 collections.namedtuple('SequenceCategoricalColumn', 4427 ('categorical_column'))): 4428 """Represents sequences of categorical data.""" 4429 4430 @property 4431 def _is_v2_column(self): 4432 return (isinstance(self.categorical_column, FeatureColumn) and 4433 self.categorical_column._is_v2_column) # pylint: disable=protected-access 4434 4435 @property 4436 def name(self): 4437 """See `FeatureColumn` base class.""" 4438 return self.categorical_column.name 4439 4440 @property 4441 def parse_example_spec(self): 4442 """See `FeatureColumn` base class.""" 4443 return self.categorical_column.parse_example_spec 4444 4445 @property 4446 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4447 _FEATURE_COLUMN_DEPRECATION) 4448 def _parse_example_spec(self): 4449 return self.categorical_column._parse_example_spec # pylint: disable=protected-access 4450 4451 def transform_feature(self, transformation_cache, state_manager): 4452 """See `FeatureColumn` base class.""" 4453 return self.categorical_column.transform_feature(transformation_cache, 4454 state_manager) 4455 4456 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4457 _FEATURE_COLUMN_DEPRECATION) 4458 def _transform_feature(self, inputs): 4459 return self.categorical_column._transform_feature(inputs) # pylint: disable=protected-access 4460 4461 @property 4462 def num_buckets(self): 4463 """Returns number of buckets in this sparse feature.""" 4464 return self.categorical_column.num_buckets 4465 4466 @property 4467 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4468 _FEATURE_COLUMN_DEPRECATION) 4469 def _num_buckets(self): 4470 return self.categorical_column._num_buckets # pylint: disable=protected-access 4471 4472 def _get_sparse_tensors_helper(self, sparse_tensors): 4473 id_tensor = sparse_tensors.id_tensor 4474 weight_tensor = sparse_tensors.weight_tensor 4475 # Expands third dimension, if necessary so that embeddings are not 4476 # combined during embedding lookup. If the tensor is already 3D, leave 4477 # as-is. 4478 shape = array_ops.shape(id_tensor) 4479 # Compute the third dimension explicitly instead of setting it to -1, as 4480 # that doesn't work for dynamically shaped tensors with 0-length at runtime. 4481 # This happens for empty sequences. 4482 target_shape = [shape[0], shape[1], math_ops.reduce_prod(shape[2:])] 4483 id_tensor = sparse_ops.sparse_reshape(id_tensor, target_shape) 4484 if weight_tensor is not None: 4485 weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape) 4486 return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor) 4487 4488 def get_sparse_tensors(self, transformation_cache, state_manager): 4489 """Returns an IdWeightPair. 4490 4491 `IdWeightPair` is a pair of `SparseTensor`s which represents ids and 4492 weights. 4493 4494 `IdWeightPair.id_tensor` is typically a `batch_size` x `num_buckets` 4495 `SparseTensor` of `int64`. `IdWeightPair.weight_tensor` is either a 4496 `SparseTensor` of `float` or `None` to indicate all weights should be 4497 taken to be 1. If specified, `weight_tensor` must have exactly the same 4498 shape and indices as `sp_ids`. Expected `SparseTensor` is same as parsing 4499 output of a `VarLenFeature` which is a ragged matrix. 4500 4501 Args: 4502 transformation_cache: A `FeatureTransformationCache` object to access 4503 features. 4504 state_manager: A `StateManager` to create / access resources such as 4505 lookup tables. 4506 """ 4507 sparse_tensors = self.categorical_column.get_sparse_tensors( 4508 transformation_cache, state_manager) 4509 return self._get_sparse_tensors_helper(sparse_tensors) 4510 4511 @deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE, 4512 _FEATURE_COLUMN_DEPRECATION) 4513 def _get_sparse_tensors(self, inputs, weight_collections=None, 4514 trainable=None): 4515 sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access 4516 return self._get_sparse_tensors_helper(sparse_tensors) 4517 4518 @property 4519 def parents(self): 4520 """See 'FeatureColumn` base class.""" 4521 return [self.categorical_column] 4522 4523 def get_config(self): 4524 """See 'FeatureColumn` base class.""" 4525 from tensorflow.python.feature_column.serialization import serialize_feature_column # pylint: disable=g-import-not-at-top 4526 config = dict(zip(self._fields, self)) 4527 config['categorical_column'] = serialize_feature_column( 4528 self.categorical_column) 4529 return config 4530 4531 @classmethod 4532 def from_config(cls, config, custom_objects=None, columns_by_name=None): 4533 """See 'FeatureColumn` base class.""" 4534 from tensorflow.python.feature_column.serialization import deserialize_feature_column # pylint: disable=g-import-not-at-top 4535 _check_config_keys(config, cls._fields) 4536 kwargs = _standardize_and_copy_config(config) 4537 kwargs['categorical_column'] = deserialize_feature_column( 4538 config['categorical_column'], custom_objects, columns_by_name) 4539 return cls(**kwargs) 4540 4541 4542def _check_config_keys(config, expected_keys): 4543 """Checks that a config has all expected_keys.""" 4544 if set(config.keys()) != set(expected_keys): 4545 raise ValueError('Invalid config: {}, expected keys: {}'.format( 4546 config, expected_keys)) 4547 4548 4549def _standardize_and_copy_config(config): 4550 """Returns a shallow copy of config with lists turned to tuples. 4551 4552 Keras serialization uses nest to listify everything. 4553 This causes problems with the NumericColumn shape, which becomes 4554 unhashable. We could try to solve this on the Keras side, but that 4555 would require lots of tracking to avoid changing existing behavior. 4556 Instead, we ensure here that we revive correctly. 4557 4558 Args: 4559 config: dict that will be used to revive a Feature Column 4560 4561 Returns: 4562 Shallow copy of config with lists turned to tuples. 4563 """ 4564 kwargs = config.copy() 4565 for k, v in kwargs.items(): 4566 if isinstance(v, list): 4567 kwargs[k] = tuple(v) 4568 4569 return kwargs 4570 4571 4572def _sanitize_column_name_for_variable_scope(name): 4573 """Sanitizes user-provided feature names for use as variable scopes.""" 4574 invalid_char = re.compile('[^A-Za-z0-9_.\\-]') 4575 return invalid_char.sub('_', name) 4576