1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =================================================================== 15"""TPU Feature Column Library.""" 16import copy 17import math 18 19import enum 20 21from tensorflow.python.feature_column import feature_column as fc 22from tensorflow.python.feature_column import feature_column_lib as fc_lib 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import embedding_ops 27from tensorflow.python.ops import init_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import sparse_ops 30from tensorflow.python.ops import variable_scope 31from tensorflow.python.tpu import tpu 32from tensorflow.python.tpu.feature_column import _is_running_on_cpu 33from tensorflow.python.tpu.feature_column import _record_variable_scope_and_name 34from tensorflow.python.tpu.feature_column import _SUPPORTED_CATEGORICAL_COLUMNS_V2 35from tensorflow.python.tpu.feature_column import _SUPPORTED_SEQUENCE_COLUMNS 36from tensorflow.python.tpu.feature_column import _TPUBaseEmbeddingColumn 37from tensorflow.python.util.tf_export import tf_export 38# pylint: disable=protected-access 39 40_ALLOWED_DEVICES = ['cpu', 'tpu_tensor_core', 'tpu_embedding_core'] 41_TENSOR_CORE_MASK_KEY_SUFFIX = '__TENSOR_CORE_MASK' 42 43 44class EmbeddingDevice(enum.Enum): 45 CPU = 1 46 TPU_TENSOR_CORE = 2 47 TPU_EMBEDDING_CORE = 3 48 49 50@tf_export(v1=['tpu.experimental.embedding_column']) 51def embedding_column_v2(categorical_column, 52 dimension, 53 combiner='mean', 54 initializer=None, 55 max_sequence_length=0, 56 learning_rate_fn=None, 57 embedding_lookup_device=None, 58 tensor_core_shape=None, 59 use_safe_embedding_lookup=True): 60 """TPU version of `tf.compat.v1.feature_column.embedding_column`. 61 62 Note that the interface for `tf.tpu.experimental.embedding_column` is 63 different from that of `tf.compat.v1.feature_column.embedding_column`: The 64 following arguments are NOT supported: `ckpt_to_load_from`, 65 `tensor_name_in_ckpt`, `max_norm` and `trainable`. 66 67 Use this function in place of `tf.compat.v1.feature_column.embedding_column` 68 when you want to use the TPU to accelerate your embedding lookups via TPU 69 embeddings. 70 71 ``` 72 column = tf.feature_column.categorical_column_with_identity(...) 73 tpu_column = tf.tpu.experimental.embedding_column(column, 10) 74 ... 75 def model_fn(features): 76 dense_feature = tf.keras.layers.DenseFeature(tpu_column) 77 embedded_feature = dense_feature(features) 78 ... 79 80 estimator = tf.estimator.tpu.TPUEstimator( 81 model_fn=model_fn, 82 ... 83 embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 84 column=[tpu_column], 85 ...)) 86 ``` 87 88 Args: 89 categorical_column: A categorical column returned from 90 `categorical_column_with_identity`, `weighted_categorical_column`, 91 `categorical_column_with_vocabulary_file`, 92 `categorical_column_with_vocabulary_list`, 93 `sequence_categorical_column_with_identity`, 94 `sequence_categorical_column_with_vocabulary_file`, 95 `sequence_categorical_column_with_vocabulary_list` 96 dimension: An integer specifying dimension of the embedding, must be > 0. 97 combiner: A string specifying how to reduce if there are multiple entries 98 in a single row for a non-sequence column. For more information, see 99 `tf.feature_column.embedding_column`. 100 initializer: A variable initializer function to be used in embedding 101 variable initialization. If not specified, defaults to 102 `tf.compat.v1.truncated_normal_initializer` with mean `0.0` and 103 standard deviation `1/sqrt(dimension)`. 104 max_sequence_length: An non-negative integer specifying the max sequence 105 length. Any sequence shorter then this will be padded with 0 embeddings 106 and any sequence longer will be truncated. This must be positive for 107 sequence features and 0 for non-sequence features. 108 learning_rate_fn: A function that takes global step and returns learning 109 rate for the embedding table. If you intend to use the same learning rate 110 for multiple embedding tables, please ensure that you pass the exact same 111 python function to all calls of embedding_column, otherwise performence 112 may suffer. 113 embedding_lookup_device: The device on which to run the embedding lookup. 114 Valid options are "cpu", "tpu_tensor_core", and "tpu_embedding_core". 115 If specifying "tpu_tensor_core", a tensor_core_shape must be supplied. 116 If not specified, the default behavior is embedding lookup on 117 "tpu_embedding_core" for training and "cpu" for inference. 118 Valid options for training : ["tpu_embedding_core", "tpu_tensor_core"] 119 Valid options for serving : ["cpu", "tpu_tensor_core"] 120 For training, tpu_embedding_core is good for large embedding vocab (>1M), 121 otherwise, tpu_tensor_core is often sufficient. 122 For serving, doing embedding lookup on tpu_tensor_core during serving is 123 a way to reduce host cpu usage in cases where that is a bottleneck. 124 tensor_core_shape: If supplied, a list of integers which specifies 125 the intended dense shape to run embedding lookup for this feature on 126 TensorCore. The batch dimension can be left None or -1 to indicate 127 a dynamic shape. Only rank 2 shapes currently supported. 128 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse 129 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures 130 there are no empty rows and all weights and ids are positive at the 131 expense of extra compute cost. This only applies to rank 2 (NxM) shaped 132 input tensors. Defaults to true, consider turning off if the above checks 133 are not needed. Note that having empty rows will not trigger any error 134 though the output result might be 0 or omitted. 135 136 Returns: 137 A `_TPUEmbeddingColumnV2`. 138 139 Raises: 140 ValueError: if `dimension` not > 0. 141 ValueError: if `initializer` is specified but not callable. 142 """ 143 144 if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS_V2): 145 raise TypeError( 146 'categorical_column for tpu ' 147 'embedding_column must be type {}, got {}.'.format(' or '.join([ 148 cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS_V2 149 ]), type(categorical_column))) 150 if (dimension is None) or (dimension < 1): 151 raise ValueError('Invalid dimension {}.'.format(dimension)) 152 if tensor_core_shape and len(tensor_core_shape) != 2: 153 raise ValueError( 154 'tensor_core_shape must be size 2. Got {}.'.format(tensor_core_shape)) 155 156 if (initializer is not None) and (not callable(initializer)): 157 raise ValueError('initializer must be callable if specified. ' 158 'Embedding of column_name: {}'.format( 159 categorical_column.name)) 160 if initializer is None: 161 initializer = init_ops.truncated_normal_initializer( 162 mean=0.0, stddev=1 / math.sqrt(dimension)) 163 164 if (embedding_lookup_device and 165 embedding_lookup_device not in _ALLOWED_DEVICES): 166 raise ValueError( 167 f'If set, embedding_lookup_device must be in {_ALLOWED_DEVICES}') 168 169 if embedding_lookup_device == 'cpu': 170 embedding_lookup_device = EmbeddingDevice.CPU 171 elif embedding_lookup_device == 'tpu_tensor_core': 172 embedding_lookup_device = EmbeddingDevice.TPU_TENSOR_CORE 173 elif embedding_lookup_device == 'tpu_embedding_core': 174 embedding_lookup_device = EmbeddingDevice.TPU_EMBEDDING_CORE 175 176 if embedding_lookup_device == EmbeddingDevice.TPU_TENSOR_CORE: 177 if not tensor_core_shape: 178 raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires ' 179 'tensor_core_shape to be set.') 180 if isinstance(categorical_column, _SUPPORTED_SEQUENCE_COLUMNS): 181 raise ValueError('embedding_lookup_device=tpu_tensor_core currently does ' 182 'not support sequence columns.') 183 184 if not embedding_lookup_device: 185 return _TPUEmbeddingColumnV2( 186 categorical_column=categorical_column, 187 dimension=dimension, 188 combiner=combiner, 189 initializer=initializer, 190 max_sequence_length=max_sequence_length, 191 learning_rate_fn=learning_rate_fn, 192 use_safe_embedding_lookup=use_safe_embedding_lookup) 193 else: 194 return _TPUDeviceSpecificEmbeddingColumnV2( 195 categorical_column=categorical_column, 196 dimension=dimension, 197 combiner=combiner, 198 initializer=initializer, 199 max_sequence_length=max_sequence_length, 200 learning_rate_fn=learning_rate_fn, 201 embedding_lookup_device=embedding_lookup_device, 202 tensor_core_shape=tensor_core_shape, 203 use_safe_embedding_lookup=use_safe_embedding_lookup) 204 205 206@tf_export(v1=['tpu.experimental.shared_embedding_columns']) 207def shared_embedding_columns_v2(categorical_columns, 208 dimension, 209 combiner='mean', 210 initializer=None, 211 shared_embedding_collection_name=None, 212 max_sequence_lengths=None, 213 learning_rate_fn=None, 214 embedding_lookup_device=None, 215 tensor_core_shape=None, 216 use_safe_embedding_lookup=True): 217 """TPU version of `tf.compat.v1.feature_column.shared_embedding_columns`. 218 219 Note that the interface for `tf.tpu.experimental.shared_embedding_columns` is 220 different from that of `tf.compat.v1.feature_column.shared_embedding_columns`: 221 The following arguments are NOT supported: `ckpt_to_load_from`, 222 `tensor_name_in_ckpt`, `max_norm` and `trainable`. 223 224 Use this function in place of 225 tf.compat.v1.feature_column.shared_embedding_columns` when you want to use the 226 TPU to accelerate your embedding lookups via TPU embeddings. 227 228 ``` 229 column_a = tf.feature_column.categorical_column_with_identity(...) 230 column_b = tf.feature_column.categorical_column_with_identity(...) 231 tpu_columns = tf.tpu.experimental.shared_embedding_columns( 232 [column_a, column_b], 10) 233 ... 234 def model_fn(features): 235 dense_feature = tf.keras.layers.DenseFeature(tpu_columns) 236 embedded_feature = dense_feature(features) 237 ... 238 239 estimator = tf.estimator.tpu.TPUEstimator( 240 model_fn=model_fn, 241 ... 242 embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec( 243 column=tpu_columns, 244 ...)) 245 ``` 246 247 Args: 248 categorical_columns: A list of categorical columns returned from 249 `categorical_column_with_identity`, `weighted_categorical_column`, 250 `categorical_column_with_vocabulary_file`, 251 `categorical_column_with_vocabulary_list`, 252 `sequence_categorical_column_with_identity`, 253 `sequence_categorical_column_with_vocabulary_file`, 254 `sequence_categorical_column_with_vocabulary_list` 255 dimension: An integer specifying dimension of the embedding, must be > 0. 256 combiner: A string specifying how to reduce if there are multiple entries in 257 a single row for a non-sequence column. For more information, see 258 `tf.feature_column.embedding_column`. 259 initializer: A variable initializer function to be used in embedding 260 variable initialization. If not specified, defaults to 261 `tf.truncated_normal_initializer` with mean `0.0` and standard deviation 262 `1/sqrt(dimension)`. 263 shared_embedding_collection_name: Optional name of the collection where 264 shared embedding weights are added. If not given, a reasonable name will 265 be chosen based on the names of `categorical_columns`. This is also used 266 in `variable_scope` when creating shared embedding weights. 267 max_sequence_lengths: An list of non-negative integers, either None or empty 268 or the same length as the argument categorical_columns. Entries 269 corresponding to non-sequence columns must be 0 and entries corresponding 270 to sequence columns specify the max sequence length for the column. Any 271 sequence shorter then this will be padded with 0 embeddings and any 272 sequence longer will be truncated. 273 learning_rate_fn: A function that takes global step and returns learning 274 rate for the embedding table. If you intend to use the same learning rate 275 for multiple embedding tables, please ensure that you pass the exact same 276 python function to all calls of shared_embedding_columns, otherwise 277 performence may suffer. 278 embedding_lookup_device: The device on which to run the embedding lookup. 279 Valid options are "cpu", "tpu_tensor_core", and "tpu_embedding_core". If 280 specifying "tpu_tensor_core", a tensor_core_shape must be supplied. 281 Defaults to "cpu". If not specified, the default behavior is embedding 282 lookup on "tpu_embedding_core" for training and "cpu" for inference. 283 Valid options for training : ["tpu_embedding_core", "tpu_tensor_core"] 284 Valid options for serving : ["cpu", "tpu_tensor_core"] 285 For training, tpu_embedding_core is good for large embedding vocab (>1M), 286 otherwise, tpu_tensor_core is often sufficient. 287 For serving, doing embedding lookup on tpu_tensor_core during serving is 288 a way to reduce host cpu usage in cases where that is a bottleneck. 289 tensor_core_shape: If supplied, a list of integers which specifies the 290 intended dense shape to run embedding lookup for this feature on 291 TensorCore. The batch dimension can be left None or -1 to indicate a 292 dynamic shape. Only rank 2 shapes currently supported. 293 use_safe_embedding_lookup: If true, uses safe_embedding_lookup_sparse 294 instead of embedding_lookup_sparse. safe_embedding_lookup_sparse ensures 295 there are no empty rows and all weights and ids are positive at the 296 expense of extra compute cost. This only applies to rank 2 (NxM) shaped 297 input tensors. Defaults to true, consider turning off if the above checks 298 are not needed. Note that having empty rows will not trigger any error 299 though the output result might be 0 or omitted. 300 301 Returns: 302 A list of `_TPUSharedEmbeddingColumnV2`. 303 304 Raises: 305 ValueError: if `dimension` not > 0. 306 ValueError: if `initializer` is specified but not callable. 307 ValueError: if `max_sequence_lengths` is specified and not the same length 308 as `categorical_columns`. 309 ValueError: if `max_sequence_lengths` is positive for a non sequence column 310 or 0 for a sequence column. 311 """ 312 313 for categorical_column in categorical_columns: 314 if not isinstance(categorical_column, _SUPPORTED_CATEGORICAL_COLUMNS_V2): 315 raise TypeError( 316 'categorical_column for tpu ' 317 ' shared_embedding_columns must be type {}, got {}.'.format( 318 ' or '.join( 319 [cc.__name__ for cc in _SUPPORTED_CATEGORICAL_COLUMNS_V2]), 320 type(categorical_column))) 321 322 if not max_sequence_lengths: 323 max_sequence_lengths = [0] * len(categorical_columns) 324 if len(max_sequence_lengths) != len(categorical_columns): 325 raise ValueError('max_sequence_lengths and categorical_columns must be of ' 326 'the same length. len(max_sequence_lengths)={} ' 327 'len(categorical_columns)={}.'.format( 328 len(max_sequence_lengths), len(categorical_columns))) 329 330 if (dimension is None) or (dimension < 1): 331 raise ValueError('Invalid dimension {}.'.format(dimension)) 332 if tensor_core_shape and len(tensor_core_shape) != 2: 333 raise ValueError( 334 'tensor_core_shape must be size 2. Got {}.'.format(tensor_core_shape)) 335 336 if (initializer is not None) and (not callable(initializer)): 337 raise ValueError('initializer must be callable if specified. ') 338 if initializer is None: 339 initializer = init_ops.truncated_normal_initializer( 340 mean=0.0, stddev=1 / math.sqrt(dimension)) 341 342 # Sort the columns so the default collection name is deterministic even if the 343 # user passes columns from an unsorted collection, such as dict.values(). 344 sorted_columns = sorted(categorical_columns, key=lambda x: x.name) 345 num_buckets = sorted_columns[0]._num_buckets # pylint: disable=protected-access 346 347 for c in sorted_columns[1:]: 348 if num_buckets != c._num_buckets: # pylint: disable=protected-access 349 raise ValueError( 350 'To use shared_embedding_column, all categorical_columns must have ' 351 'the same number of buckets. Given column: {} with buckets: {} does ' 352 'not match column: {} with buckets: {}'.format( 353 sorted_columns[0], num_buckets, c, c._num_buckets)) # pylint: disable=protected-access 354 355 if not shared_embedding_collection_name: 356 shared_embedding_collection_name = '_'.join(c.name for c in sorted_columns) 357 shared_embedding_collection_name += '_shared_embedding' 358 359 tpu_columns = [] 360 361 column_creator = fc_lib.SharedEmbeddingColumnCreator( 362 dimension=dimension, initializer=initializer, ckpt_to_load_from=None, 363 tensor_name_in_ckpt=None, num_buckets=num_buckets, trainable=None, 364 name=shared_embedding_collection_name) 365 366 if (embedding_lookup_device and 367 embedding_lookup_device not in _ALLOWED_DEVICES): 368 raise ValueError( 369 f'If set, embedding_lookup_device must be in {_ALLOWED_DEVICES}') 370 371 if embedding_lookup_device == 'cpu': 372 embedding_lookup_device = EmbeddingDevice.CPU 373 elif embedding_lookup_device == 'tpu_tensor_core': 374 embedding_lookup_device = EmbeddingDevice.TPU_TENSOR_CORE 375 elif embedding_lookup_device == 'tpu_embedding_core': 376 embedding_lookup_device = EmbeddingDevice.TPU_EMBEDDING_CORE 377 378 if embedding_lookup_device == EmbeddingDevice.TPU_TENSOR_CORE: 379 if not tensor_core_shape: 380 raise ValueError('Using embedding_lookup_device=tpu_tensor_core requires ' 381 'tensor_core_shape to be set.') 382 for c in sorted_columns: 383 if isinstance(c, _SUPPORTED_SEQUENCE_COLUMNS): 384 raise ValueError('embedding_lookup_device=tpu_tensor_core currently ' 385 'does not support sequence columns.') 386 387 # Create the state (_SharedEmbeddingColumnLayer) here. 388 for categorical_column, max_sequence_length in zip( 389 categorical_columns, max_sequence_lengths): 390 if not embedding_lookup_device: 391 column = _TPUSharedEmbeddingColumnV2( 392 categorical_column=categorical_column, 393 shared_embedding_column_creator=column_creator, 394 combiner=combiner, 395 initializer=initializer, 396 shared_embedding_collection_name=shared_embedding_collection_name, 397 max_sequence_length=max_sequence_length, 398 learning_rate_fn=learning_rate_fn, 399 use_safe_embedding_lookup=use_safe_embedding_lookup) 400 else: 401 column = _TPUSharedDeviceSpecificEmbeddingColumnV2( 402 categorical_column=categorical_column, 403 shared_embedding_column_creator=column_creator, 404 combiner=combiner, 405 initializer=initializer, 406 shared_embedding_collection_name=shared_embedding_collection_name, 407 max_sequence_length=max_sequence_length, 408 learning_rate_fn=learning_rate_fn, 409 embedding_lookup_device=embedding_lookup_device, 410 tensor_core_shape=tensor_core_shape, 411 use_safe_embedding_lookup=use_safe_embedding_lookup) 412 tpu_columns.append(column) 413 414 return tpu_columns 415 416 417class _TPUEmbeddingColumnV2(_TPUBaseEmbeddingColumn, fc_lib.EmbeddingColumn): 418 """Core Embedding Column.""" 419 420 def __new__(cls, 421 categorical_column, 422 dimension, 423 combiner='mean', 424 initializer=None, 425 max_sequence_length=0, 426 learning_rate_fn=None, 427 use_safe_embedding_lookup=True, 428 bypass_scope_validation=False): 429 del bypass_scope_validation 430 # pylint: disable=redundant-keyword-arg 431 return fc_lib.EmbeddingColumn.__new__( 432 cls, 433 categorical_column, 434 dimension, 435 combiner=combiner, 436 initializer=initializer, 437 ckpt_to_load_from=None, 438 tensor_name_in_ckpt=None, 439 max_norm=None, 440 trainable=True, 441 use_safe_embedding_lookup=use_safe_embedding_lookup) 442 443 def __getnewargs__(self): 444 return (self._tpu_categorical_column, self.dimension, self.combiner, 445 self.initializer, self._max_sequence_length, self._learning_rate_fn, 446 self.use_safe_embedding_lookup, self._bypass_scope_validation) 447 448 def __deepcopy__(self, memo): 449 return _TPUEmbeddingColumnV2( 450 *(copy.deepcopy(a, memo) for a in self.__getnewargs__())) 451 452 def __init__(self, 453 categorical_column, 454 dimension, 455 combiner='mean', 456 initializer=None, 457 max_sequence_length=0, 458 learning_rate_fn=None, 459 use_safe_embedding_lookup=True, 460 bypass_scope_validation=False): 461 _TPUBaseEmbeddingColumn.__init__( 462 self, 463 categorical_column, 464 max_sequence_length=max_sequence_length, 465 learning_rate_fn=learning_rate_fn) 466 self._key = None 467 # If true, scope validation is skipped to allow the same column to be used 468 # in multiple variable scopes. By default, this is False, and we expect a 469 # 1:1 mapping between feature columns and scopes. 470 self._bypass_scope_validation = bypass_scope_validation 471 472 def get_combiner(self): 473 return self.combiner 474 475 def get_embedding_table_size(self): 476 """Returns num_ids and width.""" 477 return (self.categorical_column._num_buckets, self.dimension) 478 479 def get_feature_key_name(self): 480 """get_feature_key_name.""" 481 if self.is_categorical_column_weighted(): 482 return self.categorical_column.categorical_column.name 483 return self.categorical_column.name 484 485 def get_weight_key_name(self): 486 """get_weight_key_name.""" 487 if self.is_categorical_column_weighted(): 488 return self.categorical_column.weight_feature_key 489 return None 490 491 def get_embedding_var_name(self): 492 """get_embedding_var_name.""" 493 return self.categorical_column.name 494 495 def get_initializer(self): 496 return self.initializer 497 498 def is_categorical_column_weighted(self): 499 """Check if the categorical column of the embedding column is weighted.""" 500 if isinstance( 501 self.categorical_column, 502 ( 503 fc._WeightedCategoricalColumn, # pylint: disable=protected-access 504 fc_lib.WeightedCategoricalColumn)): 505 return True 506 return False 507 508 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 509 if tpu.under_tpu_inference_context(): 510 def host_computation(): 511 return fc_lib.EmbeddingColumn._get_dense_tensor( 512 self, inputs, weight_collections, trainable) 513 return tpu.outside_compilation(host_computation) 514 515 if _is_running_on_cpu(): 516 return fc_lib.EmbeddingColumn._get_dense_tensor( 517 self, inputs, weight_collections, trainable) 518 519 # TPU mode 520 # Get the embeddings from the LazyBuilder. 521 tensor = inputs.get(self.get_feature_key_name()) 522 523 # Add to collection for _create_tpu_embedding_variables_and_ops 524 _record_variable_scope_and_name( 525 self.get_embedding_var_name(), 526 'embedding_weights', 527 bypass_scope_validation=self._bypass_scope_validation) 528 529 return tensor 530 531 def create_state(self, state_manager): 532 if _is_running_on_cpu(): 533 return fc_lib.EmbeddingColumn.create_state( 534 self, state_manager) 535 536 # Create state is called for the EmbeddingColumn to create its embedding 537 # variables under feature column V2, if we are on TPU so record the scope 538 # here. 539 _record_variable_scope_and_name( 540 self.get_embedding_var_name(), 541 'embedding_weights', 542 bypass_scope_validation=self._bypass_scope_validation) 543 544 def get_dense_tensor(self, transformation_cache, state_manager): 545 if tpu.under_tpu_inference_context(): 546 def host_computation(): 547 return fc_lib.EmbeddingColumn.get_dense_tensor( 548 self, transformation_cache, state_manager) 549 return tpu.outside_compilation(host_computation) 550 551 if _is_running_on_cpu(): 552 return fc_lib.EmbeddingColumn.get_dense_tensor( 553 self, transformation_cache, state_manager) 554 555 # TPU mode 556 # Get the embeddings from the FeatureTransformationCache. 557 tensor = transformation_cache.get(self.get_feature_key_name(), 558 state_manager) 559 560 return tensor 561 562 def _get_sequence_dense_tensor( 563 self, inputs, weight_collections=None, trainable=None): 564 if tpu.under_tpu_inference_context(): 565 def host_computation(): 566 return fc_lib.EmbeddingColumn._get_sequence_dense_tensor( 567 self, inputs, weight_collections, trainable) 568 return tpu.outside_compilation(host_computation) 569 570 if _is_running_on_cpu(): 571 return fc_lib.EmbeddingColumn._get_sequence_dense_tensor( 572 self, inputs, weight_collections, trainable) 573 574 tensor = inputs.get(self.get_feature_key_name()) 575 tensor_lengths = inputs.get(self.get_sequence_length_feature_key_name()) 576 577 # inputs is a _LazyBuilder and for rank 1 tensors, it calls expand_dims(-1). 578 # We need to undo this to match the standard CPU sequence embedding. 579 tensor_lengths = array_ops.squeeze(tensor_lengths, -1) 580 581 # Add to collection for _create_tpu_embedding_variables_and_ops 582 _record_variable_scope_and_name( 583 self.get_embedding_var_name(), 584 'embedding_weights', 585 bypass_scope_validation=self._bypass_scope_validation) 586 587 return fc_lib.SequenceDenseColumn.TensorSequenceLengthPair( 588 dense_tensor=tensor, sequence_length=tensor_lengths) 589 590 def get_sequence_dense_tensor(self, transformation_cache, state_manager): 591 if tpu.under_tpu_inference_context(): 592 def host_computation(): 593 return fc_lib.EmbeddingColumn.get_sequence_dense_tensor( 594 self, transformation_cache, state_manager) 595 return tpu.outside_compilation(host_computation) 596 597 if _is_running_on_cpu(): 598 return fc_lib.EmbeddingColumn.get_sequence_dense_tensor( 599 self, transformation_cache, state_manager) 600 601 tensor = transformation_cache.get(self.get_feature_key_name(), 602 state_manager) 603 tensor_lengths = transformation_cache.get( 604 self.get_sequence_length_feature_key_name(), 605 state_manager) 606 607 # FeatureTransformationCache expands rank 1 tensors (like sequence length) 608 # to rank 2. We need to undo this to match the standard CPU sequence 609 # embedding. 610 tensor_lengths = array_ops.squeeze(tensor_lengths, -1) 611 612 return fc_lib.SequenceDenseColumn.TensorSequenceLengthPair( 613 dense_tensor=tensor, sequence_length=tensor_lengths) 614 615 616class _TPUSharedEmbeddingColumnV2(_TPUBaseEmbeddingColumn, 617 fc_lib.SharedEmbeddingColumn): 618 """Core Shared Embedding Column.""" 619 620 def __new__(cls, 621 categorical_column, 622 shared_embedding_column_creator, 623 combiner='mean', 624 initializer=None, 625 shared_embedding_collection_name=None, 626 max_sequence_length=0, 627 learning_rate_fn=None, 628 use_safe_embedding_lookup=True): 629 # pylint: disable=redundant-keyword-arg 630 return fc_lib.SharedEmbeddingColumn.__new__( 631 cls, 632 categorical_column, 633 combiner=combiner, 634 shared_embedding_column_creator=shared_embedding_column_creator, 635 max_norm=None, 636 use_safe_embedding_lookup=use_safe_embedding_lookup) 637 638 def __getnewargs__(self): 639 return (self._tpu_categorical_column, self.shared_embedding_column_creator, 640 self.combiner, self._initializer, 641 self._shared_embedding_collection_name, self._max_sequence_length, 642 self._learning_rate_fn) 643 644 def __deepcopy__(self, memo): 645 return _TPUSharedEmbeddingColumnV2( 646 *(copy.deepcopy(a, memo) for a in self.__getnewargs__())) 647 648 def __init__(self, 649 categorical_column, 650 shared_embedding_column_creator, 651 combiner='mean', 652 initializer=None, 653 shared_embedding_collection_name=None, 654 max_sequence_length=0, 655 learning_rate_fn=None, 656 use_safe_embedding_lookup=True): 657 658 _TPUBaseEmbeddingColumn.__init__( 659 self, 660 categorical_column, 661 max_sequence_length=max_sequence_length, 662 learning_rate_fn=learning_rate_fn) 663 self._initializer = initializer 664 self._shared_embedding_collection_name = shared_embedding_collection_name 665 666 def get_combiner(self): 667 return self.combiner 668 669 def get_embedding_table_size(self): 670 """Returns num_ids and width.""" 671 return (self.categorical_column._num_buckets, 672 self.shared_embedding_column_creator.dimension) 673 674 def get_feature_key_name(self): 675 """get_feature_key_name.""" 676 if self.is_categorical_column_weighted(): 677 return self.categorical_column.categorical_column.name 678 return self.categorical_column.name 679 680 def get_weight_key_name(self): 681 """get_weight_key_name.""" 682 if self.is_categorical_column_weighted(): 683 return self.categorical_column.weight_feature_key 684 return None 685 686 def get_embedding_var_name(self): 687 """get_embedding_var_name.""" 688 return self._shared_embedding_collection_name 689 690 def get_initializer(self): 691 return self._initializer 692 693 def is_categorical_column_weighted(self): 694 """Check if the categorical column of the embedding column is weighted.""" 695 if isinstance( 696 self.categorical_column, 697 ( 698 fc._WeightedCategoricalColumn, # pylint: disable=protected-access 699 fc_lib.WeightedCategoricalColumn)): 700 return True 701 return False 702 703 def _get_dense_tensor_internal( 704 self, transformation_cache, state_manager): 705 if tpu.under_tpu_inference_context(): 706 def host_computation(): 707 return fc_lib.SharedEmbeddingColumn._get_dense_tensor_internal( 708 self, transformation_cache, state_manager) 709 return tpu.outside_compilation(host_computation) 710 711 if _is_running_on_cpu(): 712 return fc_lib.SharedEmbeddingColumn._get_dense_tensor_internal( 713 self, transformation_cache, state_manager) 714 715 # TPU mode 716 # Get the embeddings from the FeatureTransformationCache. 717 tensor = transformation_cache.get(self.get_feature_key_name(), 718 state_manager) 719 720 # Add to collection for _create_tpu_embedding_variables_and_ops 721 # Note that in Feature Column V2, shared embeddings have no scope. 722 _record_variable_scope_and_name( 723 self.get_embedding_var_name(), 724 self.shared_embedding_column_creator._name, 725 is_shared_embedding=True) 726 return tensor 727 728 def get_sequence_dense_tensor( 729 self, transformation_cache, state_manager): 730 if tpu.under_tpu_inference_context(): 731 def host_computation(): 732 return fc_lib.SharedEmbeddingColumn.get_sequence_dense_tensor( 733 self, transformation_cache, state_manager) 734 return tpu.outside_compilation(host_computation) 735 736 if _is_running_on_cpu(): 737 return fc_lib.SharedEmbeddingColumn.get_sequence_dense_tensor( 738 self, transformation_cache, state_manager) 739 740 tensor = self._get_dense_tensor_internal( 741 transformation_cache, state_manager) 742 tensor_lengths = transformation_cache.get( 743 self.get_sequence_length_feature_key_name(), 744 state_manager) 745 746 # FeatureTransformationCache expands rank 1 tensors (like sequence length) 747 # to rank 2. We need to undo this to match the standard CPU sequence 748 # embedding. 749 tensor_lengths = array_ops.squeeze(tensor_lengths, -1) 750 751 return fc_lib.SequenceDenseColumn.TensorSequenceLengthPair( 752 dense_tensor=tensor, sequence_length=tensor_lengths) 753 754 755def split_sequence_columns_v2(feature_columns): 756 """Split a list of _TPUEmbeddingColumn into sequence and non-sequence columns. 757 758 For use in a TPUEstimator model_fn function. E.g. 759 760 def model_fn(features): 761 sequence_columns, feature_columns = ( 762 tf.tpu.feature_column.split_sequence_columns(feature_columns)) 763 input = tf.feature_column.input_layer( 764 features=features, feature_columns=feature_columns) 765 sequence_features, sequence_lengths = ( 766 tf.contrib.feature_column.sequence_input_layer( 767 features=features, feature_columns=sequence_columns)) 768 769 Args: 770 feature_columns: A list of _TPUEmbeddingColumns to split. 771 772 Returns: 773 Two lists of _TPUEmbeddingColumns, the first is the sequence columns and the 774 second is the non-sequence columns. 775 """ 776 sequence_columns = [] 777 non_sequence_columns = [] 778 for column in feature_columns: 779 if not isinstance(column, (_TPUEmbeddingColumnV2, 780 _TPUSharedEmbeddingColumnV2)): 781 raise TypeError( 782 'column must be a _TPUEmbeddingColumnV2 or ' 783 f'_TPUSharedEmbeddingColumnV2 but got {type(column)} instead.') 784 if column.is_sequence_column(): 785 sequence_columns.append(column) 786 else: 787 non_sequence_columns.append(column) 788 return sequence_columns, non_sequence_columns 789 790 791def sparse_embedding_aggregate_slice(params, 792 values_and_values_mask, 793 combiner='mean', 794 name='sparse_embedding_aggregate_slice'): 795 """Uses XLA's dynamic slice operations to perform embedding lookups. 796 797 From third_party/cloud_tpu/models/movielens/tpu_embedding.py 798 799 Args: 800 params: Tensor of embedding table. Rank 2 (table_size x embedding dim) 801 values_and_values_mask: is a two-tuple that contains: values - Tensor of 802 embedding indices. Rank 2 (batch x n_indices) values_mask - Tensor of mask 803 / weights. Rank 2 (batch x n_indices) 804 combiner: The combiner to use for the embedding lookup. Currently supports 805 'sum' and 'mean'. 806 name: Optional name scope for created ops 807 808 Returns: 809 Rank 2 tensor of aggregated (per batch element) embedding vectors. 810 811 Raises: 812 ValueError: Combiner is not supported. 813 """ 814 values, values_mask = values_and_values_mask # unpack the two-tuple 815 with ops.name_scope(name): 816 _, embedding_dimension = params.get_shape().as_list() 817 n_batch, n_indices_padded = values.get_shape().as_list() 818 if not n_batch: 819 n_batch = -1 820 821 emb_lookup = array_ops.reshape( 822 embedding_ops.embedding_lookup( 823 params, array_ops.reshape(values, [n_batch, n_indices_padded])), 824 [n_batch, n_indices_padded, embedding_dimension]) 825 826 values_mask_broadcast = array_ops.reshape(values_mask, 827 [n_batch, n_indices_padded, 1]) 828 aggregate_emb = math_ops.reduce_sum( 829 emb_lookup * values_mask_broadcast, axis=1) 830 if combiner == 'sum': 831 return aggregate_emb 832 elif combiner == 'mean': 833 # In the case we have an empty row, both aggregate_emb and 834 # math_ops.reduce_sum(values_mask_broadcast, axis=1) will be 0. Thus, 835 # we can take max it with a non-zero value to prevent NaNs. Note that 836 # math_ops.reduce_sum(values_mask_broadcast, axis=1) will have integer 837 # values so 1.0 is the smallest value. 838 return aggregate_emb / math_ops.maximum( 839 math_ops.reduce_sum(values_mask_broadcast, axis=1), 1.0) 840 else: 841 raise ValueError('Dense TPU Embedding does not support combiner ' 842 'other than sum and mean.') 843 844 845def pad_sparse_embedding_lookup_indices(sparse_indices, padded_size): 846 """Creates statically-sized Tensors containing indices and weights. 847 848 From third_party/cloud_tpu/models/movielens/tpu_embedding.py 849 850 Also computes sparse_indices.values % embedding_table_size, for equivalent 851 functionality to sparse_column_with_integerized_feature. The returned 852 padded weight Tensor also doubles as a mask indicating which values in 853 the returned padded indices Tensor are indices versus padded zeros. 854 855 Args: 856 sparse_indices: SparseTensor of embedding lookup indices. 857 padded_size: Number of columns of the returned Tensors. Indices which fall 858 out of bounds will be truncated to the padded size. 859 860 Returns: 861 (sparse_indices.values padded to the specified size, 862 a mask the same size as the returned padded values in which 0s 863 indicate padded locations and 1s (or values from sparse_weights) 864 indicate actual values) 865 """ 866 batch_size = sparse_indices.dense_shape[0] 867 sparse_indices = sparse_ops.sparse_slice(sparse_indices, [0, 0], 868 [batch_size, padded_size]) 869 indices, values = sparse_indices.indices, sparse_indices.values 870 871 padded_values = array_ops.scatter_nd( 872 indices, 873 math_ops.cast(values, dtypes.int32), 874 shape=(batch_size, padded_size)) 875 876 weights = array_ops.ones_like(values, dtype=dtypes.float32) 877 padded_mask = array_ops.scatter_nd( 878 indices, weights, shape=(batch_size, padded_size)) 879 880 return padded_values, padded_mask 881 882 883def _check_invalid_cases(embedding_lookup_device): 884 """Checks for invalid embedding_lookup_device configurations.""" 885 if (tpu.under_tpu_inference_context() and 886 embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE): 887 raise ValueError( 888 'Using embedding_lookup_device=tpu_embedding_core during inference ' 889 'is not supported.') 890 if embedding_lookup_device == EmbeddingDevice.CPU: 891 if not tpu.under_tpu_inference_context(): 892 raise ValueError( 893 'Using TPUEmbeddingColumn with embedding_lookup_device="cpu" ' 894 'during training is not supported.') 895 896 897class _TPUDeviceSpecificEmbeddingColumnV2(_TPUEmbeddingColumnV2): 898 """TPUEmbeddingColumn which allows serving on TensorCore.""" 899 900 def __new__(cls, *args, **kwargs): 901 # For __new__, just capture the inference dense shape and call parent. 902 if 'tensor_core_shape' in kwargs: 903 cls._tensor_core_shape = kwargs['tensor_core_shape'] 904 del kwargs['tensor_core_shape'] 905 if 'embedding_lookup_device' in kwargs: 906 cls._embedding_lookup_device = kwargs['embedding_lookup_device'] 907 del kwargs['embedding_lookup_device'] 908 return _TPUEmbeddingColumnV2.__new__(cls, *args, **kwargs) 909 910 def __init__(self, *args, **kwargs): 911 # For __init__, just capture the inference dense shape and call parent. 912 if 'tensor_core_shape' in kwargs: 913 self._tensor_core_shape = kwargs['tensor_core_shape'] 914 del kwargs['tensor_core_shape'] 915 if 'embedding_lookup_device' in kwargs: 916 self._embedding_lookup_device = kwargs['embedding_lookup_device'] 917 del kwargs['embedding_lookup_device'] 918 _TPUEmbeddingColumnV2.__init__(self, *args, **kwargs) 919 920 def __deepcopy__(self, memo): 921 return _TPUDeviceSpecificEmbeddingColumnV2( 922 *(copy.deepcopy(a, memo) for a in self.__getnewargs__()), 923 tensor_core_shape=self._tensor_core_shape, 924 embedding_lookup_device=self._embedding_lookup_device) 925 926 def create_state(self, state_manager): 927 _check_invalid_cases(self._embedding_lookup_device) 928 # CPU case. 929 is_cpu = self._embedding_lookup_device == EmbeddingDevice.CPU 930 is_cpu = is_cpu or _is_running_on_cpu() 931 if is_cpu: 932 return fc_lib.EmbeddingColumn.create_state(self, state_manager) 933 # TPU_EMBEDDING_CORE case. 934 elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE: 935 return super(_TPUDeviceSpecificEmbeddingColumnV2, 936 self).create_state(state_manager) 937 938 # TPU_EMBEDDING_CORE case. 939 return fc_lib.EmbeddingColumn.create_state(self, state_manager) 940 941 def get_dense_tensor(self, transformation_cache, state_manager): 942 """Private method that follows get_dense_tensor.""" 943 _check_invalid_cases(self._embedding_lookup_device) 944 # CPU Case. 945 is_cpu = self._embedding_lookup_device == EmbeddingDevice.CPU 946 is_cpu = is_cpu or _is_running_on_cpu() 947 if is_cpu: 948 return super(_TPUDeviceSpecificEmbeddingColumnV2, 949 self).get_dense_tensor(transformation_cache, state_manager) 950 # TPU_EMBEDDING_CORE case. 951 elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE: 952 return super(_TPUDeviceSpecificEmbeddingColumnV2, 953 self).get_dense_tensor(transformation_cache, state_manager) 954 955 # TPU_EMBEDDING_CORE cases. 956 if tpu.under_tpu_inference_context(): 957 # For inference, use outside compile to densify and pad the input tensors. 958 sparse_tensor = transformation_cache.get(self.categorical_column.name, 959 state_manager) 960 961 def host_computation(): 962 return pad_sparse_embedding_lookup_indices(sparse_tensor, 963 self._tensor_core_shape[1]) 964 965 values, mask = tpu.outside_compilation(host_computation) 966 else: 967 # For training, the inputs should already have been densified and padded. 968 values = transformation_cache.get(self.categorical_column.name, 969 state_manager) 970 mask = transformation_cache.get( 971 self.categorical_column.name + _TENSOR_CORE_MASK_KEY_SUFFIX, 972 state_manager) 973 embedding_weights = state_manager.get_variable( 974 self, name='embedding_weights') 975 return sparse_embedding_aggregate_slice(embedding_weights, (values, mask), 976 self.get_combiner()) 977 978 def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None): 979 _check_invalid_cases(self._embedding_lookup_device) 980 # CPU Case. 981 is_cpu = self._embedding_lookup_device == EmbeddingDevice.CPU 982 is_cpu = is_cpu or _is_running_on_cpu() 983 if is_cpu: 984 return super(_TPUDeviceSpecificEmbeddingColumnV2, 985 self)._get_dense_tensor(inputs, weight_collections, 986 trainable) 987 # TPU_EMBEDDING_CORE case. 988 elif self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE: 989 return super(_TPUDeviceSpecificEmbeddingColumnV2, 990 self)._get_dense_tensor(inputs, weight_collections, 991 trainable) 992 993 # TPU_EMBEDDING_CORE cases. 994 if tpu.under_tpu_inference_context(): 995 # For inference, use outside compile to densify and pad the input tensors. 996 sparse_tensor = inputs.get(self.get_feature_key_name()) 997 998 def host_computation(): 999 return pad_sparse_embedding_lookup_indices(sparse_tensor, 1000 self._tensor_core_shape[1]) 1001 1002 values, mask = tpu.outside_compilation(host_computation) 1003 else: 1004 # For training, the inputs should already have been densified and padded. 1005 values = inputs.get(self.get_feature_key_name()) 1006 mask = inputs.get(self.get_feature_key_name() + 1007 _TENSOR_CORE_MASK_KEY_SUFFIX) 1008 1009 embedding_shape = (self.categorical_column._num_buckets, self.dimension) # pylint: disable=protected-access 1010 if (weight_collections and 1011 ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections): 1012 weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES) 1013 embedding_weights = variable_scope.get_variable( 1014 name='embedding_weights', 1015 shape=embedding_shape, 1016 dtype=dtypes.float32, 1017 initializer=self.initializer, 1018 trainable=self.trainable and trainable, 1019 collections=weight_collections) 1020 return sparse_embedding_aggregate_slice(embedding_weights, (values, mask), 1021 self.get_combiner()) 1022 1023 1024class _TPUSharedDeviceSpecificEmbeddingColumnV2(_TPUSharedEmbeddingColumnV2): 1025 """TPUSharedEmbeddingColumnV2 which allows serving on TensorCore.""" 1026 1027 def __new__(cls, *args, **kwargs): 1028 # For __new__, just capture the inference dense shape and call parent. 1029 if 'tensor_core_shape' in kwargs: 1030 cls._tensor_core_shape = kwargs['tensor_core_shape'] 1031 del kwargs['tensor_core_shape'] 1032 if 'embedding_lookup_device' in kwargs: 1033 cls._embedding_lookup_device = kwargs['embedding_lookup_device'] 1034 del kwargs['embedding_lookup_device'] 1035 1036 return _TPUSharedEmbeddingColumnV2.__new__(cls, *args, **kwargs) 1037 1038 def __init__(self, *args, **kwargs): 1039 # For __init__, just capture the inference dense shape and call parent. 1040 if 'tensor_core_shape' in kwargs: 1041 self._tensor_core_shape = kwargs['tensor_core_shape'] 1042 del kwargs['tensor_core_shape'] 1043 if 'embedding_lookup_device' in kwargs: 1044 self._embedding_lookup_device = kwargs['embedding_lookup_device'] 1045 del kwargs['embedding_lookup_device'] 1046 _TPUSharedEmbeddingColumnV2.__init__(self, *args, **kwargs) 1047 1048 def __deepcopy__(self, memo): 1049 return _TPUSharedDeviceSpecificEmbeddingColumnV2( 1050 *(copy.deepcopy(a, memo) for a in self.__getnewargs__()), 1051 tensor_core_shape=self._tensor_core_shape, 1052 embedding_lookup_device=self._embedding_lookup_device) 1053 1054 def _get_dense_tensor_internal(self, transformation_cache, state_manager): 1055 """Private method that follows _get_dense_tensor_internal.""" 1056 _check_invalid_cases(self._embedding_lookup_device) 1057 # CPU Case. 1058 is_cpu = self._embedding_lookup_device == EmbeddingDevice.CPU 1059 is_cpu = is_cpu or _is_running_on_cpu() 1060 if is_cpu: 1061 return super(_TPUSharedDeviceSpecificEmbeddingColumnV2, 1062 self)._get_dense_tensor_internal(transformation_cache, 1063 state_manager) 1064 # TPU_EMBEDDING_CORE case. 1065 if self._embedding_lookup_device == EmbeddingDevice.TPU_EMBEDDING_CORE: 1066 return super(_TPUSharedDeviceSpecificEmbeddingColumnV2, 1067 self)._get_dense_tensor_internal(transformation_cache, 1068 state_manager) 1069 1070 # TPU_EMBEDDING_CORE cases. 1071 if tpu.under_tpu_inference_context(): 1072 # For inference, use outside compile to densify and pad the input tensors. 1073 sparse_tensor = transformation_cache.get(self.categorical_column.name, 1074 state_manager) 1075 1076 def host_computation(): 1077 return pad_sparse_embedding_lookup_indices(sparse_tensor, 1078 self._tensor_core_shape[1]) 1079 1080 values, mask = tpu.outside_compilation(host_computation) 1081 else: 1082 # For training, the inputs should already have been densified and padded. 1083 values = transformation_cache.get(self.categorical_column.name, 1084 state_manager) 1085 mask = transformation_cache.get( 1086 self.categorical_column.name + _TENSOR_CORE_MASK_KEY_SUFFIX, 1087 state_manager) 1088 1089 # Do a dense embedding lookup on TensorCore. 1090 embedding_weights = self.shared_embedding_column_creator.embedding_weights 1091 return sparse_embedding_aggregate_slice(embedding_weights, (values, mask), 1092 self.get_combiner()) 1093