1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for lookup ops.""" 16import os 17import tempfile 18import unittest 19 20from absl.testing import parameterized 21import numpy as np 22 23from tensorflow.python import tf2 24from tensorflow.python.checkpoint import checkpoint as trackable 25from tensorflow.python.checkpoint import graph_view 26from tensorflow.python.checkpoint import util as checkpoint_util 27from tensorflow.python.client import session 28from tensorflow.python.data.experimental.ops import counter 29from tensorflow.python.data.ops import dataset_ops 30from tensorflow.python.eager import backprop 31from tensorflow.python.eager import context 32from tensorflow.python.eager import def_function 33from tensorflow.python.eager import function 34from tensorflow.python.eager import wrap_function 35from tensorflow.python.framework import constant_op 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import errors_impl 38from tensorflow.python.framework import ops 39from tensorflow.python.framework import sparse_tensor 40from tensorflow.python.framework import tensor_spec 41from tensorflow.python.framework import test_ops 42from tensorflow.python.framework import test_util 43from tensorflow.python.ops import array_ops 44from tensorflow.python.ops import control_flow_ops 45from tensorflow.python.ops import lookup_ops 46from tensorflow.python.ops import map_fn 47from tensorflow.python.ops import variables 48from tensorflow.python.ops.ragged import ragged_tensor 49from tensorflow.python.platform import test 50from tensorflow.python.saved_model import load as saved_model_load 51from tensorflow.python.saved_model import save as saved_model_save 52from tensorflow.python.trackable import asset 53from tensorflow.python.trackable import autotrackable 54from tensorflow.python.training import saver 55from tensorflow.python.training import server_lib 56from tensorflow.python.util import compat 57 58 59class BaseLookupTableTest(test.TestCase): 60 61 def getHashTable(self): 62 if tf2.enabled(): 63 return lookup_ops.StaticHashTable 64 else: 65 return lookup_ops.StaticHashTableV1 66 67 def getVocabularyTable(self): 68 if tf2.enabled(): 69 return lookup_ops.StaticVocabularyTable 70 else: 71 return lookup_ops.StaticVocabularyTableV1 72 73 def initialize_table(self, table): 74 if not tf2.enabled(): 75 self.evaluate(table.initializer) 76 77 78SKIP_ANONYMOUS_IN_TF1_REASON = ( 79 "In v1 graph mode, each self.evaluate call will execute the handle " 80 "creation op (e.g. AnonymousHashTable) which will create a new table " 81 "resource unrelated to other self.evaluate calls, so we can't test " 82 "anonymous resources with self.evaluate ." 83) 84 85 86@parameterized.named_parameters( 87 (f"_{is_anonymous}", is_anonymous) for is_anonymous in [False, True]) 88class StaticHashTableTest(BaseLookupTableTest, parameterized.TestCase): 89 90 def testStaticHashTable(self, is_anonymous): 91 if is_anonymous and not tf2.enabled(): 92 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 93 default_val = -1 94 keys = constant_op.constant(["brain", "salad", "surgery"]) 95 values = constant_op.constant([0, 1, 2], dtypes.int64) 96 table = self.getHashTable()( 97 lookup_ops.KeyValueTensorInitializer(keys, values), 98 default_val, 99 experimental_is_anonymous=is_anonymous) 100 self.assertEqual(table._is_anonymous, is_anonymous) 101 self.initialize_table(table) 102 103 self.assertAllEqual(3, self.evaluate(table.size())) 104 105 input_string = constant_op.constant(["brain", "salad", "tank"]) 106 output = table.lookup(input_string) 107 self.assertAllEqual([3], output.get_shape()) 108 109 result = self.evaluate(output) 110 self.assertAllEqual([0, 1, -1], result) 111 112 exported_keys_tensor, exported_values_tensor = table.export() 113 114 self.assertItemsEqual([b"brain", b"salad", b"surgery"], 115 self.evaluate(exported_keys_tensor)) 116 self.assertItemsEqual([0, 1, 2], self.evaluate(exported_values_tensor)) 117 118 def testStaticHashTableFindHighRank(self, is_anonymous): 119 if is_anonymous and not tf2.enabled(): 120 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 121 default_val = -1 122 keys = constant_op.constant(["brain", "salad", "surgery"]) 123 values = constant_op.constant([0, 1, 2], dtypes.int64) 124 table = self.getHashTable()( 125 lookup_ops.KeyValueTensorInitializer(keys, values), 126 default_val, 127 experimental_is_anonymous=is_anonymous) 128 self.initialize_table(table) 129 130 self.assertAllEqual(3, self.evaluate(table.size())) 131 132 input_string = constant_op.constant([["brain", "salad"], 133 ["tank", "tarkus"]]) 134 output = table.lookup(input_string) 135 136 result = self.evaluate(output) 137 self.assertAllEqual([[0, 1], [-1, -1]], result) 138 139 def testStaticHashTableInitWithPythonArrays(self, is_anonymous): 140 if is_anonymous and not tf2.enabled(): 141 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 142 default_val = -1 143 keys = ["brain", "salad", "surgery"] 144 values = [0, 1, 2] 145 table = self.getHashTable()( 146 lookup_ops.KeyValueTensorInitializer( 147 keys, values, value_dtype=dtypes.int64), 148 default_val, 149 experimental_is_anonymous=is_anonymous) 150 self.initialize_table(table) 151 152 self.assertAllEqual(3, self.evaluate(table.size())) 153 154 input_string = constant_op.constant(["brain", "salad", "tank"]) 155 output = table.lookup(input_string) 156 157 result = self.evaluate(output) 158 self.assertAllEqual([0, 1, -1], result) 159 160 def testStaticHashTableInitWithNumPyArrays(self, is_anonymous): 161 if is_anonymous and not tf2.enabled(): 162 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 163 default_val = -1 164 keys = np.array(["brain", "salad", "surgery"], dtype=np.str_) 165 values = np.array([0, 1, 2], dtype=np.int64) 166 table = self.getHashTable()( 167 lookup_ops.KeyValueTensorInitializer(keys, values), 168 default_val, 169 experimental_is_anonymous=is_anonymous) 170 self.initialize_table(table) 171 172 self.assertAllEqual(3, self.evaluate(table.size())) 173 174 input_string = constant_op.constant(["brain", "salad", "tank"]) 175 output = table.lookup(input_string) 176 177 result = self.evaluate(output) 178 self.assertAllEqual([0, 1, -1], result) 179 180 def testMultipleStaticHashTables(self, is_anonymous): 181 if is_anonymous and not tf2.enabled(): 182 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 183 default_val = -1 184 keys = constant_op.constant(["brain", "salad", "surgery"]) 185 values = constant_op.constant([0, 1, 2], dtypes.int64) 186 187 table1 = self.getHashTable()( 188 lookup_ops.KeyValueTensorInitializer(keys, values), 189 default_val, 190 experimental_is_anonymous=is_anonymous) 191 table2 = self.getHashTable()( 192 lookup_ops.KeyValueTensorInitializer(keys, values), 193 default_val, 194 experimental_is_anonymous=is_anonymous) 195 table3 = self.getHashTable()( 196 lookup_ops.KeyValueTensorInitializer(keys, values), 197 default_val, 198 experimental_is_anonymous=is_anonymous) 199 200 self.initialize_table(table1) 201 self.initialize_table(table2) 202 self.initialize_table(table3) 203 self.assertAllEqual(3, self.evaluate(table1.size())) 204 self.assertAllEqual(3, self.evaluate(table2.size())) 205 self.assertAllEqual(3, self.evaluate(table3.size())) 206 207 input_string = constant_op.constant(["brain", "salad", "tank"]) 208 output1 = table1.lookup(input_string) 209 output2 = table2.lookup(input_string) 210 output3 = table3.lookup(input_string) 211 212 out1, out2, out3 = self.evaluate([output1, output2, output3]) 213 self.assertAllEqual([0, 1, -1], out1) 214 self.assertAllEqual([0, 1, -1], out2) 215 self.assertAllEqual([0, 1, -1], out3) 216 217 def testStaticHashTableWithTensorDefault(self, is_anonymous): 218 if is_anonymous and not tf2.enabled(): 219 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 220 default_val = constant_op.constant(-1, dtypes.int64) 221 keys = constant_op.constant(["brain", "salad", "surgery"]) 222 values = constant_op.constant([0, 1, 2], dtypes.int64) 223 table = self.getHashTable()( 224 lookup_ops.KeyValueTensorInitializer(keys, values), 225 default_val, 226 experimental_is_anonymous=is_anonymous) 227 self.initialize_table(table) 228 229 input_string = constant_op.constant(["brain", "salad", "tank"]) 230 output = table.lookup(input_string) 231 232 result = self.evaluate(output) 233 self.assertAllEqual([0, 1, -1], result) 234 235 def testStaticHashTableGetItem(self, is_anonymous): 236 if is_anonymous and not tf2.enabled(): 237 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 238 default_val = constant_op.constant(-1, dtypes.int64) 239 keys = constant_op.constant(["brain", "salad", "surgery"]) 240 values = constant_op.constant([0, 1, 2], dtypes.int64) 241 table = self.getHashTable()( 242 lookup_ops.KeyValueTensorInitializer(keys, values), 243 default_val, 244 experimental_is_anonymous=is_anonymous) 245 self.initialize_table(table) 246 247 input_string = constant_op.constant(["brain", "salad", "tank"]) 248 output = table[input_string] 249 250 result = self.evaluate(output) 251 self.assertAllEqual([0, 1, -1], result) 252 253 def testStaticHashTableWithSparseTensorInput(self, is_anonymous): 254 if is_anonymous and not tf2.enabled(): 255 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 256 default_val = constant_op.constant(-1, dtypes.int64) 257 keys = constant_op.constant(["brain", "salad", "surgery"]) 258 values = constant_op.constant([0, 1, 2], dtypes.int64) 259 table = self.getHashTable()( 260 lookup_ops.KeyValueTensorInitializer(keys, values), 261 default_val, 262 experimental_is_anonymous=is_anonymous) 263 self.initialize_table(table) 264 265 sp_indices = [[0, 0], [0, 1], [1, 0]] 266 sp_shape = [2, 2] 267 input_tensor = sparse_tensor.SparseTensor( 268 constant_op.constant(sp_indices, dtypes.int64), 269 constant_op.constant(["brain", "salad", "tank"]), 270 constant_op.constant(sp_shape, dtypes.int64)) 271 output = table.lookup(input_tensor) 272 273 out_indices, out_values, out_shape = self.evaluate(output) 274 275 self.assertAllEqual([0, 1, -1], out_values) 276 self.assertAllEqual(sp_indices, out_indices) 277 self.assertAllEqual(sp_shape, out_shape) 278 279 def testStaticHashTableWithRaggedTensorInput(self, is_anonymous): 280 if is_anonymous and not tf2.enabled(): 281 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 282 default_val = constant_op.constant(-1, dtypes.int64) 283 keys = constant_op.constant(["brain", "salad", "surgery"]) 284 values = constant_op.constant([0, 1, 2], dtypes.int64) 285 table = self.getHashTable()( 286 lookup_ops.KeyValueTensorInitializer(keys, values), 287 default_val, 288 experimental_is_anonymous=is_anonymous) 289 self.initialize_table(table) 290 291 row_splits = [0, 2, 3] 292 input_tensor = ragged_tensor.RaggedTensor.from_row_splits( 293 constant_op.constant(["brain", "salad", "tank"]), 294 constant_op.constant(row_splits, dtypes.int64)) 295 output = table.lookup(input_tensor) 296 297 out = self.evaluate(output) 298 299 self.assertAllEqual([0, 1, -1], out.values) 300 self.assertAllEqual(row_splits, out.row_splits) 301 302 def testSignatureMismatch(self, is_anonymous): 303 if is_anonymous and not tf2.enabled(): 304 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 305 default_val = -1 306 keys = constant_op.constant(["brain", "salad", "surgery"]) 307 values = constant_op.constant([0, 1, 2], dtypes.int64) 308 table = self.getHashTable()( 309 lookup_ops.KeyValueTensorInitializer(keys, values), 310 default_val, 311 experimental_is_anonymous=is_anonymous) 312 self.initialize_table(table) 313 314 # Ref types do not produce a lookup signature mismatch. 315 input_string_ref = variables.Variable("brain") 316 self.evaluate(input_string_ref.initializer) 317 self.assertEqual(0, self.evaluate(table.lookup(input_string_ref))) 318 319 input_string = constant_op.constant([1, 2, 3], dtypes.int64) 320 with self.assertRaises(TypeError): 321 table.lookup(input_string) 322 323 with self.assertRaises(TypeError): 324 self.getHashTable()( 325 lookup_ops.KeyValueTensorInitializer(keys, values), 326 "UNK", 327 experimental_is_anonymous=is_anonymous) 328 329 def testDTypes(self, is_anonymous): 330 default_val = -1 331 with self.assertRaises(TypeError): 332 self.getHashTable()( 333 lookup_ops.KeyValueTensorInitializer(["a"], [1], [dtypes.string], 334 dtypes.int64), 335 default_val, 336 experimental_is_anonymous=is_anonymous) 337 338 @test_util.run_v1_only("(Cached) Sessions not available in TF2.0") 339 def testNotInitialized(self, is_anonymous): 340 with self.cached_session(): 341 default_val = -1 342 table = self.getHashTable()( 343 lookup_ops.KeyValueTensorInitializer(["a"], [1], 344 value_dtype=dtypes.int64), 345 default_val, 346 experimental_is_anonymous=is_anonymous) 347 348 input_string = constant_op.constant(["brain", "salad", "surgery"]) 349 output = table.lookup(input_string) 350 351 with self.assertRaisesOpError("Table not initialized"): 352 self.evaluate(output) 353 354 @test_util.run_v1_only("(Cached) Sessions not available in TF2.0") 355 def testInitializeTwice(self, is_anonymous): 356 with self.cached_session(): 357 default_val = -1 358 keys = constant_op.constant(["brain", "salad", "surgery"]) 359 values = constant_op.constant([0, 1, 2], dtypes.int64) 360 table = self.getHashTable()( 361 lookup_ops.KeyValueTensorInitializer(keys, values), 362 default_val, 363 experimental_is_anonymous=is_anonymous) 364 self.initialize_table(table) 365 # Make sure that initializing twice doesn't throw any errors. 366 self.initialize_table(table) 367 368 def testInitializationWithInvalidDimensions(self, is_anonymous): 369 default_val = -1 370 keys = constant_op.constant(["brain", "salad", "surgery"]) 371 values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64) 372 373 raised_error = ValueError 374 if context.executing_eagerly(): 375 raised_error = errors_impl.InvalidArgumentError 376 with self.assertRaises(raised_error): 377 self.getHashTable()( 378 lookup_ops.KeyValueTensorInitializer(keys, values), 379 default_val, 380 experimental_is_anonymous=is_anonymous) 381 382 @test_util.run_v1_only("Sessions not available in TF2.0") 383 def testMultipleSessions(self, is_anonymous): 384 if is_anonymous and not tf2.enabled(): 385 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 386 # Start a server 387 server = server_lib.Server({"local0": ["localhost:0"]}, 388 protocol="grpc", 389 start=True) 390 # Create two sessions sharing the same state 391 session1 = session.Session(server.target) 392 session2 = session.Session(server.target) 393 394 default_val = -1 395 keys = constant_op.constant(["brain", "salad", "surgery"]) 396 values = constant_op.constant([0, 1, 2], dtypes.int64) 397 table = self.getHashTable()( 398 lookup_ops.KeyValueTensorInitializer(keys, values), 399 default_val, 400 name="t1", 401 experimental_is_anonymous=is_anonymous) 402 403 # Init the table in the first session. 404 with session1: 405 self.initialize_table(table) 406 self.assertAllEqual(3, self.evaluate(table.size())) 407 408 # Init the table in the second session and verify that we do not get a 409 # "Table already initialized" error. 410 with session2: 411 self.evaluate(table.initializer) 412 self.assertAllEqual(3, self.evaluate(table.size())) 413 414 @test_util.run_v2_only 415 def testImportedHashTable(self, is_anonymous): 416 g = ops.Graph() 417 with g.as_default(): 418 t = lookup_ops.StaticHashTable( 419 lookup_ops.KeyValueTensorInitializer(["a"], [1]), 420 2) 421 init_op = t._init_op 422 op = t.lookup(ops.convert_to_tensor(["a"])) 423 meta_graph = saver.export_meta_graph() 424 425 def f(): 426 saver.import_meta_graph(meta_graph) 427 return ops.get_default_graph().get_tensor_by_name(op.name) 428 429 wrapped = wrap_function.wrap_function(f, []) 430 pruned_init_fn = wrapped.prune( 431 (), [wrapped.graph.get_operation_by_name(init_op.name)]) 432 self.evaluate(pruned_init_fn()) 433 self.assertAllEqual([1], wrapped()) 434 435 def testStaticHashTableInt32String(self, is_anonymous): 436 if is_anonymous and not tf2.enabled(): 437 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 438 default_val = "n/a" 439 keys = constant_op.constant([0, 1, 2], dtypes.int32) 440 values = constant_op.constant(["brain", "salad", "surgery"]) 441 table = self.getHashTable()( 442 lookup_ops.KeyValueTensorInitializer(keys, values), 443 default_val, 444 experimental_is_anonymous=is_anonymous) 445 self.initialize_table(table) 446 447 input_tensor = constant_op.constant([0, 1, -1]) 448 output = table.lookup(input_tensor) 449 450 result = self.evaluate(output) 451 self.assertAllEqual([b"brain", b"salad", b"n/a"], result) 452 453 def testTableUseInFunction(self, is_anonymous): 454 if not context.executing_eagerly(): 455 self.skipTest("Only Eager mode test.") 456 keys = constant_op.constant([0, 1, 2], dtypes.int32) 457 values = constant_op.constant(["brain", "salad", "surgery"]) 458 table = self.getHashTable()( 459 lookup_ops.KeyValueTensorInitializer(keys, values), 460 "n/a", 461 experimental_is_anonymous=is_anonymous) 462 463 @function.defun() 464 def lookup_table_func(k): 465 return table.lookup(k) 466 467 result = lookup_table_func(constant_op.constant([0, 1, -1])) 468 self.assertAllEqual([b"brain", b"salad", b"n/a"], result) 469 result = lookup_table_func(constant_op.constant([2, -1, 1])) 470 self.assertAllEqual([b"surgery", b"n/a", b"salad"], result) 471 472 def testTableCreatedInFunction(self, is_anonymous): 473 if not context.executing_eagerly(): 474 self.skipTest("Only Eager mode test.") 475 keys = constant_op.constant([0, 1, 2], dtypes.int32) 476 values = constant_op.constant(["brain", "salad", "surgery"]) 477 478 @function.defun() 479 def lookup_table_func(k): 480 table = self.getHashTable()( 481 lookup_ops.KeyValueTensorInitializer(keys, values), 482 "n/a", 483 experimental_is_anonymous=is_anonymous) 484 return table.lookup(k) 485 486 result = lookup_table_func(constant_op.constant([0, 1, -1])) 487 self.assertAllEqual([b"brain", b"salad", b"n/a"], result) 488 result = lookup_table_func(constant_op.constant([2, -1, 1])) 489 self.assertAllEqual([b"surgery", b"n/a", b"salad"], result) 490 491 def testTwoTablesInControlFlow(self, is_anonymous): 492 if is_anonymous and not tf2.enabled(): 493 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 494 keys = constant_op.constant([1, 2, 3], dtypes.int32) 495 values = constant_op.constant([5, 10, 15], dtypes.int32) 496 497 def table_func1(x): 498 table = self.getHashTable()( 499 lookup_ops.KeyValueTensorInitializer(keys, values), 500 -1, 501 experimental_is_anonymous=is_anonymous) 502 return table.lookup(x) 503 504 elems = np.array([2, 4, 1], dtype=np.int32) 505 result1 = map_fn.map_fn(table_func1, elems, dtype=dtypes.int32) 506 507 def table_func2(x): 508 table = self.getHashTable()( 509 lookup_ops.KeyValueTensorInitializer(keys, values), 510 -1, 511 experimental_is_anonymous=is_anonymous) 512 return table.lookup(x) 513 514 elems = np.array([2, 4, 1], dtype=np.int32) 515 result2 = map_fn.map_fn(table_func2, elems, dtype=dtypes.int32) 516 517 self.evaluate(lookup_ops.tables_initializer()) 518 519 self.assertAllEqual([10, -1, 5], self.evaluate(result1)) 520 self.assertAllEqual([10, -1, 5], self.evaluate(result2)) 521 522 @test_util.enable_control_flow_v2 523 def testLookupTableInWhileV2(self, is_anonymous): 524 lookup = self.getHashTable()( 525 lookup_ops.KeyValueTensorInitializer( 526 constant_op.constant([2, 5], dtype=dtypes.int64), 527 constant_op.constant([-10.0, 1], dtype=dtypes.float32)), 528 -1, 529 experimental_is_anonymous=is_anonymous) 530 531 beta = variables.Variable(1.0, trainable=True) 532 533 @def_function.function 534 def get_loss(unused_beta): 535 return map_fn.map_fn( 536 lookup.lookup, 537 constant_op.constant([2, 3], dtype=dtypes.int64), 538 dtype=dtypes.float32) 539 540 with backprop.GradientTape() as tape: 541 loss = get_loss(beta) 542 543 self.assertIsNone(tape.gradient(loss, beta)) 544 545 @test_util.enable_control_flow_v2 546 def testLookupTableInCondV2(self, is_anonymous): 547 if is_anonymous and not tf2.enabled(): 548 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 549 lookup = self.getHashTable()( 550 lookup_ops.KeyValueTensorInitializer( 551 constant_op.constant([2, 5], dtype=dtypes.int64), 552 constant_op.constant([-10.0, 1], dtype=dtypes.float32)), 553 -1, 554 experimental_is_anonymous=is_anonymous) 555 556 beta = variables.Variable(1.0, trainable=True) 557 558 @def_function.function 559 def get_loss(beta): 560 561 def true_fn(): 562 return lookup.lookup(constant_op.constant(2, dtype=dtypes.int64)) 563 564 def false_fn(): 565 return constant_op.constant(0, dtype=dtypes.float32) 566 567 return beta * control_flow_ops.cond( 568 constant_op.constant(True), true_fn=true_fn, false_fn=false_fn) 569 570 with backprop.GradientTape() as tape: 571 loss = get_loss(beta) 572 grad = tape.gradient(loss, beta) 573 self.evaluate(variables.global_variables_initializer()) 574 self.evaluate(lookup_ops.tables_initializer()) 575 self.assertAllEqual(grad, -10.) 576 577 def testExportShapeInference(self, is_anonymous): 578 table = self.getHashTable()( 579 lookup_ops.KeyValueTensorInitializer( 580 constant_op.constant([2, 5], dtype=dtypes.int64), 581 constant_op.constant([-10.0, 1], dtype=dtypes.float32)), 582 -1, 583 experimental_is_anonymous=is_anonymous) 584 actual_shapes = [t.shape for t in table.export()] 585 inferred_shapes = [] 586 587 @def_function.function 588 def f(): 589 for t in table.export(): 590 inferred_shapes.append(t.shape) 591 592 f() 593 self.assertLen(actual_shapes, 2) 594 self.assertLen(inferred_shapes, 2) 595 self.assertTrue(inferred_shapes[0].is_compatible_with(actual_shapes[0])) 596 self.assertTrue(inferred_shapes[1].is_compatible_with(actual_shapes[1])) 597 598 @test_util.run_v2_only 599 def testSavedModelSaveRestore(self, is_anonymous): 600 save_dir = os.path.join(self.get_temp_dir(), "save_restore") 601 save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") 602 603 root = autotrackable.AutoTrackable() 604 605 default_value = -1 606 keys = constant_op.constant([11, 12, 13], dtypes.int64) 607 values = constant_op.constant([0, 1, 2], dtypes.int64) 608 root.table = self.getHashTable()( 609 lookup_ops.KeyValueTensorInitializer(keys, values), 610 default_value, 611 experimental_is_anonymous=is_anonymous) 612 613 @def_function.function( 614 input_signature=[tensor_spec.TensorSpec((), dtypes.int64)]) 615 def lookup(key): 616 return root.table.lookup(key) 617 618 @def_function.function(input_signature=[]) 619 def size(): 620 return root.table.size() 621 622 @def_function.function(input_signature=[]) 623 def is_ref_counting(): 624 return test_ops.is_resource_handle_ref_counting( 625 root.table.resource_handle) 626 627 root.lookup = lookup 628 root.size = size 629 root.is_ref_counting = is_ref_counting 630 631 self.assertEqual(root.table.size(), 3) 632 self.assertEqual(root.lookup(12), 1) 633 self.assertEqual(root.lookup(10), -1) 634 self.assertLen(root.table.export()[0], 3) 635 self.assertEqual(root.is_ref_counting(), is_anonymous) 636 637 saved_model_save.save(root, save_path) 638 639 del root 640 loaded = saved_model_load.load(save_path) 641 self.assertEqual(loaded.size(), 3) 642 self.assertEqual(loaded.lookup(12), 1) 643 self.assertEqual(loaded.lookup(10), -1) 644 self.assertEqual(loaded.is_ref_counting(), is_anonymous) 645 646 647@parameterized.named_parameters( 648 (f"_{is_anonymous}", is_anonymous) for is_anonymous in [False, True]) 649class KeyValueTensorInitializerTest(BaseLookupTableTest): 650 651 def test_string(self, is_anonymous): 652 init = lookup_ops.KeyValueTensorInitializer( 653 ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64) 654 table = self.getHashTable()( 655 init, default_value=-1, experimental_is_anonymous=is_anonymous) 656 self.initialize_table(table) 657 658 def test_multiple_tables(self, is_anonymous): 659 with ops.name_scope("table_scope"): 660 init1 = lookup_ops.KeyValueTensorInitializer( 661 ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64) 662 table1 = self.getHashTable()( 663 init1, default_value=-1, experimental_is_anonymous=is_anonymous) 664 if not context.executing_eagerly(): 665 self.assertEqual("hash_table", table1.name) 666 self.assertEqual("table_scope/hash_table", 667 table1.resource_handle.op.name) 668 init2 = lookup_ops.KeyValueTensorInitializer( 669 ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64) 670 table2 = self.getHashTable()( 671 init2, default_value=-1, experimental_is_anonymous=is_anonymous) 672 if not context.executing_eagerly(): 673 self.assertEqual("hash_table_1", table2.name) 674 self.assertEqual("table_scope/hash_table_1", 675 table2.resource_handle.op.name) 676 677 def test_int64(self, is_anonymous): 678 init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2), 679 dtypes.int64, dtypes.int64) 680 table = self.getHashTable()( 681 init, default_value=-1, experimental_is_anonymous=is_anonymous) 682 self.initialize_table(table) 683 684 def test_int32(self, is_anonymous): 685 init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2), 686 dtypes.int32, dtypes.int64) 687 with self.assertRaises(errors_impl.OpError): 688 table = self.getHashTable()( 689 init, default_value=-1, experimental_is_anonymous=is_anonymous) 690 self.initialize_table(table) 691 692 693@parameterized.named_parameters( 694 (f"_{is_anonymous}", is_anonymous) for is_anonymous in [False, True]) 695class InitializeTableFromFileOpTest(BaseLookupTableTest): 696 697 def _createVocabFile(self, basename, values=("brain", "salad", "surgery")): 698 vocabulary_file = os.path.join(self.get_temp_dir(), basename) 699 with open(vocabulary_file, "w") as f: 700 f.write("\n".join(values) + "\n") 701 return vocabulary_file 702 703 def testInitializeStringTable(self, is_anonymous): 704 if is_anonymous and not tf2.enabled(): 705 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 706 vocabulary_file = self._createVocabFile("one_column_1.txt") 707 default_value = -1 708 init = lookup_ops.TextFileInitializer( 709 vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, 710 dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER) 711 self.assertIn("one_column_1.txt_-2_-1", init._shared_name) 712 table = self.getHashTable()( 713 init, default_value, experimental_is_anonymous=is_anonymous) 714 self.initialize_table(table) 715 716 output = table.lookup(constant_op.constant(["brain", "salad", "tank"])) 717 718 result = self.evaluate(output) 719 self.assertAllEqual([0, 1, -1], result) 720 721 def testInitializeInt64Table(self, is_anonymous): 722 if is_anonymous and not tf2.enabled(): 723 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 724 vocabulary_file = self._createVocabFile( 725 "one_column_int64.txt", values=("42", "1", "-1000")) 726 727 with self.cached_session(): 728 default_value = -1 729 init = lookup_ops.TextFileInitializer( 730 vocabulary_file, dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE, 731 dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER) 732 self.assertIn("one_column_int64.txt_-2_-1", init._shared_name) 733 table = self.getHashTable()( 734 init, default_value, experimental_is_anonymous=is_anonymous) 735 self.initialize_table(table) 736 737 output = table.lookup( 738 constant_op.constant((42, 1, 11), dtype=dtypes.int64)) 739 740 result = self.evaluate(output) 741 self.assertAllEqual([0, 1, -1], result) 742 743 def testInitializeIndexTable(self, is_anonymous): 744 if is_anonymous and not tf2.enabled(): 745 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 746 vocabulary_file = self._createVocabFile("one_column_2.txt") 747 748 with self.cached_session(): 749 default_value = "UNK" 750 key_index = lookup_ops.TextFileIndex.LINE_NUMBER 751 value_index = lookup_ops.TextFileIndex.WHOLE_LINE 752 init = lookup_ops.TextFileInitializer( 753 vocabulary_file, dtypes.int64, key_index, dtypes.string, value_index) 754 self.assertIn("one_column_2.txt_-1_-2", init._shared_name) 755 table = self.getHashTable()( 756 init, default_value, experimental_is_anonymous=is_anonymous) 757 self.initialize_table(table) 758 759 input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64) 760 output = table.lookup(input_values) 761 762 result = self.evaluate(output) 763 self.assertAllEqual([b"brain", b"salad", b"surgery", b"UNK"], result) 764 765 def testMultiColumn(self, is_anonymous): 766 if is_anonymous and not tf2.enabled(): 767 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 768 vocabulary_file = os.path.join(self.get_temp_dir(), "three_columns.txt") 769 with open(vocabulary_file, "w") as f: 770 f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n") 771 772 with self.cached_session(): 773 default_value = -1 774 key_index = 1 775 value_index = 2 776 777 init = lookup_ops.TextFileInitializer( 778 vocabulary_file, dtypes.string, key_index, dtypes.int64, value_index) 779 self.assertIn("three_columns.txt_1_2", init._shared_name) 780 table = self.getHashTable()( 781 init, default_value, experimental_is_anonymous=is_anonymous) 782 self.initialize_table(table) 783 784 input_string = constant_op.constant(["brain", "salad", "surgery"]) 785 output = table.lookup(input_string) 786 787 result = self.evaluate(output) 788 self.assertAllEqual([1, 5, 6], result) 789 790 def testInvalidDataTypeInMultiColumn(self, is_anonymous): 791 vocabulary_file = os.path.join(self.get_temp_dir(), "three_columns.txt") 792 with open(vocabulary_file, "w") as f: 793 f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n") 794 795 with self.cached_session(): 796 default_value = -1 797 key_index = 2 798 value_index = 1 799 init = lookup_ops.TextFileInitializer( 800 vocabulary_file, dtypes.string, key_index, dtypes.int64, value_index) 801 self.assertIn("three_columns.txt_2_1", init._shared_name) 802 with self.assertRaisesOpError("is not a valid"): 803 table = self.getHashTable()( 804 init, default_value, experimental_is_anonymous=is_anonymous) 805 self.initialize_table(table) 806 807 def testInvalidDataType(self, is_anonymous): 808 vocabulary_file = self._createVocabFile("one_column_3.txt") 809 810 with self.cached_session(): 811 default_value = "UNK" 812 key_index = lookup_ops.TextFileIndex.WHOLE_LINE 813 value_index = lookup_ops.TextFileIndex.LINE_NUMBER 814 815 with self.assertRaises(ValueError): 816 init = lookup_ops.TextFileInitializer(vocabulary_file, dtypes.int64, 817 key_index, dtypes.string, 818 value_index) 819 self.assertIn("one_column_3.txt_-2_-1", init._shared_name) 820 self.getHashTable()( 821 init, default_value, experimental_is_anonymous=is_anonymous) 822 823 def testInvalidIndex(self, is_anonymous): 824 vocabulary_file = self._createVocabFile("one_column_4.txt") 825 with self.cached_session(): 826 default_value = -1 827 key_index = 1 # second column of the line 828 value_index = lookup_ops.TextFileIndex.LINE_NUMBER 829 init = lookup_ops.TextFileInitializer( 830 vocabulary_file, dtypes.string, key_index, dtypes.int64, value_index) 831 self.assertIn("one_column_4.txt_1_-1", init._shared_name) 832 833 with self.assertRaisesOpError("Invalid number of columns"): 834 table = self.getHashTable()( 835 init, default_value, experimental_is_anonymous=is_anonymous) 836 self.initialize_table(table) 837 838 def testInitializeSameTableWithMultipleNodes(self, is_anonymous): 839 if is_anonymous and not tf2.enabled(): 840 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 841 vocabulary_file = self._createVocabFile("one_column_5.txt") 842 843 with self.cached_session(): 844 default_value = -1 845 init1 = lookup_ops.TextFileInitializer( 846 vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, 847 dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER) 848 self.assertIn("one_column_5.txt_-2_-1", init1._shared_name) 849 table1 = self.getHashTable()( 850 init1, default_value, experimental_is_anonymous=is_anonymous) 851 init2 = lookup_ops.TextFileInitializer( 852 vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, 853 dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER) 854 self.assertIn("one_column_5.txt_-2_-1", init2._shared_name) 855 table2 = self.getHashTable()( 856 init2, default_value, experimental_is_anonymous=is_anonymous) 857 init3 = lookup_ops.TextFileInitializer( 858 vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, 859 dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER) 860 self.assertIn("one_column_5.txt_-2_-1", init3._shared_name) 861 table3 = self.getHashTable()( 862 init3, default_value, experimental_is_anonymous=is_anonymous) 863 864 self.evaluate(lookup_ops.tables_initializer()) 865 866 input_string = constant_op.constant(["brain", "salad", "tank"]) 867 868 output1 = table1.lookup(input_string) 869 output2 = table2.lookup(input_string) 870 output3 = table3.lookup(input_string) 871 872 out1, out2, out3 = self.evaluate([output1, output2, output3]) 873 self.assertAllEqual([0, 1, -1], out1) 874 self.assertAllEqual([0, 1, -1], out2) 875 self.assertAllEqual([0, 1, -1], out3) 876 877 def testInitializeTableWithNoFilename(self, is_anonymous): 878 with self.cached_session(): 879 default_value = -1 880 with self.assertRaises(ValueError): 881 self.getHashTable()( 882 lookup_ops.TextFileInitializer( 883 "", dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, 884 dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), 885 default_value, 886 experimental_is_anonymous=is_anonymous) 887 888 def testInitializeWithVocabSize(self, is_anonymous): 889 if is_anonymous and not tf2.enabled(): 890 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 891 with self.cached_session(): 892 default_value = -1 893 vocab_size = 3 894 vocabulary_file1 = self._createVocabFile("one_column6.txt") 895 init1 = lookup_ops.TextFileInitializer( 896 vocabulary_file1, 897 dtypes.string, 898 lookup_ops.TextFileIndex.WHOLE_LINE, 899 dtypes.int64, 900 lookup_ops.TextFileIndex.LINE_NUMBER, 901 vocab_size=vocab_size) 902 self.assertIn("one_column6.txt_3_-2_-1", init1._shared_name) 903 table1 = self.getHashTable()( 904 init1, default_value, experimental_is_anonymous=is_anonymous) 905 906 # Initialize from file. 907 self.initialize_table(table1) 908 self.assertEqual(vocab_size, self.evaluate(table1.size())) 909 910 vocabulary_file2 = self._createVocabFile("one_column7.txt") 911 vocab_size = 5 912 init2 = lookup_ops.TextFileInitializer( 913 vocabulary_file2, 914 dtypes.string, 915 lookup_ops.TextFileIndex.WHOLE_LINE, 916 dtypes.int64, 917 lookup_ops.TextFileIndex.LINE_NUMBER, 918 vocab_size=vocab_size) 919 self.assertIn("one_column7.txt_5_-2_-1", init2._shared_name) 920 with self.assertRaisesOpError("Invalid vocab_size"): 921 table2 = self.getHashTable()( 922 init2, default_value, experimental_is_anonymous=is_anonymous) 923 self.initialize_table(table2) 924 925 vocab_size = 1 926 vocabulary_file3 = self._createVocabFile("one_column3.txt") 927 init3 = lookup_ops.TextFileInitializer( 928 vocabulary_file3, 929 dtypes.string, 930 lookup_ops.TextFileIndex.WHOLE_LINE, 931 dtypes.int64, 932 lookup_ops.TextFileIndex.LINE_NUMBER, 933 vocab_size=vocab_size) 934 self.assertIn("one_column3.txt_1_-2_-1", init3._shared_name) 935 table3 = self.getHashTable()( 936 init3, default_value, experimental_is_anonymous=is_anonymous) 937 938 # Smaller vocab size reads only vocab_size records. 939 self.initialize_table(table3) 940 self.assertEqual(vocab_size, self.evaluate(table3.size())) 941 942 @test_util.run_v1_only("placeholder usage") 943 def testFeedVocabularyName(self, is_anonymous): 944 if is_anonymous and not tf2.enabled(): 945 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 946 vocabulary_file = self._createVocabFile("feed_vocabulary.txt") 947 948 with self.cached_session(): 949 default_value = -1 950 init = lookup_ops.TextFileInitializer( 951 "old_file.txt", dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, 952 dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER) 953 self.assertIn("old_file.txt_-2_-1", init._shared_name) 954 table = self.getHashTable()( 955 init, default_value, experimental_is_anonymous=is_anonymous) 956 957 # Initialize with non existing file (old_file.txt) should fail. 958 # TODO(yleon): Update message, which might change per FileSystem. 959 with self.assertRaisesOpError("old_file.txt"): 960 self.evaluate(table.initializer) 961 962 # Initialize the model feeding the vocabulary file. 963 filenames = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS) 964 table.initializer.run(feed_dict={filenames[0]: vocabulary_file}) 965 966 input_string = constant_op.constant(["brain", "salad", "tank"]) 967 output = table.lookup(input_string) 968 969 result = self.evaluate(output) 970 self.assertAllEqual([0, 1, -1], result) 971 972 def testInvalidFilenames(self, is_anonymous): 973 vocabulary_file = self._createVocabFile("filename_shape.txt") 974 975 with self.cached_session(): 976 default_value = -1 977 978 # Invalid data type 979 other_type = constant_op.constant(1) 980 with self.assertRaises(Exception) as cm: 981 self.getHashTable()( 982 lookup_ops.TextFileInitializer( 983 other_type, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, 984 dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), 985 default_value, 986 experimental_is_anonymous=is_anonymous) 987 self.assertIsInstance(cm.exception, (ValueError, TypeError)) 988 989 # Non-scalar filename 990 filenames = constant_op.constant([vocabulary_file, vocabulary_file]) 991 if not context.executing_eagerly(): 992 with self.assertRaises(Exception) as cm: 993 self.getHashTable()( 994 lookup_ops.TextFileInitializer( 995 filenames, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, 996 dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), 997 default_value, 998 experimental_is_anonymous=is_anonymous) 999 self.assertIsInstance(cm.exception, (ValueError, TypeError)) 1000 else: 1001 with self.assertRaises(errors_impl.InvalidArgumentError): 1002 self.getHashTable()( 1003 lookup_ops.TextFileInitializer( 1004 filenames, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE, 1005 dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER), 1006 default_value, 1007 experimental_is_anonymous=is_anonymous) 1008 1009 def testIdToStringTable(self, is_anonymous): 1010 if is_anonymous and not tf2.enabled(): 1011 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1012 vocab_file = self._createVocabFile("feat_to_id_1.txt") 1013 with self.cached_session(): 1014 default_value = "UNK" 1015 vocab_size = 3 1016 init = lookup_ops.TextFileStringTableInitializer( 1017 vocab_file, vocab_size=vocab_size) 1018 self.assertTrue("feat_to_id_1.txt_3_-1_-2", init._shared_name) 1019 table = self.getHashTable()( 1020 init, default_value, experimental_is_anonymous=is_anonymous) 1021 1022 self.initialize_table(table) 1023 1024 input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64) 1025 1026 out = table.lookup(input_values) 1027 self.assertAllEqual([b"brain", b"salad", b"surgery", b"UNK"], 1028 self.evaluate(out)) 1029 self.assertEqual(vocab_size, self.evaluate(table.size())) 1030 1031 def testStringToIdTable(self, is_anonymous): 1032 if is_anonymous and not tf2.enabled(): 1033 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1034 vocab_file = self._createVocabFile("feat_to_id_2.txt") 1035 with self.cached_session(): 1036 default_value = -1 1037 vocab_size = 3 1038 init = lookup_ops.TextFileIdTableInitializer( 1039 vocab_file, vocab_size=vocab_size) 1040 self.assertTrue("feat_to_id_2.txt_3_-1_-2", init._shared_name) 1041 table = self.getHashTable()( 1042 init, default_value, experimental_is_anonymous=is_anonymous) 1043 self.initialize_table(table) 1044 1045 input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"]) 1046 1047 out = table.lookup(input_string) 1048 self.assertAllEqual([0, 1, 2, -1], self.evaluate(out)) 1049 self.assertEqual(vocab_size, self.evaluate(table.size())) 1050 1051 def testInt64ToIdTable(self, is_anonymous): 1052 if is_anonymous and not tf2.enabled(): 1053 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1054 vocab_file = self._createVocabFile( 1055 "feat_to_id_3.txt", values=("42", "1", "-1000")) 1056 with self.cached_session(): 1057 default_value = -1 1058 vocab_size = 3 1059 init = lookup_ops.TextFileIdTableInitializer( 1060 vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64) 1061 self.assertTrue("feat_to_id_3.txt_3_-1_-2", init._shared_name) 1062 table = self.getHashTable()( 1063 init, default_value, experimental_is_anonymous=is_anonymous) 1064 self.initialize_table(table) 1065 1066 out = table.lookup( 1067 constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64)) 1068 self.assertAllEqual((0, 1, 2, -1), self.evaluate(out)) 1069 self.assertEqual(vocab_size, self.evaluate(table.size())) 1070 1071 1072@parameterized.named_parameters( 1073 (f"_{is_anonymous}", is_anonymous) for is_anonymous in [False, True]) 1074class StaticVocabularyTableTest(BaseLookupTableTest): 1075 1076 def _createVocabFile(self, basename, values=("brain", "salad", "surgery")): 1077 vocabulary_file = os.path.join(self.get_temp_dir(), basename) 1078 with open(vocabulary_file, "w") as f: 1079 f.write("\n".join(values) + "\n") 1080 return vocabulary_file 1081 1082 def testStringStaticVocabularyTable(self, is_anonymous): 1083 if is_anonymous and not tf2.enabled(): 1084 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1085 vocab_file = self._createVocabFile("feat_to_id_1.txt") 1086 vocab_size = 3 1087 oov_buckets = 1 1088 table = self.getVocabularyTable()( 1089 lookup_ops.TextFileIdTableInitializer( 1090 vocab_file, vocab_size=vocab_size), 1091 oov_buckets, 1092 experimental_is_anonymous=is_anonymous) 1093 1094 self.initialize_table(table) 1095 1096 input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"]) 1097 1098 out = table.lookup(input_string) 1099 self.assertAllEqual([0, 1, 2, 3], self.evaluate(out)) 1100 self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size())) 1101 1102 def testStaticVocabularyTableGetItem(self, is_anonymous): 1103 if is_anonymous and not tf2.enabled(): 1104 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1105 vocab_file = self._createVocabFile("feat_to_id_1.txt") 1106 vocab_size = 3 1107 oov_buckets = 1 1108 table = self.getVocabularyTable()( 1109 lookup_ops.TextFileIdTableInitializer( 1110 vocab_file, vocab_size=vocab_size), 1111 oov_buckets, 1112 experimental_is_anonymous=is_anonymous) 1113 1114 self.initialize_table(table) 1115 1116 input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"]) 1117 1118 out = table[input_string] 1119 self.assertAllEqual([0, 1, 2, 3], self.evaluate(out)) 1120 self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size())) 1121 1122 def testInt32StaticVocabularyTable(self, is_anonymous): 1123 if is_anonymous and not tf2.enabled(): 1124 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1125 vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000")) 1126 vocab_size = 3 1127 oov_buckets = 1 1128 table = self.getVocabularyTable()( 1129 lookup_ops.TextFileIdTableInitializer( 1130 vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64), 1131 oov_buckets, 1132 lookup_key_dtype=dtypes.int32, 1133 experimental_is_anonymous=is_anonymous) 1134 1135 self.initialize_table(table) 1136 1137 values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int32) 1138 1139 out = table.lookup(values) 1140 self.assertAllEqual([0, 1, 2, 3], self.evaluate(out)) 1141 self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size())) 1142 1143 def testInt64StaticVocabularyTable(self, is_anonymous): 1144 if is_anonymous and not tf2.enabled(): 1145 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1146 vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000")) 1147 vocab_size = 3 1148 oov_buckets = 1 1149 table = self.getVocabularyTable()( 1150 lookup_ops.TextFileIdTableInitializer( 1151 vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64), 1152 oov_buckets, 1153 experimental_is_anonymous=is_anonymous) 1154 1155 self.initialize_table(table) 1156 1157 values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64) 1158 1159 out = table.lookup(values) 1160 self.assertAllEqual([0, 1, 2, 3], self.evaluate(out)) 1161 self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size())) 1162 1163 def testStringStaticVocabularyTableNoInitializer(self, is_anonymous): 1164 oov_buckets = 5 1165 1166 # Set a table that only uses hash buckets, for each input value returns 1167 # an id calculated by fingerprint("input") mod oov_buckets. 1168 table = self.getVocabularyTable()( 1169 None, oov_buckets, experimental_is_anonymous=is_anonymous) 1170 self.initialize_table(table) 1171 1172 values = constant_op.constant(("brain", "salad", "surgery")) 1173 1174 out = table.lookup(values) 1175 self.assertAllEqual( 1176 [ 1177 3, # fingerprint("brain") mod 5. 1178 1, # fingerprint("salad") mod 5. 1179 4 # fingerprint("surgery") mod 5 1180 ], 1181 self.evaluate(out)) 1182 self.assertEqual(oov_buckets, self.evaluate(table.size())) 1183 1184 def testStaticVocabularyTableWithMultipleInitializers(self, is_anonymous): 1185 if is_anonymous and not tf2.enabled(): 1186 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1187 vocab_file = self._createVocabFile("feat_to_id_4.txt") 1188 vocab_size = 3 1189 oov_buckets = 3 1190 1191 init = lookup_ops.TextFileIdTableInitializer( 1192 vocab_file, vocab_size=vocab_size) 1193 table1 = self.getVocabularyTable()( 1194 init, 1195 oov_buckets, 1196 name="table1", 1197 experimental_is_anonymous=is_anonymous) 1198 1199 table2 = self.getVocabularyTable()( 1200 init, 1201 oov_buckets, 1202 name="table2", 1203 experimental_is_anonymous=is_anonymous) 1204 1205 self.evaluate(lookup_ops.tables_initializer()) 1206 1207 input_string = constant_op.constant( 1208 ["fruit", "brain", "salad", "surgery", "UNK"]) 1209 1210 out1 = table1.lookup(input_string) 1211 out2 = table2.lookup(input_string) 1212 1213 out1, out2 = self.evaluate([out1, out2]) 1214 self.assertAllEqual([5, 0, 1, 2, 5], out1) 1215 self.assertAllEqual([5, 0, 1, 2, 5], out2) 1216 self.assertEqual(vocab_size + oov_buckets, self.evaluate(table1.size())) 1217 self.assertEqual(vocab_size + oov_buckets, self.evaluate(table2.size())) 1218 1219 def testStaticVocabularyTableInitializationAcrossSessions(self, is_anonymous): 1220 if is_anonymous and not tf2.enabled(): 1221 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1222 vocab_file = self._createVocabFile("feat_to_id_5.txt") 1223 with self.cached_session(): 1224 vocab_size = 3 1225 oov_buckets = 1 1226 table1 = self.getVocabularyTable()( 1227 lookup_ops.TextFileIdTableInitializer( 1228 vocab_file, vocab_size=vocab_size), 1229 oov_buckets, 1230 experimental_is_anonymous=is_anonymous) 1231 1232 self.initialize_table(table1) 1233 1234 input_string_1 = constant_op.constant( 1235 ["brain", "salad", "surgery", "UNK"]) 1236 1237 out1 = table1.lookup(input_string_1) 1238 1239 self.assertAllEqual([0, 1, 2, 3], self.evaluate(out1)) 1240 self.assertEqual(vocab_size + oov_buckets, self.evaluate(table1.size())) 1241 1242 with self.cached_session(): 1243 vocab_size = 3 1244 oov_buckets = 1 1245 1246 # Underlying lookup table already initialized in previous session. 1247 # No need to initialize table2 1248 table2 = self.getVocabularyTable()( 1249 lookup_ops.TextFileIdTableInitializer( 1250 vocab_file, vocab_size=vocab_size), 1251 oov_buckets, 1252 experimental_is_anonymous=is_anonymous) 1253 1254 input_string_2 = constant_op.constant(["fruit", "salad", "UNK"]) 1255 1256 out2 = table2.lookup(input_string_2) 1257 1258 self.assertAllEqual([3, 1, 3], self.evaluate(out2)) 1259 self.assertEqual(vocab_size + oov_buckets, self.evaluate(table2.size())) 1260 1261 def testStaticVocabularyTableAssetTracking(self, is_anonymous): 1262 vocab_file = self._createVocabFile("vocab.txt") 1263 vocab_size = 3 1264 oov_buckets = 1 1265 table = self.getVocabularyTable()( 1266 lookup_ops.TextFileIdTableInitializer( 1267 vocab_file, vocab_size=vocab_size), 1268 oov_buckets, 1269 experimental_is_anonymous=is_anonymous) 1270 objects = checkpoint_util.list_objects(graph_view.ObjectGraphView(table)) 1271 assets = list(filter(lambda obj: isinstance(obj, asset.Asset), objects)) 1272 self.assertLen(assets, 1) 1273 self.assertEqual( 1274 self.evaluate(assets[0].asset_path), compat.as_bytes(vocab_file)) 1275 1276 def testSparseTensor(self, is_anonymous): 1277 if is_anonymous and not tf2.enabled(): 1278 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1279 vocab_file = self._createVocabFile("feat_to_id_7.txt") 1280 input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] 1281 input_shape = [4, 4] 1282 sp_features = sparse_tensor.SparseTensor( 1283 constant_op.constant(input_indices, dtypes.int64), 1284 constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"], 1285 dtypes.string), 1286 constant_op.constant(input_shape, dtypes.int64)) 1287 1288 table = self.getVocabularyTable()( 1289 lookup_ops.TextFileIdTableInitializer(vocab_file, vocab_size=3), 1290 1, 1291 experimental_is_anonymous=is_anonymous) 1292 self.initialize_table(table) 1293 1294 sp_ids = table.lookup(sp_features) 1295 1296 self.assertAllEqual([5], sp_ids.values._shape_as_list()) 1297 1298 sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate( 1299 [sp_ids.indices, sp_ids.values, sp_ids.dense_shape]) 1300 1301 self.assertAllEqual(input_indices, sp_ids_ind) 1302 self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val) 1303 self.assertAllEqual(input_shape, sp_ids_shape) 1304 1305 def testRaggedTensor(self, is_anonymous): 1306 if is_anonymous and not tf2.enabled(): 1307 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1308 vocab_file = self._createVocabFile("feat_to_id_7.txt") 1309 input_row_splits = [0, 2, 4, 5] 1310 ragged_features = ragged_tensor.RaggedTensor.from_row_splits( 1311 constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"], 1312 dtypes.string), 1313 constant_op.constant(input_row_splits, dtypes.int64)) 1314 1315 table = self.getVocabularyTable()( 1316 lookup_ops.TextFileIdTableInitializer(vocab_file, vocab_size=3), 1317 1, 1318 experimental_is_anonymous=is_anonymous) 1319 self.initialize_table(table) 1320 1321 ragged_ids = table.lookup(ragged_features) 1322 1323 self.assertAllEqual([5], ragged_ids.values._shape_as_list()) 1324 1325 ragged_ids_val, ragged_ids_row_splits = self.evaluate( 1326 [ragged_ids.values, ragged_ids.row_splits]) 1327 1328 self.assertAllEqual([0, 1, 0, 2, 3], ragged_ids_val) 1329 self.assertAllEqual(input_row_splits, ragged_ids_row_splits) 1330 1331 def testInt32SparseTensor(self, is_anonymous): 1332 if is_anonymous and not tf2.enabled(): 1333 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1334 input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] 1335 input_shape = [4, 4] 1336 sp_features = sparse_tensor.SparseTensor( 1337 constant_op.constant(input_indices, dtypes.int64), 1338 constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32), 1339 constant_op.constant(input_shape, dtypes.int64)) 1340 1341 table = self.getVocabularyTable()( 1342 lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2), 1343 dtypes.int64, dtypes.int64), 1344 1, 1345 lookup_key_dtype=dtypes.int32, 1346 experimental_is_anonymous=is_anonymous) 1347 self.initialize_table(table) 1348 1349 sp_ids = table.lookup(sp_features) 1350 1351 self.assertAllEqual([5], sp_ids.values._shape_as_list()) 1352 1353 sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate( 1354 [sp_ids.indices, sp_ids.values, sp_ids.dense_shape]) 1355 1356 self.assertAllEqual(input_indices, sp_ids_ind) 1357 self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val) 1358 self.assertAllEqual(input_shape, sp_ids_shape) 1359 1360 def testInt32RaggedTensor(self, is_anonymous): 1361 if is_anonymous and not tf2.enabled(): 1362 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1363 input_row_splits = [0, 2, 4, 5] 1364 ragged_features = ragged_tensor.RaggedTensor.from_row_splits( 1365 constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32), 1366 constant_op.constant(input_row_splits, dtypes.int64)) 1367 1368 table = self.getVocabularyTable()( 1369 lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2), 1370 dtypes.int64, dtypes.int64), 1371 1, 1372 lookup_key_dtype=dtypes.int32, 1373 experimental_is_anonymous=is_anonymous) 1374 self.initialize_table(table) 1375 1376 ragged_ids = table.lookup(ragged_features) 1377 1378 self.assertAllEqual([5], ragged_ids.values._shape_as_list()) 1379 1380 ragged_ids_val, ragged_ids_row_splits = self.evaluate( 1381 [ragged_ids.values, ragged_ids.row_splits]) 1382 1383 self.assertAllEqual([0, 1, 0, 2, 3], ragged_ids_val) 1384 self.assertAllEqual(input_row_splits, ragged_ids_row_splits) 1385 1386 def testInt64SparseTensor(self, is_anonymous): 1387 if is_anonymous and not tf2.enabled(): 1388 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1389 input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] 1390 input_shape = [4, 4] 1391 sp_features = sparse_tensor.SparseTensor( 1392 constant_op.constant(input_indices, dtypes.int64), 1393 constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64), 1394 constant_op.constant(input_shape, dtypes.int64)) 1395 1396 table = self.getVocabularyTable()( 1397 lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2), 1398 dtypes.int64, dtypes.int64), 1399 1, 1400 experimental_is_anonymous=is_anonymous) 1401 self.initialize_table(table) 1402 1403 sp_ids = table.lookup(sp_features) 1404 1405 self.assertAllEqual([5], sp_ids.values._shape_as_list()) 1406 1407 sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate( 1408 [sp_ids.indices, sp_ids.values, sp_ids.dense_shape]) 1409 1410 self.assertAllEqual(input_indices, sp_ids_ind) 1411 self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val) 1412 self.assertAllEqual(input_shape, sp_ids_shape) 1413 1414 def testInt64RaggedTensor(self, is_anonymous): 1415 if is_anonymous and not tf2.enabled(): 1416 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1417 input_row_splits = [0, 2, 4, 5] 1418 ragged_features = ragged_tensor.RaggedTensor.from_row_splits( 1419 constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64), 1420 constant_op.constant(input_row_splits, dtypes.int64)) 1421 1422 table = self.getVocabularyTable()( 1423 lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2), 1424 dtypes.int64, dtypes.int64), 1425 1, 1426 experimental_is_anonymous=is_anonymous) 1427 self.initialize_table(table) 1428 1429 ragged_ids = table.lookup(ragged_features) 1430 1431 self.assertAllEqual([5], ragged_ids.values._shape_as_list()) 1432 1433 ragged_ids_val, ragged_ids_row_splits = self.evaluate( 1434 [ragged_ids.values, ragged_ids.row_splits]) 1435 1436 self.assertAllEqual([0, 1, 0, 2, 3], ragged_ids_val) 1437 self.assertAllEqual(input_row_splits, ragged_ids_row_splits) 1438 1439 def testStaticVocabularyTableNoInnerTable(self, is_anonymous): 1440 table = self.getVocabularyTable()( 1441 None, num_oov_buckets=1, experimental_is_anonymous=is_anonymous) 1442 self.assertIsNone(table.resource_handle) 1443 1444 @test_util.run_v2_only 1445 def testSavedModelSaveRestore(self, is_anonymous): 1446 save_dir = os.path.join(self.get_temp_dir(), "save_restore") 1447 save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") 1448 1449 root = autotrackable.AutoTrackable() 1450 1451 vocab_file = self._createVocabFile("feat_to_id_3.txt", ("11", "12", "13")) 1452 vocab_size = 3 1453 oov_buckets = 1 1454 root.table = self.getVocabularyTable()( 1455 lookup_ops.TextFileIdTableInitializer( 1456 vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64), 1457 oov_buckets, 1458 experimental_is_anonymous=is_anonymous) 1459 1460 @def_function.function( 1461 input_signature=[tensor_spec.TensorSpec((), dtypes.int64)]) 1462 def lookup(key): 1463 return root.table.lookup(key) 1464 1465 @def_function.function(input_signature=[]) 1466 def size(): 1467 return root.table.size() 1468 1469 @def_function.function(input_signature=[]) 1470 def is_ref_counting(): 1471 return test_ops.is_resource_handle_ref_counting( 1472 root.table.resource_handle) 1473 1474 root.lookup = lookup 1475 root.size = size 1476 root.is_ref_counting = is_ref_counting 1477 1478 self.assertEqual(root.table.size(), 4) 1479 self.assertEqual(root.lookup(12), 1) 1480 self.assertEqual(root.lookup(10), 3) 1481 self.assertEqual(root.is_ref_counting(), is_anonymous) 1482 1483 saved_model_save.save(root, save_path) 1484 1485 del root 1486 loaded = saved_model_load.load(save_path) 1487 self.assertEqual(loaded.size(), 4) 1488 self.assertEqual(loaded.lookup(12), 1) 1489 self.assertEqual(loaded.lookup(10), 3) 1490 self.assertEqual(loaded.is_ref_counting(), is_anonymous) 1491 1492 1493@parameterized.named_parameters( 1494 (f"_{is_anonymous}", is_anonymous) for is_anonymous in [False, True]) 1495class DenseHashTableOpTest(test.TestCase): 1496 1497 def testBasic(self, is_anonymous): 1498 if is_anonymous and not tf2.enabled(): 1499 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1500 keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) 1501 values = constant_op.constant([0, 1, 2, 3], dtypes.int64) 1502 table = lookup_ops.DenseHashTable( 1503 dtypes.int64, 1504 dtypes.int64, 1505 default_value=-1, 1506 empty_key=0, 1507 deleted_key=-1, 1508 experimental_is_anonymous=is_anonymous) 1509 self.assertAllEqual(0, self.evaluate(table.size())) 1510 1511 self.evaluate(table.insert(keys, values)) 1512 self.assertAllEqual(4, self.evaluate(table.size())) 1513 1514 remove_string = constant_op.constant([12, 15], dtypes.int64) 1515 self.evaluate(table.remove(remove_string)) 1516 self.assertAllEqual(3, self.evaluate(table.size())) 1517 1518 input_string = constant_op.constant([11, 12, 15], dtypes.int64) 1519 output = table.lookup(input_string) 1520 self.assertAllEqual([3], output.get_shape()) 1521 1522 result = self.evaluate(output) 1523 self.assertAllEqual([0, -1, -1], result) 1524 1525 def testGetItem(self, is_anonymous): 1526 if is_anonymous and not tf2.enabled(): 1527 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1528 keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) 1529 values = constant_op.constant([0, 1, 2, 3], dtypes.int64) 1530 table = lookup_ops.DenseHashTable( 1531 dtypes.int64, 1532 dtypes.int64, 1533 default_value=-1, 1534 empty_key=0, 1535 deleted_key=-1, 1536 experimental_is_anonymous=is_anonymous) 1537 1538 self.evaluate(table.insert(keys, values)) 1539 1540 input_string = constant_op.constant([11, 12, 15], dtypes.int64) 1541 output = table[input_string] 1542 self.assertAllEqual([3], output.get_shape()) 1543 1544 result = self.evaluate(output) 1545 self.assertAllEqual([0, 1, -1], result) 1546 1547 def testBasicBool(self, is_anonymous): 1548 if is_anonymous and not tf2.enabled(): 1549 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1550 keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) 1551 values = constant_op.constant([True, True, True, True], dtypes.bool) 1552 table = lookup_ops.DenseHashTable( 1553 dtypes.int64, 1554 dtypes.bool, 1555 default_value=False, 1556 empty_key=0, 1557 deleted_key=-1, 1558 experimental_is_anonymous=is_anonymous) 1559 self.assertAllEqual(0, self.evaluate(table.size())) 1560 1561 self.evaluate(table.insert(keys, values)) 1562 self.assertAllEqual(4, self.evaluate(table.size())) 1563 1564 remove_string = constant_op.constant([11, 15], dtypes.int64) 1565 self.evaluate(table.remove(remove_string)) 1566 self.assertAllEqual(3, self.evaluate(table.size())) 1567 1568 input_string = constant_op.constant([11, 12, 15], dtypes.int64) 1569 output = table.lookup(input_string) 1570 self.assertAllEqual([3], output.get_shape()) 1571 1572 result = self.evaluate(output) 1573 self.assertAllEqual([False, True, False], result) 1574 1575 def testSameEmptyAndDeletedKey(self, is_anonymous): 1576 with self.assertRaisesRegex(errors_impl.InvalidArgumentError, 1577 "Empty and deleted keys"): 1578 table = lookup_ops.DenseHashTable( 1579 dtypes.int64, 1580 dtypes.int64, 1581 default_value=-1, 1582 empty_key=42, 1583 deleted_key=42, 1584 experimental_is_anonymous=is_anonymous) 1585 self.assertAllEqual(0, self.evaluate(table.size())) 1586 1587 @test_util.run_v1_only("uses placeholders") 1588 def testLookupUnknownShape(self, is_anonymous): 1589 if is_anonymous and not tf2.enabled(): 1590 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1591 with self.cached_session(): 1592 keys = constant_op.constant([11, 12, 13], dtypes.int64) 1593 values = constant_op.constant([0, 1, 2], dtypes.int64) 1594 table = lookup_ops.DenseHashTable( 1595 dtypes.int64, 1596 dtypes.int64, 1597 default_value=-1, 1598 empty_key=0, 1599 deleted_key=-1, 1600 experimental_is_anonymous=is_anonymous) 1601 1602 self.evaluate(table.insert(keys, values)) 1603 self.assertAllEqual(3, self.evaluate(table.size())) 1604 1605 placeholder_keys = array_ops.placeholder(dtypes.int64) 1606 output = table.lookup(placeholder_keys) 1607 self.assertAllEqual(None, output.get_shape()) 1608 result = output.eval({placeholder_keys: [11, 12, 15]}) 1609 self.assertAllEqual([0, 1, -1], result) 1610 1611 def testMapStringToFloat(self, is_anonymous): 1612 if is_anonymous and not tf2.enabled(): 1613 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1614 keys = constant_op.constant(["a", "b", "c", "d"], dtypes.string) 1615 values = constant_op.constant([0.0, 1.1, 2.2, 3.3], dtypes.float32) 1616 default_value = constant_op.constant(-1.5, dtypes.float32) 1617 table = lookup_ops.DenseHashTable( 1618 dtypes.string, 1619 dtypes.float32, 1620 default_value=default_value, 1621 empty_key="", 1622 deleted_key="$", 1623 experimental_is_anonymous=is_anonymous) 1624 self.assertAllEqual(0, self.evaluate(table.size())) 1625 1626 self.evaluate(table.insert(keys, values)) 1627 self.assertAllEqual(4, self.evaluate(table.size())) 1628 1629 remove_string = constant_op.constant(["b", "e"]) 1630 self.evaluate(table.remove(remove_string)) 1631 self.assertAllEqual(3, self.evaluate(table.size())) 1632 1633 input_string = constant_op.constant(["a", "b", "d", "e"], dtypes.string) 1634 output = table.lookup(input_string) 1635 self.assertAllEqual([4], output.get_shape()) 1636 1637 result = self.evaluate(output) 1638 self.assertAllClose([0, -1.5, 3.3, -1.5], result) 1639 1640 def testMapInt64ToFloat(self, is_anonymous): 1641 if is_anonymous and not tf2.enabled(): 1642 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1643 for float_dtype in [dtypes.float32, dtypes.float64]: 1644 keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) 1645 values = constant_op.constant([0.0, 1.1, 2.2, 3.3], float_dtype) 1646 default_value = constant_op.constant(-1.5, float_dtype) 1647 table = lookup_ops.DenseHashTable( 1648 dtypes.int64, 1649 float_dtype, 1650 default_value=default_value, 1651 empty_key=0, 1652 deleted_key=-1, 1653 experimental_is_anonymous=is_anonymous) 1654 self.assertAllEqual(0, self.evaluate(table.size())) 1655 1656 self.evaluate(table.insert(keys, values)) 1657 self.assertAllEqual(4, self.evaluate(table.size())) 1658 1659 remove_string = constant_op.constant([12, 15], dtypes.int64) 1660 self.evaluate(table.remove(remove_string)) 1661 self.assertAllEqual(3, self.evaluate(table.size())) 1662 1663 input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64) 1664 output = table.lookup(input_string) 1665 self.assertAllEqual([4], output.get_shape()) 1666 1667 result = self.evaluate(output) 1668 self.assertAllClose([0, -1.5, 3.3, -1.5], result) 1669 1670 def testVectorValues(self, is_anonymous): 1671 if is_anonymous and not tf2.enabled(): 1672 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1673 keys = constant_op.constant([11, 12, 13], dtypes.int64) 1674 values = constant_op.constant([[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]], 1675 dtypes.int64) 1676 default_value = constant_op.constant([-1, -2, -3, -4], dtypes.int64) 1677 table = lookup_ops.DenseHashTable( 1678 dtypes.int64, 1679 dtypes.int64, 1680 default_value=default_value, 1681 empty_key=0, 1682 deleted_key=-1, 1683 initial_num_buckets=4, 1684 experimental_is_anonymous=is_anonymous) 1685 self.assertAllEqual(0, self.evaluate(table.size())) 1686 1687 self.evaluate(table.insert(keys, values)) 1688 self.assertAllEqual(3, self.evaluate(table.size())) 1689 self.assertAllEqual(4, len(self.evaluate(table.export()[0]))) 1690 1691 self.evaluate( 1692 table.insert( 1693 constant_op.constant([14], dtypes.int64), 1694 constant_op.constant([[2, 3, 4, 5]], dtypes.int64))) 1695 self.assertAllEqual(4, self.evaluate(table.size())) 1696 self.assertAllEqual(8, len(self.evaluate(table.export()[0]))) 1697 1698 remove_string = constant_op.constant([12, 16], dtypes.int64) 1699 self.evaluate(table.remove(remove_string)) 1700 self.assertAllEqual(3, self.evaluate(table.size())) 1701 self.assertAllEqual(8, len(self.evaluate(table.export()[0]))) 1702 1703 input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64) 1704 output = table.lookup(input_string) 1705 self.assertAllEqual([4, 4], 1706 output.shape, 1707 msg="Saw shape: %s" % output.shape) 1708 1709 result = self.evaluate(output) 1710 self.assertAllEqual( 1711 [[0, 1, 2, 3], [-1, -2, -3, -4], [2, 3, 4, 5], [-1, -2, -3, -4]], 1712 result) 1713 1714 def testVectorKeys(self, is_anonymous): 1715 if is_anonymous and not tf2.enabled(): 1716 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1717 keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64) 1718 values = constant_op.constant([10, 11, 12], dtypes.int64) 1719 empty_key = constant_op.constant([0, 3], dtypes.int64) 1720 deleted_key = constant_op.constant([-1, -1], dtypes.int64) 1721 default_value = constant_op.constant(-1, dtypes.int64) 1722 table = lookup_ops.DenseHashTable( 1723 dtypes.int64, 1724 dtypes.int64, 1725 default_value=default_value, 1726 empty_key=empty_key, 1727 deleted_key=deleted_key, 1728 initial_num_buckets=8, 1729 experimental_is_anonymous=is_anonymous) 1730 self.assertAllEqual(0, self.evaluate(table.size())) 1731 1732 self.evaluate(table.insert(keys, values)) 1733 self.assertAllEqual(3, self.evaluate(table.size())) 1734 1735 self.evaluate( 1736 table.insert( 1737 constant_op.constant([[0, 0]], dtypes.int64), 1738 constant_op.constant([13], dtypes.int64))) 1739 self.assertAllEqual(4, self.evaluate(table.size())) 1740 self.assertAllEqual(8, len(self.evaluate(table.export()[0]))) 1741 1742 remove_string = constant_op.constant([[1, 2], [7, 8]], dtypes.int64) 1743 self.evaluate(table.remove(remove_string)) 1744 self.assertAllEqual(3, self.evaluate(table.size())) 1745 self.assertAllEqual(8, len(self.evaluate(table.export()[0]))) 1746 1747 input_string = constant_op.constant([[0, 1], [1, 2], [1, 3], [0, 2]], 1748 dtypes.int64) 1749 output = table.lookup(input_string) 1750 self.assertAllEqual([4], output.get_shape()) 1751 1752 result = self.evaluate(output) 1753 self.assertAllEqual([10, -1, 12, -1], result) 1754 1755 def testResize(self, is_anonymous): 1756 if is_anonymous and not tf2.enabled(): 1757 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1758 keys = constant_op.constant([11, 12, 13], dtypes.int64) 1759 values = constant_op.constant([0, 1, 2], dtypes.int64) 1760 table = lookup_ops.DenseHashTable( 1761 dtypes.int64, 1762 dtypes.int64, 1763 default_value=-1, 1764 empty_key=0, 1765 deleted_key=-1, 1766 initial_num_buckets=4, 1767 experimental_is_anonymous=is_anonymous) 1768 self.assertAllEqual(0, self.evaluate(table.size())) 1769 1770 self.evaluate(table.insert(keys, values)) 1771 self.assertAllEqual(3, self.evaluate(table.size())) 1772 self.assertAllEqual(4, len(self.evaluate(table.export()[0]))) 1773 1774 keys2 = constant_op.constant([12, 99], dtypes.int64) 1775 self.evaluate(table.remove(keys2)) 1776 self.assertAllEqual(2, self.evaluate(table.size())) 1777 self.assertAllEqual(4, len(self.evaluate(table.export()[0]))) 1778 1779 keys3 = constant_op.constant([13, 14, 15, 16, 17], dtypes.int64) 1780 values3 = constant_op.constant([3, 4, 5, 6, 7], dtypes.int64) 1781 1782 self.evaluate(table.insert(keys3, values3)) 1783 self.assertAllEqual(6, self.evaluate(table.size())) 1784 self.assertAllEqual(16, len(self.evaluate(table.export()[0]))) 1785 1786 keys4 = constant_op.constant([10, 11, 12, 13, 14, 15, 16, 17, 18], 1787 dtypes.int64) 1788 output = table.lookup(keys4) 1789 self.assertAllEqual([-1, 0, -1, 3, 4, 5, 6, 7, -1], self.evaluate(output)) 1790 1791 def testExport(self, is_anonymous): 1792 if is_anonymous and not tf2.enabled(): 1793 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1794 keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) 1795 values = constant_op.constant([1, 2, 3, 4], dtypes.int64) 1796 table = lookup_ops.DenseHashTable( 1797 dtypes.int64, 1798 dtypes.int64, 1799 default_value=-1, 1800 empty_key=100, 1801 deleted_key=200, 1802 initial_num_buckets=8, 1803 experimental_is_anonymous=is_anonymous) 1804 self.assertAllEqual(0, self.evaluate(table.size())) 1805 1806 self.evaluate(table.insert(keys, values)) 1807 self.assertAllEqual(4, self.evaluate(table.size())) 1808 1809 keys2 = constant_op.constant([12, 15], dtypes.int64) 1810 self.evaluate(table.remove(keys2)) 1811 self.assertAllEqual(3, self.evaluate(table.size())) 1812 1813 exported_keys, exported_values = table.export() 1814 1815 np_keys = self.evaluate(exported_keys) 1816 np_values = self.evaluate(exported_values) 1817 1818 self.assertAllEqual(8, len(np_keys)) 1819 self.assertAllEqual(8, len(np_values)) 1820 1821 # pair up keys and values, drop extra added dimension 1822 pairs = np.dstack((np_keys.flatten(), np_values.flatten()))[0] 1823 # sort by key 1824 pairs = pairs[pairs[:, 0].argsort()] 1825 self.assertAllEqual([[11, 1], [13, 3], [14, 4], [100, 0], [100, 0], 1826 [100, 0], [100, 0], [200, 2]], pairs) 1827 1828 @test_util.run_v1_only("Saver V1 only") 1829 def testSaveRestore(self, is_anonymous): 1830 if is_anonymous and not tf2.enabled(): 1831 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1832 save_dir = os.path.join(self.get_temp_dir(), "save_restore") 1833 save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") 1834 1835 with self.session(graph=ops.Graph()) as sess: 1836 default_value = -1 1837 empty_key = 0 1838 deleted_key = -1 1839 keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) 1840 values = constant_op.constant([0, 1, 2, 3], dtypes.int64) 1841 table = lookup_ops.DenseHashTable( 1842 dtypes.int64, 1843 dtypes.int64, 1844 default_value=default_value, 1845 empty_key=empty_key, 1846 deleted_key=deleted_key, 1847 name="t1", 1848 checkpoint=True, 1849 initial_num_buckets=32, 1850 experimental_is_anonymous=is_anonymous) 1851 1852 save = saver.Saver() 1853 1854 self.assertAllEqual(0, table.size()) 1855 table.insert(keys, values).run() 1856 self.assertAllEqual(4, table.size()) 1857 self.assertAllEqual(32, len(table.export()[0].eval())) 1858 1859 keys2 = constant_op.constant([12, 15], dtypes.int64) 1860 table.remove(keys2).run() 1861 self.assertAllEqual(3, table.size()) 1862 self.assertAllEqual(32, len(table.export()[0].eval())) 1863 1864 val = save.save(sess, save_path) 1865 self.assertIsInstance(val, str) 1866 self.assertEqual(save_path, val) 1867 1868 with self.session(graph=ops.Graph()) as sess: 1869 table = lookup_ops.DenseHashTable( 1870 dtypes.int64, 1871 dtypes.int64, 1872 default_value=default_value, 1873 empty_key=empty_key, 1874 deleted_key=deleted_key, 1875 name="t1", 1876 checkpoint=True, 1877 initial_num_buckets=64, 1878 experimental_is_anonymous=is_anonymous) 1879 table.insert( 1880 constant_op.constant([11, 14], dtypes.int64), 1881 constant_op.constant([12, 24], dtypes.int64)).run() 1882 self.assertAllEqual(2, table.size()) 1883 self.assertAllEqual(64, len(table.export()[0].eval())) 1884 1885 save = saver.Saver() 1886 1887 # Restore the saved values in the parameter nodes. 1888 save.restore(sess, save_path) 1889 1890 self.assertAllEqual(3, table.size()) 1891 self.assertAllEqual(32, len(table.export()[0].eval())) 1892 1893 input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64) 1894 output = table.lookup(input_string) 1895 self.assertAllEqual([-1, 0, -1, 2, 3], output) 1896 1897 @test_util.run_v1_only("Saver V1 only") 1898 def testSaveRestoreOnlyTable(self, is_anonymous): 1899 if is_anonymous and not tf2.enabled(): 1900 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1901 save_dir = os.path.join(self.get_temp_dir(), "save_restore") 1902 save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") 1903 1904 with self.session(graph=ops.Graph()) as sess: 1905 default_value = -1 1906 empty_key = 0 1907 deleted_key = -1 1908 keys = constant_op.constant([11, 12, 13, 14], dtypes.int64) 1909 values = constant_op.constant([0, 1, 2, 3], dtypes.int64) 1910 table = lookup_ops.DenseHashTable( 1911 dtypes.int64, 1912 dtypes.int64, 1913 default_value=default_value, 1914 empty_key=empty_key, 1915 deleted_key=deleted_key, 1916 name="t1", 1917 checkpoint=True, 1918 initial_num_buckets=32, 1919 experimental_is_anonymous=is_anonymous) 1920 1921 save = saver.Saver([table]) 1922 1923 self.assertAllEqual(0, table.size()) 1924 table.insert(keys, values).run() 1925 self.assertAllEqual(4, table.size()) 1926 self.assertAllEqual(32, len(table.export()[0].eval())) 1927 1928 keys2 = constant_op.constant([12, 15], dtypes.int64) 1929 table.remove(keys2).run() 1930 self.assertAllEqual(3, table.size()) 1931 self.assertAllEqual(32, len(table.export()[0].eval())) 1932 1933 val = save.save(sess, save_path) 1934 self.assertIsInstance(val, str) 1935 self.assertEqual(save_path, val) 1936 1937 with self.session(graph=ops.Graph()) as sess: 1938 table = lookup_ops.DenseHashTable( 1939 dtypes.int64, 1940 dtypes.int64, 1941 default_value=default_value, 1942 empty_key=empty_key, 1943 deleted_key=deleted_key, 1944 name="t1", 1945 checkpoint=True, 1946 initial_num_buckets=64, 1947 experimental_is_anonymous=is_anonymous) 1948 table.insert( 1949 constant_op.constant([11, 14], dtypes.int64), 1950 constant_op.constant([12, 24], dtypes.int64)).run() 1951 self.assertAllEqual(2, table.size()) 1952 self.assertAllEqual(64, len(table.export()[0].eval())) 1953 1954 save = saver.Saver([table]) 1955 1956 # Restore the saved values in the parameter nodes. 1957 save.restore(sess, save_path) 1958 1959 self.assertAllEqual(3, table.size()) 1960 self.assertAllEqual(32, len(table.export()[0].eval())) 1961 1962 input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64) 1963 output = table.lookup(input_string) 1964 self.assertAllEqual([-1, 0, -1, 2, 3], output) 1965 1966 @test_util.run_in_graph_and_eager_modes 1967 def testObjectSaveRestore(self, is_anonymous): 1968 if is_anonymous and not context.executing_eagerly(): 1969 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 1970 save_dir = os.path.join(self.get_temp_dir(), "save_restore") 1971 save_prefix = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") 1972 1973 default_value = -1 1974 empty_key = 0 1975 deleted_key = -1 1976 keys = constant_op.constant([11, 12, 13], dtypes.int64) 1977 values = constant_op.constant([0, 1, 2], dtypes.int64) 1978 save_table = lookup_ops.DenseHashTable( 1979 dtypes.int64, 1980 dtypes.int64, 1981 default_value=default_value, 1982 empty_key=empty_key, 1983 deleted_key=deleted_key, 1984 name="t1", 1985 checkpoint=True, 1986 initial_num_buckets=32, 1987 experimental_is_anonymous=is_anonymous) 1988 1989 save_checkpoint = trackable.Checkpoint(table=save_table) 1990 1991 self.assertAllEqual(0, self.evaluate(save_table.size())) 1992 self.evaluate(save_table.insert(keys, values)) 1993 self.assertAllEqual(3, self.evaluate(save_table.size())) 1994 self.assertAllEqual(32, len(self.evaluate(save_table.export()[0]))) 1995 1996 save_path = save_checkpoint.save(save_prefix) 1997 del save_table, save_checkpoint 1998 1999 load_table = lookup_ops.DenseHashTable( 2000 dtypes.int64, 2001 dtypes.int64, 2002 default_value=default_value, 2003 empty_key=empty_key, 2004 deleted_key=deleted_key, 2005 name="t1", 2006 checkpoint=True, 2007 initial_num_buckets=64, 2008 experimental_is_anonymous=is_anonymous) 2009 self.evaluate( 2010 load_table.insert( 2011 constant_op.constant([11, 14], dtypes.int64), 2012 constant_op.constant([12, 24], dtypes.int64))) 2013 self.assertAllEqual(2, self.evaluate(load_table.size())) 2014 self.assertAllEqual(64, len(self.evaluate(load_table.export()[0]))) 2015 2016 restore_checkpoint = trackable.Checkpoint(table=load_table) 2017 2018 # Restore the saved values in the parameter nodes. 2019 restore_checkpoint.restore(save_path).run_restore_ops() 2020 2021 self.assertAllEqual(3, self.evaluate(load_table.size())) 2022 self.assertAllEqual(32, len(self.evaluate(load_table.export()[0]))) 2023 2024 input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64) 2025 output = load_table.lookup(input_string) 2026 self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output)) 2027 2028 @test_util.run_v2_only 2029 def testSavedModelSaveRestore(self, is_anonymous): 2030 save_dir = os.path.join(self.get_temp_dir(), "save_restore") 2031 save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") 2032 2033 root = autotrackable.AutoTrackable() 2034 2035 default_value = -1 2036 empty_key = 0 2037 deleted_key = -1 2038 keys = constant_op.constant([11, 12, 13], dtypes.int64) 2039 values = constant_op.constant([0, 1, 2], dtypes.int64) 2040 root.table = lookup_ops.DenseHashTable( 2041 dtypes.int64, 2042 dtypes.int64, 2043 default_value=default_value, 2044 empty_key=empty_key, 2045 deleted_key=deleted_key, 2046 name="t1", 2047 checkpoint=True, 2048 initial_num_buckets=32, 2049 experimental_is_anonymous=is_anonymous) 2050 2051 @def_function.function( 2052 input_signature=[tensor_spec.TensorSpec((), dtypes.int64)]) 2053 def lookup(key): 2054 return root.table.lookup(key) 2055 2056 @def_function.function(input_signature=[]) 2057 def size(): 2058 return root.table.size() 2059 2060 @def_function.function(input_signature=[]) 2061 def is_ref_counting(): 2062 return test_ops.is_resource_handle_ref_counting( 2063 root.table.resource_handle) 2064 2065 root.lookup = lookup 2066 root.size = size 2067 root.is_ref_counting = is_ref_counting 2068 2069 self.assertEqual(root.table.size(), 0) 2070 root.table.insert(keys, values) 2071 self.assertEqual(root.table.size(), 3) 2072 self.assertEqual(root.table.lookup(12), 1) 2073 self.assertEqual(root.table.lookup(10), -1) 2074 self.assertEqual(len(root.table.export()[0]), 32) 2075 self.assertEqual(root.is_ref_counting(), is_anonymous) 2076 2077 saved_model_save.save(root, save_path) 2078 2079 del root 2080 loaded = saved_model_load.load(save_path) 2081 self.assertEqual(loaded.size(), 3) 2082 self.assertEqual(loaded.lookup(12), 1) 2083 self.assertEqual(loaded.lookup(10), -1) 2084 self.assertEqual(loaded.is_ref_counting(), is_anonymous) 2085 2086 @test_util.run_v1_only("Saver V1 only") 2087 def testVectorSaveRestore(self, is_anonymous): 2088 if is_anonymous and not tf2.enabled(): 2089 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 2090 save_dir = os.path.join(self.get_temp_dir(), "vector_save_restore") 2091 save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") 2092 2093 with self.session(graph=ops.Graph()) as sess: 2094 empty_key = constant_op.constant([11, 13], dtypes.int64) 2095 deleted_key = constant_op.constant([-2, -3], dtypes.int64) 2096 default_value = constant_op.constant([-1, -2], dtypes.int64) 2097 keys = constant_op.constant([[11, 12], [11, 14], [12, 13], [13, 14]], 2098 dtypes.int64) 2099 values = constant_op.constant([[0, 1], [2, 3], [2, 4], [4, 5]], 2100 dtypes.int64) 2101 table = lookup_ops.DenseHashTable( 2102 dtypes.int64, 2103 dtypes.int64, 2104 default_value=default_value, 2105 empty_key=empty_key, 2106 deleted_key=deleted_key, 2107 name="t1", 2108 checkpoint=True, 2109 initial_num_buckets=32, 2110 experimental_is_anonymous=is_anonymous) 2111 2112 save = saver.Saver() 2113 2114 self.assertAllEqual(0, table.size()) 2115 table.insert(keys, values).run() 2116 self.assertAllEqual(4, table.size()) 2117 self.assertAllEqual(32, len(table.export()[0].eval())) 2118 2119 keys2 = constant_op.constant([[12, 13], [16, 17]], dtypes.int64) 2120 table.remove(keys2).run() 2121 self.assertAllEqual(3, table.size()) 2122 self.assertAllEqual(32, len(table.export()[0].eval())) 2123 2124 val = save.save(sess, save_path) 2125 self.assertIsInstance(val, str) 2126 self.assertEqual(save_path, val) 2127 2128 with self.session(graph=ops.Graph()) as sess: 2129 empty_key = constant_op.constant([11, 13], dtypes.int64) 2130 deleted_key = constant_op.constant([-2, -3], dtypes.int64) 2131 default_value = constant_op.constant([-1, -2], dtypes.int64) 2132 table = lookup_ops.DenseHashTable( 2133 dtypes.int64, 2134 dtypes.int64, 2135 default_value=default_value, 2136 empty_key=empty_key, 2137 deleted_key=deleted_key, 2138 name="t1", 2139 checkpoint=True, 2140 initial_num_buckets=64, 2141 experimental_is_anonymous=is_anonymous) 2142 table.insert( 2143 constant_op.constant([[11, 12], [13, 15]], dtypes.int64), 2144 constant_op.constant([[21, 22], [23, 24]], dtypes.int64)).run() 2145 self.assertAllEqual(2, table.size()) 2146 self.assertAllEqual(64, len(table.export()[0].eval())) 2147 2148 save = saver.Saver() 2149 2150 # Restore the saved values in the parameter nodes. 2151 save.restore(sess, save_path) 2152 2153 self.assertAllEqual(3, table.size()) 2154 self.assertAllEqual(32, len(table.export()[0].eval())) 2155 2156 input_string = constant_op.constant( 2157 [[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64) 2158 output = table.lookup(input_string) 2159 self.assertAllEqual([[0, 1], [2, 3], [-1, -2], [4, 5], [-1, -2]], 2160 self.evaluate(output)) 2161 2162 @test_util.run_v1_only("Saver V1 only") 2163 def testVectorScalarSaveRestore(self, is_anonymous): 2164 if is_anonymous and not tf2.enabled(): 2165 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 2166 save_dir = os.path.join(self.get_temp_dir(), "vector_scalar_save_restore") 2167 save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") 2168 2169 with self.session(graph=ops.Graph()) as sess: 2170 empty_key = constant_op.constant([11, 13], dtypes.int64) 2171 deleted_key = constant_op.constant([-1, -1], dtypes.int64) 2172 default_value = constant_op.constant(-1, dtypes.int64) 2173 keys = constant_op.constant([[11, 12], [11, 14], [12, 13], [13, 14]], 2174 dtypes.int64) 2175 values = constant_op.constant([0, 1, 2, 3], dtypes.int64) 2176 table = lookup_ops.DenseHashTable( 2177 dtypes.int64, 2178 dtypes.int64, 2179 default_value=default_value, 2180 empty_key=empty_key, 2181 deleted_key=deleted_key, 2182 name="t2", 2183 checkpoint=True, 2184 initial_num_buckets=32, 2185 experimental_is_anonymous=is_anonymous) 2186 2187 save = saver.Saver() 2188 2189 self.assertAllEqual(0, table.size()) 2190 table.insert(keys, values).run() 2191 self.assertAllEqual(4, table.size()) 2192 self.assertAllEqual(32, len(table.export()[0].eval())) 2193 2194 keys2 = constant_op.constant([[12, 13], [15, 16]], dtypes.int64) 2195 table.remove(keys2).run() 2196 self.assertAllEqual(3, table.size()) 2197 self.assertAllEqual(32, len(table.export()[0].eval())) 2198 2199 val = save.save(sess, save_path) 2200 self.assertIsInstance(val, str) 2201 self.assertEqual(save_path, val) 2202 2203 with self.session(graph=ops.Graph()) as sess: 2204 empty_key = constant_op.constant([11, 13], dtypes.int64) 2205 deleted_key = constant_op.constant([-1, -1], dtypes.int64) 2206 default_value = constant_op.constant(-1, dtypes.int64) 2207 table = lookup_ops.DenseHashTable( 2208 dtypes.int64, 2209 dtypes.int64, 2210 default_value=default_value, 2211 empty_key=empty_key, 2212 deleted_key=deleted_key, 2213 name="t2", 2214 checkpoint=True, 2215 initial_num_buckets=64, 2216 experimental_is_anonymous=is_anonymous) 2217 table.insert( 2218 constant_op.constant([[11, 12], [13, 15]], dtypes.int64), 2219 constant_op.constant([3, 4], dtypes.int64)).run() 2220 self.assertAllEqual(2, table.size()) 2221 self.assertAllEqual(64, len(table.export()[0].eval())) 2222 2223 save = saver.Saver() 2224 2225 # Restore the saved values in the parameter nodes. 2226 save.restore(sess, save_path) 2227 2228 self.assertAllEqual(3, table.size()) 2229 self.assertAllEqual(32, len(table.export()[0].eval())) 2230 2231 input_string = constant_op.constant( 2232 [[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64) 2233 output = table.lookup(input_string) 2234 self.assertAllEqual([0, 1, -1, 3, -1], output) 2235 2236 def testReprobe(self, is_anonymous): 2237 if is_anonymous and not tf2.enabled(): 2238 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 2239 # Insert 6 keys into a table with 8 buckets. 2240 # The values are chosen to make sure collisions occur when using GCC STL 2241 keys = constant_op.constant([11, 12, 13, 19, 20, 21], dtypes.int64) 2242 values = constant_op.constant([51, 52, 53, 54, 55, 56], dtypes.int64) 2243 table = lookup_ops.DenseHashTable( 2244 dtypes.int64, 2245 dtypes.int64, 2246 default_value=-1, 2247 empty_key=0, 2248 deleted_key=-1, 2249 initial_num_buckets=8, 2250 experimental_is_anonymous=is_anonymous) 2251 self.assertAllEqual(0, self.evaluate(table.size())) 2252 2253 self.evaluate(table.insert(keys, values)) 2254 self.assertAllEqual(6, self.evaluate(table.size())) 2255 2256 input_string = constant_op.constant([10, 11, 12, 13, 14, 19, 20, 21, 22], 2257 dtypes.int64) 2258 output = table.lookup(input_string) 2259 self.assertAllEqual([9], output.get_shape()) 2260 2261 result = self.evaluate(output) 2262 self.assertAllEqual([-1, 51, 52, 53, -1, 54, 55, 56, -1], result) 2263 2264 def testCustomEmptyKey(self, is_anonymous): 2265 if is_anonymous and not tf2.enabled(): 2266 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 2267 keys = constant_op.constant([11, 0, 13], dtypes.int64) 2268 values = constant_op.constant([0, 1, 2], dtypes.int64) 2269 table = lookup_ops.DenseHashTable( 2270 dtypes.int64, 2271 dtypes.int64, 2272 default_value=-1, 2273 empty_key=12, 2274 deleted_key=-1, 2275 experimental_is_anonymous=is_anonymous) 2276 self.assertAllEqual(0, self.evaluate(table.size())) 2277 2278 self.evaluate(table.insert(keys, values)) 2279 self.assertAllEqual(3, self.evaluate(table.size())) 2280 2281 input_string = constant_op.constant([11, 0, 15], dtypes.int64) 2282 output = table.lookup(input_string) 2283 self.assertAllEqual([3], output.get_shape()) 2284 2285 result = self.evaluate(output) 2286 self.assertAllEqual([0, 1, -1], result) 2287 2288 def testErrors(self, is_anonymous): 2289 table = lookup_ops.DenseHashTable( 2290 dtypes.int64, 2291 dtypes.int64, 2292 default_value=-1, 2293 empty_key=0, 2294 deleted_key=-1, 2295 experimental_is_anonymous=is_anonymous) 2296 2297 # Inserting the empty key returns an error 2298 keys1 = constant_op.constant([11, 0], dtypes.int64) 2299 values1 = constant_op.constant([0, 1], dtypes.int64) 2300 with self.assertRaisesRegex(errors_impl.InvalidArgumentError, 2301 "empty_key"): 2302 self.evaluate(table.insert(keys1, values1)) 2303 2304 # Looking up the empty key returns an error 2305 with self.assertRaisesRegex(errors_impl.InvalidArgumentError, 2306 "empty_key"): 2307 self.evaluate(table.lookup(keys1)) 2308 2309 # Inserting the deleted key returns an error 2310 keys2 = constant_op.constant([11, -1], dtypes.int64) 2311 values2 = constant_op.constant([0, 1], dtypes.int64) 2312 with self.assertRaisesRegex(errors_impl.InvalidArgumentError, 2313 "deleted_key"): 2314 self.evaluate(table.insert(keys2, values2)) 2315 2316 # Looking up the empty key returns an error 2317 with self.assertRaisesRegex(errors_impl.InvalidArgumentError, 2318 "deleted_key"): 2319 self.evaluate(table.lookup(keys2)) 2320 2321 # Arbitrary tensors of keys are not supported 2322 keys = constant_op.constant([[11, 0], [12, 1]], dtypes.int64) 2323 values = constant_op.constant([[11, 0], [12, 1]], dtypes.int64) 2324 with self.assertRaisesRegex(errors_impl.InvalidArgumentError, 2325 "Expected key shape"): 2326 self.evaluate(table.lookup(keys)) 2327 with self.assertRaisesRegex(errors_impl.InvalidArgumentError, 2328 "Expected key shape"): 2329 self.evaluate(table.insert(keys, values)) 2330 2331 with self.assertRaisesRegex(errors_impl.InvalidArgumentError, 2332 "Number of buckets must be"): 2333 table2 = lookup_ops.DenseHashTable( 2334 dtypes.int64, 2335 dtypes.int64, 2336 default_value=-1, 2337 empty_key=17, 2338 deleted_key=-1, 2339 initial_num_buckets=12, 2340 experimental_is_anonymous=is_anonymous) 2341 self.assertAllEqual(0, self.evaluate(table2.size())) 2342 2343 with self.assertRaisesRegex( 2344 errors_impl.InvalidArgumentError, 2345 "Empty and deleted keys must have same shape"): 2346 table3 = lookup_ops.DenseHashTable( 2347 dtypes.int64, 2348 dtypes.int64, 2349 default_value=-1, 2350 empty_key=42, 2351 deleted_key=[1, 2], 2352 experimental_is_anonymous=is_anonymous) 2353 self.assertAllEqual(0, self.evaluate(table3.size())) 2354 2355 with self.assertRaisesRegex(errors_impl.InvalidArgumentError, 2356 "Empty and deleted keys cannot be equal"): 2357 table4 = lookup_ops.DenseHashTable( 2358 dtypes.int64, 2359 dtypes.int64, 2360 default_value=-1, 2361 empty_key=42, 2362 deleted_key=42, 2363 experimental_is_anonymous=is_anonymous) 2364 self.assertAllEqual(0, self.evaluate(table4.size())) 2365 2366 with self.assertRaisesRegex(errors_impl.InvalidArgumentError, 2367 "Empty and deleted keys cannot be equal"): 2368 table5 = lookup_ops.DenseHashTable( 2369 dtypes.int64, 2370 dtypes.int64, 2371 default_value=-1, 2372 empty_key=[1, 2, 3], 2373 deleted_key=[1, 2, 3], 2374 experimental_is_anonymous=is_anonymous) 2375 self.assertAllEqual(0, self.evaluate(table5.size())) 2376 2377 @test_util.run_in_graph_and_eager_modes 2378 def testStringToResource(self, is_anonymous): 2379 v = variables.Variable(1.) 2380 v1 = variables.Variable(1.) 2381 table = lookup_ops.DenseHashTable( 2382 dtypes.string, 2383 dtypes.resource, 2384 default_value=v.handle, 2385 empty_key="<empty>", 2386 deleted_key="<deleted>", 2387 experimental_is_anonymous=is_anonymous) 2388 self.assertEqual([], table.lookup("not_found").shape) 2389 table.insert("v1", v1.handle) 2390 self.assertEqual([], table.lookup("v1").shape) 2391 2392 def testExportShapeInference(self, is_anonymous): 2393 default_value = -1 2394 empty_key = 0 2395 deleted_key = -1 2396 table = lookup_ops.DenseHashTable( 2397 dtypes.int64, 2398 dtypes.int64, 2399 default_value=default_value, 2400 empty_key=empty_key, 2401 deleted_key=deleted_key, 2402 experimental_is_anonymous=is_anonymous) 2403 actual_shapes = [t.shape for t in table.export()] 2404 inferred_shapes = [] 2405 2406 @def_function.function 2407 def f(): 2408 for t in table.export(): 2409 inferred_shapes.append(t.shape) 2410 2411 f() 2412 self.assertLen(actual_shapes, 2) 2413 self.assertLen(inferred_shapes, 2) 2414 self.assertTrue(inferred_shapes[0].is_compatible_with(actual_shapes[0])) 2415 self.assertTrue(inferred_shapes[1].is_compatible_with(actual_shapes[1])) 2416 2417 2418class IndexTableFromFile(test.TestCase): 2419 2420 def _createVocabFile(self, basename, values=("brain", "salad", "surgery")): 2421 vocabulary_file = os.path.join(self.get_temp_dir(), basename) 2422 with open(vocabulary_file, "w") as f: 2423 f.write("\n".join(values) + "\n") 2424 return vocabulary_file 2425 2426 def test_string_index_table_from_file(self): 2427 vocabulary_file = self._createVocabFile("f2i_vocab1.txt") 2428 2429 table = lookup_ops.index_table_from_file( 2430 vocabulary_file=vocabulary_file, num_oov_buckets=1) 2431 ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) 2432 2433 if not context.executing_eagerly(): 2434 with self.assertRaises(errors_impl.OpError): 2435 self.evaluate(ids) 2436 self.evaluate(lookup_ops.tables_initializer()) 2437 self.assertAllEqual((1, 2, 3), self.evaluate(ids)) 2438 2439 def test_string_index_table_from_multicolumn_file(self): 2440 vocabulary_file = self._createVocabFile( 2441 "f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1")) 2442 table = lookup_ops.index_table_from_file( 2443 vocabulary_file=vocabulary_file, 2444 num_oov_buckets=1, 2445 key_column_index=0, 2446 value_column_index=lookup_ops.TextFileIndex.LINE_NUMBER) 2447 ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) 2448 2449 if not context.executing_eagerly(): 2450 with self.assertRaises(errors_impl.OpError): 2451 self.evaluate(ids) 2452 self.evaluate(lookup_ops.tables_initializer()) 2453 self.assertAllEqual((1, 2, 3), self.evaluate(ids)) 2454 2455 def test_string_index_table_from_multicolumn_file_custom_delimiter(self): 2456 vocabulary_file = self._createVocabFile( 2457 "f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1")) 2458 table = lookup_ops.index_table_from_file( 2459 vocabulary_file=vocabulary_file, 2460 num_oov_buckets=1, 2461 key_column_index=0, 2462 value_column_index=lookup_ops.TextFileIndex.LINE_NUMBER, 2463 delimiter=" ") 2464 ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) 2465 2466 if not context.executing_eagerly(): 2467 with self.assertRaises(errors_impl.OpError): 2468 self.evaluate(ids) 2469 self.evaluate(lookup_ops.tables_initializer()) 2470 self.assertAllEqual((1, 2, 3), self.evaluate(ids)) 2471 2472 def test_string_index_table_from_file_tensor_filename(self): 2473 vocabulary_file = self._createVocabFile("f2i_vocab1.txt") 2474 vocabulary_file = constant_op.constant(vocabulary_file) 2475 table = lookup_ops.index_table_from_file( 2476 vocabulary_file=vocabulary_file, num_oov_buckets=1) 2477 ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) 2478 2479 if not context.executing_eagerly(): 2480 with self.assertRaises(errors_impl.OpError): 2481 self.evaluate(ids) 2482 self.evaluate(lookup_ops.tables_initializer()) 2483 self.assertAllEqual((1, 2, 3), self.evaluate(ids)) 2484 if not context.executing_eagerly(): 2485 self.assertEqual(1, 2486 len(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS))) 2487 2488 @test_util.run_v1_only("placeholder usage") 2489 def test_string_index_table_from_file_placeholder_filename(self): 2490 vocabulary_file = self._createVocabFile("f2i_vocab1.txt") 2491 with self.cached_session(): 2492 vocabulary_placeholder = array_ops.placeholder(dtypes.string, []) 2493 table = lookup_ops.index_table_from_file( 2494 vocabulary_file=vocabulary_placeholder, num_oov_buckets=1) 2495 ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) 2496 2497 with self.assertRaises(errors_impl.OpError): 2498 self.evaluate(ids) 2499 2500 feed_dict = {vocabulary_placeholder.name: vocabulary_file} 2501 lookup_ops.tables_initializer().run(feed_dict=feed_dict) 2502 self.assertAllEqual((1, 2, 3), self.evaluate(ids)) 2503 self.assertEqual(0, 2504 len(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS))) 2505 2506 def test_int32_index_table_from_file(self): 2507 vocabulary_file = self._createVocabFile( 2508 "f2i_vocab2.txt", values=("42", "1", "-1000")) 2509 table = lookup_ops.index_table_from_file( 2510 vocabulary_file=vocabulary_file, 2511 num_oov_buckets=1, 2512 key_dtype=dtypes.int32) 2513 ids = table.lookup(constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) 2514 2515 if not context.executing_eagerly(): 2516 with self.assertRaises(errors_impl.OpError): 2517 self.evaluate(ids) 2518 self.evaluate(lookup_ops.tables_initializer()) 2519 self.assertAllEqual((1, 2, 3), self.evaluate(ids)) 2520 2521 def test_int64_index_table_from_file(self): 2522 vocabulary_file = self._createVocabFile( 2523 "f2i_vocab3.txt", values=("42", "1", "-1000")) 2524 table = lookup_ops.index_table_from_file( 2525 vocabulary_file=vocabulary_file, 2526 num_oov_buckets=1, 2527 key_dtype=dtypes.int64) 2528 ids = table.lookup(constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) 2529 2530 if not context.executing_eagerly(): 2531 with self.assertRaises(errors_impl.OpError): 2532 self.evaluate(ids) 2533 self.evaluate(lookup_ops.tables_initializer()) 2534 self.assertAllEqual((1, 2, 3), self.evaluate(ids)) 2535 2536 def test_index_table_from_file_with_default_value(self): 2537 default_value = -42 2538 vocabulary_file = self._createVocabFile("f2i_vocab4.txt") 2539 table = lookup_ops.index_table_from_file( 2540 vocabulary_file=vocabulary_file, default_value=default_value) 2541 ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) 2542 2543 if not context.executing_eagerly(): 2544 with self.assertRaises(errors_impl.OpError): 2545 self.evaluate(ids) 2546 self.evaluate(lookup_ops.tables_initializer()) 2547 self.assertAllEqual((1, 2, default_value), self.evaluate(ids)) 2548 2549 def test_index_table_from_file_with_oov_buckets(self): 2550 vocabulary_file = self._createVocabFile("f2i_vocab5.txt") 2551 table = lookup_ops.index_table_from_file( 2552 vocabulary_file=vocabulary_file, num_oov_buckets=1000) 2553 ids = table.lookup( 2554 constant_op.constant(["salad", "surgery", "tarkus", "toccata"])) 2555 2556 if not context.executing_eagerly(): 2557 with self.assertRaises(errors_impl.OpError): 2558 self.evaluate(ids) 2559 self.evaluate(lookup_ops.tables_initializer()) 2560 self.assertAllEqual( 2561 ( 2562 1, # From vocabulary file. 2563 2, # From vocabulary file. 2564 867, # 3 + fingerprint("tarkus") mod 300. 2565 860), # 3 + fingerprint("toccata") mod 300. 2566 self.evaluate(ids)) 2567 2568 def test_index_table_from_file_fails_with_empty_vocabulary_file_name(self): 2569 self.assertRaises( 2570 ValueError, lookup_ops.index_table_from_file, vocabulary_file="") 2571 2572 def test_index_table_from_file_fails_with_empty_vocabulary(self): 2573 self.assertRaises( 2574 ValueError, lookup_ops.index_table_from_file, vocabulary_file=None) 2575 2576 def test_index_table_from_file_str_fails_with_zero_size_vocabulary(self): 2577 vocabulary_file = self._createVocabFile("zero_vocab_str.txt") 2578 self.assertRaisesRegex( 2579 ValueError, "`vocab_size` must be greater than 0, got 0 for " 2580 "vocabulary_file: .*zero_vocab_str.txt", 2581 lookup_ops.index_table_from_file, 2582 vocabulary_file=vocabulary_file, 2583 vocab_size=0) 2584 2585 def test_index_table_from_file_tensor_fails_with_zero_size_vocabulary(self): 2586 vocabulary_file = constant_op.constant( 2587 self._createVocabFile("zero_vocab_tensor.txt")) 2588 self.assertRaisesRegex( 2589 ValueError, "`vocab_size` must be greater than 0, got 0 for " 2590 "vocabulary_file: .*zero_vocab_tensor.txt", 2591 lookup_ops.index_table_from_file, 2592 vocabulary_file=vocabulary_file, 2593 vocab_size=0) 2594 2595 def test_index_table_from_file_with_vocab_size_too_small(self): 2596 vocabulary_file = self._createVocabFile("f2i_vocab6.txt") 2597 table = lookup_ops.index_table_from_file( 2598 vocabulary_file=vocabulary_file, vocab_size=2) 2599 ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) 2600 2601 if not context.executing_eagerly(): 2602 with self.assertRaises(errors_impl.OpError): 2603 self.evaluate(ids) 2604 self.evaluate(lookup_ops.tables_initializer()) 2605 self.assertAllEqual((1, -1, -1), self.evaluate(ids)) 2606 self.assertEqual(2, self.evaluate(table.size())) 2607 2608 def test_index_table_from_file_with_vocab_size_too_large(self): 2609 vocabulary_file = self._createVocabFile("f2i_vocab7.txt") 2610 with self.assertRaisesRegex(errors_impl.InvalidArgumentError, 2611 "Invalid vocab_size"): 2612 table = lookup_ops.index_table_from_file( 2613 vocabulary_file=vocabulary_file, vocab_size=4) 2614 self.evaluate(table.initializer) 2615 2616 def test_index_table_from_file_with_vocab_size(self): 2617 vocabulary_file = self._createVocabFile("f2i_vocab8.txt") 2618 2619 self.assertRaises( 2620 ValueError, 2621 lookup_ops.index_table_from_file, 2622 vocabulary_file=vocabulary_file, 2623 vocab_size=0) 2624 2625 table = lookup_ops.index_table_from_file( 2626 vocabulary_file=vocabulary_file, vocab_size=3) 2627 ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) 2628 2629 if not context.executing_eagerly(): 2630 with self.assertRaises(errors_impl.OpError): 2631 self.evaluate(ids) 2632 self.evaluate(lookup_ops.tables_initializer()) 2633 self.assertAllEqual((1, 2, -1), self.evaluate(ids)) 2634 self.assertEqual(3, self.evaluate(table.size())) 2635 2636 def test_index_table_from_file_with_invalid_hashers(self): 2637 vocabulary_file = self._createVocabFile("invalid_hasher.txt") 2638 with self.assertRaises(TypeError): 2639 lookup_ops.index_table_from_file( 2640 vocabulary_file=vocabulary_file, 2641 vocab_size=3, 2642 num_oov_buckets=1, 2643 hasher_spec=1) 2644 2645 table = lookup_ops.index_table_from_file( 2646 vocabulary_file=vocabulary_file, 2647 vocab_size=3, 2648 num_oov_buckets=1, 2649 hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None)) 2650 2651 self.assertRaises(ValueError, table.lookup, 2652 constant_op.constant(["salad", "surgery", "tarkus"])) 2653 2654 def test_index_table_from_file_table_ref_with_oov_buckets(self): 2655 vocabulary_file = self._createVocabFile("f2i_vocab9.txt") 2656 table = lookup_ops.index_table_from_file( 2657 vocabulary_file=vocabulary_file, num_oov_buckets=1) 2658 self.assertIsNotNone(table.resource_handle) 2659 2660 def test_index_table_from_file_table_ref_without_oov_buckets(self): 2661 vocabulary_file = self._createVocabFile("f2i_vocab10.txt") 2662 table = lookup_ops.index_table_from_file( 2663 vocabulary_file=vocabulary_file, num_oov_buckets=0) 2664 self.assertIsNotNone(table.resource_handle) 2665 2666 2667class IndexTableFromTensor(test.TestCase): 2668 2669 @test_util.run_in_graph_and_eager_modes 2670 def test_index_table_from_tensor_with_tensor_init(self): 2671 table = lookup_ops.index_table_from_tensor( 2672 vocabulary_list=("brain", "salad", "surgery"), num_oov_buckets=1) 2673 2674 if not context.executing_eagerly(): 2675 with self.assertRaises(errors_impl.OpError): 2676 self.evaluate( 2677 table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))) 2678 else: 2679 # Reinitializing a table in eager should work. 2680 table = lookup_ops.index_table_from_tensor( 2681 vocabulary_list=("brain", "salad", "surgery"), num_oov_buckets=1) 2682 self.evaluate(lookup_ops.tables_initializer()) 2683 ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))) 2684 self.assertAllEqual((1, 2, 3), self.evaluate(ids)) 2685 2686 def test_int32_index_table_from_tensor_with_tensor_init(self): 2687 table = lookup_ops.index_table_from_tensor( 2688 vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32) 2689 ids = table.lookup(constant_op.constant((1, -1000, 11), dtype=dtypes.int32)) 2690 2691 if not context.executing_eagerly(): 2692 with self.assertRaises(errors_impl.FailedPreconditionError): 2693 self.evaluate(ids) 2694 self.evaluate(lookup_ops.tables_initializer()) 2695 self.assertAllEqual((1, 2, 3), self.evaluate(ids)) 2696 2697 def test_int64_index_table_from_tensor_with_tensor_init(self): 2698 table = lookup_ops.index_table_from_tensor( 2699 vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64) 2700 ids = table.lookup(constant_op.constant((1, -1000, 11), dtype=dtypes.int64)) 2701 2702 if not context.executing_eagerly(): 2703 with self.assertRaises(errors_impl.FailedPreconditionError): 2704 self.evaluate(ids) 2705 self.evaluate(lookup_ops.tables_initializer()) 2706 self.assertAllEqual((1, 2, 3), self.evaluate(ids)) 2707 2708 def test_index_table_from_tensor_with_default_value(self): 2709 default_value = -42 2710 table = lookup_ops.index_table_from_tensor( 2711 vocabulary_list=["brain", "salad", "surgery"], 2712 default_value=default_value) 2713 ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"])) 2714 2715 if not context.executing_eagerly(): 2716 with self.assertRaises(errors_impl.FailedPreconditionError): 2717 self.evaluate(ids) 2718 self.evaluate(lookup_ops.tables_initializer()) 2719 self.assertAllEqual((1, 2, default_value), self.evaluate(ids)) 2720 2721 def test_index_table_from_tensor_missing_vocabulary_list(self): 2722 with self.assertRaisesRegex(ValueError, 2723 "`vocabulary_list` must be specified"): 2724 lookup_ops.index_table_from_tensor( 2725 vocabulary_list=None, num_oov_buckets=1) 2726 2727 def test_index_table_from_tensor_empty_vocabulary_list(self): 2728 with self.assertRaisesRegex(errors_impl.OpError, 2729 "keys and values cannot be empty"): 2730 _ = lookup_ops.index_table_from_tensor( 2731 vocabulary_list=np.array([], dtype=np.str_), num_oov_buckets=1) 2732 self.evaluate(lookup_ops.tables_initializer()) 2733 2734 def test_index_table_from_tensor_with_invalid_hashers(self): 2735 with self.assertRaises(TypeError): 2736 lookup_ops.index_table_from_tensor( 2737 vocabulary_list=["brain", "salad", "surgery"], 2738 num_oov_buckets=1, 2739 hasher_spec=1) 2740 2741 table = lookup_ops.index_table_from_tensor( 2742 vocabulary_list=["brain", "salad", "surgery"], 2743 num_oov_buckets=1, 2744 hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None)) 2745 2746 self.assertRaises(ValueError, table.lookup, 2747 constant_op.constant(["salad", "surgery", "tarkus"])) 2748 2749 2750class IndexToStringTableFromFileTest(test.TestCase): 2751 2752 def _createVocabFile(self, basename, values=("brain", "salad", "surgery")): 2753 vocabulary_file = os.path.join(self.get_temp_dir(), basename) 2754 with open(vocabulary_file, "w") as f: 2755 f.write("\n".join(values) + "\n") 2756 return vocabulary_file 2757 2758 def test_index_to_string_table(self): 2759 vocabulary_path = self._createVocabFile("i2f_vocab1.txt") 2760 # vocabulary_file supports string and tensor 2761 type_funcs = [str, constant_op.constant] 2762 for type_func in type_funcs: 2763 vocabulary_file = type_func(vocabulary_path) 2764 table = lookup_ops.index_to_string_table_from_file( 2765 vocabulary_file=vocabulary_file) 2766 features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) 2767 if not context.executing_eagerly(): 2768 with self.assertRaises(errors_impl.OpError): 2769 self.evaluate(features) 2770 self.evaluate(lookup_ops.tables_initializer()) 2771 self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), 2772 self.evaluate(features)) 2773 2774 def test_index_to_string_table_from_multicolumn_file(self): 2775 vocabulary_file = self._createVocabFile( 2776 "f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1")) 2777 table = lookup_ops.index_to_string_table_from_file( 2778 vocabulary_file=vocabulary_file, 2779 key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER, 2780 value_column_index=0) 2781 features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) 2782 if not context.executing_eagerly(): 2783 with self.assertRaises(errors_impl.OpError): 2784 self.evaluate(features) 2785 self.evaluate(lookup_ops.tables_initializer()) 2786 self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), 2787 self.evaluate(features)) 2788 2789 def test_index_to_string_table_from_multicolumn_file_custom_delimiter(self): 2790 vocabulary_file = self._createVocabFile( 2791 "f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1")) 2792 table = lookup_ops.index_to_string_table_from_file( 2793 vocabulary_file=vocabulary_file, 2794 key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER, 2795 value_column_index=0, 2796 delimiter=" ") 2797 features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64)) 2798 if not context.executing_eagerly(): 2799 with self.assertRaises(errors_impl.OpError): 2800 self.evaluate(features) 2801 self.evaluate(lookup_ops.tables_initializer()) 2802 self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), 2803 self.evaluate(features)) 2804 2805 def test_index_to_string_table_with_default_value(self): 2806 default_value = b"NONE" 2807 vocabulary_file = self._createVocabFile("f2i_vocab2.txt") 2808 table = lookup_ops.index_to_string_table_from_file( 2809 vocabulary_file=vocabulary_file, default_value=default_value) 2810 features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) 2811 if not context.executing_eagerly(): 2812 with self.assertRaises(errors_impl.OpError): 2813 self.evaluate(features) 2814 self.evaluate(lookup_ops.tables_initializer()) 2815 self.assertAllEqual((b"salad", b"surgery", default_value), 2816 self.evaluate(features)) 2817 2818 def test_index_to_string_table_with_vocab_size_too_small(self): 2819 default_value = b"NONE" 2820 vocabulary_file = self._createVocabFile("f2i_vocab2.txt") 2821 table = lookup_ops.index_to_string_table_from_file( 2822 vocabulary_file=vocabulary_file, 2823 vocab_size=2, 2824 default_value=default_value) 2825 features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) 2826 if not context.executing_eagerly(): 2827 with self.assertRaises(errors_impl.OpError): 2828 self.evaluate(features) 2829 self.evaluate(lookup_ops.tables_initializer()) 2830 self.assertAllEqual((b"salad", default_value, default_value), 2831 self.evaluate(features)) 2832 2833 def test_index_to_string_table_with_vocab_size_too_large(self): 2834 vocabulary_file = self._createVocabFile("f2i_vocab6.txt") 2835 with self.assertRaisesRegex(errors_impl.InvalidArgumentError, 2836 "Invalid vocab_size"): 2837 _ = lookup_ops.index_to_string_table_from_file( 2838 vocabulary_file=vocabulary_file, vocab_size=4) 2839 self.evaluate(lookup_ops.tables_initializer()) 2840 2841 def test_index_to_string_table_with_vocab_size(self): 2842 vocabulary_file = self._createVocabFile("f2i_vocab7.txt") 2843 table = lookup_ops.index_to_string_table_from_file( 2844 vocabulary_file=vocabulary_file, vocab_size=3) 2845 features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64)) 2846 2847 if not context.executing_eagerly(): 2848 with self.assertRaises(errors_impl.OpError): 2849 self.evaluate(features) 2850 self.evaluate(lookup_ops.tables_initializer()) 2851 self.assertAllEqual((b"salad", b"surgery", b"UNK"), self.evaluate(features)) 2852 2853 2854class IndexToStringTableFromTensorTest(test.TestCase): 2855 2856 def test_index_to_string_table_from_tensor(self): 2857 vocabulary_list = constant_op.constant(["brain", "salad", "surgery"]) 2858 table = lookup_ops.index_to_string_table_from_tensor( 2859 vocabulary_list=vocabulary_list) 2860 2861 indices = constant_op.constant([0, 1, 2, 3], dtypes.int64) 2862 features = table.lookup(indices) 2863 if not context.executing_eagerly(): 2864 with self.assertRaises(errors_impl.OpError): 2865 self.evaluate(features) 2866 self.evaluate(lookup_ops.tables_initializer()) 2867 2868 self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"), 2869 self.evaluate(features)) 2870 2871 def test_duplicate_entries(self): 2872 vocabulary_list = constant_op.constant(["hello", "hello"]) 2873 table = lookup_ops.index_to_string_table_from_tensor( 2874 vocabulary_list=vocabulary_list) 2875 indices = constant_op.constant([0, 1, 4], dtypes.int64) 2876 features = table.lookup(indices) 2877 self.evaluate(lookup_ops.tables_initializer()) 2878 self.assertAllEqual((b"hello", b"hello", b"UNK"), self.evaluate(features)) 2879 2880 def test_index_to_string_with_default_value(self): 2881 default_value = b"NONE" 2882 vocabulary_list = constant_op.constant(["brain", "salad", "surgery"]) 2883 table = lookup_ops.index_to_string_table_from_tensor( 2884 vocabulary_list=vocabulary_list, default_value=default_value) 2885 indices = constant_op.constant([1, 2, 4], dtypes.int64) 2886 features = table.lookup(indices) 2887 if not context.executing_eagerly(): 2888 with self.assertRaises(errors_impl.OpError): 2889 self.evaluate(features) 2890 self.evaluate(lookup_ops.tables_initializer()) 2891 self.assertAllEqual((b"salad", b"surgery", default_value), 2892 self.evaluate(features)) 2893 2894 2895class IdTableWithHashBucketsTest(test.TestCase): 2896 2897 def _createVocabFile(self, basename, values=("brain", "salad", "surgery")): 2898 vocabulary_file = os.path.join(self.get_temp_dir(), basename) 2899 with open(vocabulary_file, "w") as f: 2900 f.write("\n".join(values) + "\n") 2901 return vocabulary_file 2902 2903 def testStringIdTableWithHashBuckets(self): 2904 vocab_file = self._createVocabFile("feat_to_id_1.txt") 2905 default_value = -1 2906 vocab_size = 3 2907 oov_buckets = 1 2908 table = lookup_ops.IdTableWithHashBuckets( 2909 lookup_ops.StaticHashTable( 2910 lookup_ops.TextFileIdTableInitializer( 2911 vocab_file, vocab_size=vocab_size), default_value), 2912 oov_buckets) 2913 2914 self.evaluate(table.initializer) 2915 2916 input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"]) 2917 2918 out = table.lookup(input_string) 2919 self.assertAllEqual([0, 1, 2, 3], self.evaluate(out)) 2920 self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size())) 2921 2922 def testInt32IdTableWithHashBuckets(self): 2923 vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000")) 2924 default_value = -1 2925 vocab_size = 3 2926 oov_buckets = 1 2927 table = lookup_ops.IdTableWithHashBuckets( 2928 lookup_ops.StaticHashTable( 2929 lookup_ops.TextFileIdTableInitializer( 2930 vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64), 2931 default_value), 2932 oov_buckets, 2933 key_dtype=dtypes.int32) 2934 2935 self.evaluate(table.initializer) 2936 2937 values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int32) 2938 2939 out = table.lookup(values) 2940 self.assertAllEqual([0, 1, 2, 3], self.evaluate(out)) 2941 self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size())) 2942 2943 def testInt64IdTableWithHashBuckets(self): 2944 vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000")) 2945 default_value = -1 2946 vocab_size = 3 2947 oov_buckets = 1 2948 table = lookup_ops.IdTableWithHashBuckets( 2949 lookup_ops.StaticHashTable( 2950 lookup_ops.TextFileIdTableInitializer( 2951 vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64), 2952 default_value), oov_buckets) 2953 2954 self.evaluate(table.initializer) 2955 2956 values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64) 2957 2958 out = table.lookup(values) 2959 self.assertAllEqual([0, 1, 2, 3], self.evaluate(out)) 2960 self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size())) 2961 2962 def testStringIdTableWithOnlyHashBucket(self): 2963 oov_buckets = 5 2964 2965 # Set a table that only uses hash buckets, for each input value returns 2966 # an id calculated by fingerprint("input") mod oov_buckets. 2967 table = lookup_ops.IdTableWithHashBuckets(None, oov_buckets) 2968 self.evaluate(table.initializer) 2969 2970 values = constant_op.constant(("brain", "salad", "surgery")) 2971 2972 out = table.lookup(values) 2973 self.assertAllEqual( 2974 [ 2975 3, # fingerprint("brain") mod 5. 2976 1, # fingerprint("salad") mod 5. 2977 4 # fingerprint("surgery") mod 5 2978 ], 2979 self.evaluate(out)) 2980 self.assertEqual(oov_buckets, self.evaluate(table.size())) 2981 2982 def testInt32IdTableWithOnlyHashBucket(self): 2983 oov_buckets = 5 2984 2985 # Set a table that only uses hash buckets, for each input value returns 2986 # an id calculated by fingerprint("input") mod oov_buckets. 2987 table = lookup_ops.IdTableWithHashBuckets( 2988 None, oov_buckets, key_dtype=dtypes.int32) 2989 self.evaluate(table.initializer) 2990 2991 input_string = constant_op.constant([42, 1, -1000], dtype=dtypes.int32) 2992 2993 out = table.lookup(input_string) 2994 self.assertAllEqual( 2995 [ 2996 1, # fingerprint("42") mod 5. 2997 4, # fingerprint("1") mod 5. 2998 2 # fingerprint("-1000") mod 5 2999 ], 3000 self.evaluate(out)) 3001 self.assertEqual(oov_buckets, self.evaluate(table.size())) 3002 3003 def testFloat64IdTableWithOnlyHashBucket(self): 3004 with self.assertRaisesRegex(TypeError, "Invalid `key_dtype`"): 3005 lookup_ops.IdTableWithHashBuckets( 3006 None, num_oov_buckets=5, key_dtype=dtypes.float64) 3007 3008 def testBoolIdTableWithOnlyHashBucket(self): 3009 with self.assertRaisesRegex(TypeError, "Invalid `key_dtype`"): 3010 lookup_ops.IdTableWithHashBuckets( 3011 None, num_oov_buckets=5, key_dtype=dtypes.bool) 3012 3013 def testIdTableWithHashBucketsWithMultipleInitializers(self): 3014 vocab_file = self._createVocabFile("feat_to_id_4.txt") 3015 default_value = -1 3016 vocab_size = 3 3017 oov_buckets = 3 3018 3019 vocab_table = lookup_ops.StaticHashTable( 3020 lookup_ops.TextFileIdTableInitializer( 3021 vocab_file, vocab_size=vocab_size), default_value) 3022 table1 = lookup_ops.IdTableWithHashBuckets( 3023 vocab_table, 3024 oov_buckets, 3025 hasher_spec=lookup_ops.FastHashSpec, 3026 name="table1") 3027 3028 table2 = lookup_ops.IdTableWithHashBuckets( 3029 vocab_table, 3030 oov_buckets, 3031 hasher_spec=lookup_ops.StrongHashSpec((1, 2)), 3032 name="table2") 3033 3034 self.evaluate(lookup_ops.tables_initializer()) 3035 3036 input_string = constant_op.constant( 3037 ["fruit", "brain", "salad", "surgery", "UNK"]) 3038 3039 out1 = table1.lookup(input_string) 3040 out2 = table2.lookup(input_string) 3041 3042 out1, out2 = self.evaluate([out1, out2]) 3043 self.assertAllEqual([5, 0, 1, 2, 5], out1) 3044 self.assertAllEqual([5, 0, 1, 2, 3], out2) 3045 self.assertEqual(vocab_size + oov_buckets, self.evaluate(table1.size())) 3046 self.assertEqual(vocab_size + oov_buckets, self.evaluate(table2.size())) 3047 if not context.executing_eagerly(): 3048 test_util.assert_ops_in_graph({ 3049 "table1_Lookup/hash_bucket": "StringToHashBucketFast", 3050 "table2_Lookup/hash_bucket": "StringToHashBucketStrong", 3051 }, ops.get_default_graph()) 3052 3053 def testIdTableWithHashBucketsInitializationAcrossSessions(self): 3054 vocab_file = self._createVocabFile("feat_to_id_5.txt") 3055 default_value = -1 3056 vocab_size = 3 3057 oov_buckets = 1 3058 table1 = lookup_ops.IdTableWithHashBuckets( 3059 lookup_ops.StaticHashTable( 3060 lookup_ops.TextFileIdTableInitializer( 3061 vocab_file, vocab_size=vocab_size), default_value), oov_buckets) 3062 3063 self.evaluate(table1.initializer) 3064 3065 input_string_1 = constant_op.constant(["brain", "salad", "surgery", "UNK"]) 3066 3067 out1 = table1.lookup(input_string_1) 3068 3069 self.assertAllEqual([0, 1, 2, 3], self.evaluate(out1)) 3070 self.assertEqual(vocab_size + oov_buckets, self.evaluate(table1.size())) 3071 3072 default_value = -1 3073 vocab_size = 3 3074 oov_buckets = 1 3075 3076 # Underlying lookup table already initialized in previous session. 3077 # No need to call self.evaluate(table2.initializer) 3078 table2 = lookup_ops.IdTableWithHashBuckets( 3079 lookup_ops.StaticHashTable( 3080 lookup_ops.TextFileIdTableInitializer( 3081 vocab_file, vocab_size=vocab_size), default_value), oov_buckets) 3082 3083 input_string_2 = constant_op.constant(["fruit", "salad", "UNK"]) 3084 3085 out2 = table2.lookup(input_string_2) 3086 3087 self.assertAllEqual([3, 1, 3], self.evaluate(out2)) 3088 self.assertEqual(vocab_size + oov_buckets, self.evaluate(table2.size())) 3089 3090 def testIdTableWithHashBucketsWithMultipleInitializersDifferentDefault(self): 3091 vocab_file = self._createVocabFile("feat_to_id_6.txt") 3092 default_value1 = -1 3093 vocab_size = 3 3094 oov_buckets = 0 3095 table1 = lookup_ops.IdTableWithHashBuckets( 3096 lookup_ops.StaticHashTable( 3097 lookup_ops.TextFileIdTableInitializer( 3098 vocab_file, vocab_size=vocab_size), default_value1), 3099 oov_buckets) 3100 3101 default_value2 = -2 3102 table2 = lookup_ops.IdTableWithHashBuckets( 3103 lookup_ops.StaticHashTable( 3104 lookup_ops.TextFileIdTableInitializer( 3105 vocab_file, vocab_size=vocab_size), default_value2), 3106 oov_buckets) 3107 3108 self.evaluate(lookup_ops.tables_initializer()) 3109 3110 input_string_1 = constant_op.constant( 3111 ["brain", "salad", "surgery", "UNK"]) 3112 input_string_2 = constant_op.constant(["fruit", "salad", "UNK"]) 3113 3114 out1 = table1.lookup(input_string_1) 3115 out2 = table2.lookup(input_string_2) 3116 3117 out1, out2 = self.evaluate([out1, out2]) 3118 self.assertAllEqual([0, 1, 2, -1], out1) 3119 self.assertAllEqual([-2, 1, -2], out2) 3120 self.assertEqual(vocab_size + oov_buckets, self.evaluate(table1.size())) 3121 self.assertEqual(vocab_size + oov_buckets, self.evaluate(table2.size())) 3122 3123 def testSparseTensor(self): 3124 vocab_file = self._createVocabFile("feat_to_id_7.txt") 3125 input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] 3126 input_shape = [4, 4] 3127 sp_features = sparse_tensor.SparseTensor( 3128 constant_op.constant(input_indices, dtypes.int64), 3129 constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"], 3130 dtypes.string), 3131 constant_op.constant(input_shape, dtypes.int64)) 3132 3133 table = lookup_ops.IdTableWithHashBuckets( 3134 lookup_ops.StaticHashTable( 3135 lookup_ops.TextFileIdTableInitializer(vocab_file, vocab_size=3), 3136 -1), 1) 3137 self.evaluate(table.initializer) 3138 3139 sp_ids = table.lookup(sp_features) 3140 3141 self.assertAllEqual([5], sp_ids.values._shape_as_list()) 3142 3143 sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate( 3144 [sp_ids.indices, sp_ids.values, sp_ids.dense_shape]) 3145 3146 self.assertAllEqual(input_indices, sp_ids_ind) 3147 self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val) 3148 self.assertAllEqual(input_shape, sp_ids_shape) 3149 3150 def testRaggedTensor(self): 3151 vocab_file = self._createVocabFile("feat_to_id_7.txt") 3152 input_row_splits = [0, 2, 4, 5] 3153 ragged_features = ragged_tensor.RaggedTensor.from_row_splits( 3154 constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"], 3155 dtypes.string), 3156 constant_op.constant(input_row_splits, dtypes.int64)) 3157 3158 table = lookup_ops.IdTableWithHashBuckets( 3159 lookup_ops.StaticHashTable( 3160 lookup_ops.TextFileIdTableInitializer(vocab_file, vocab_size=3), 3161 -1), 1) 3162 self.evaluate(table.initializer) 3163 3164 ragged_ids = table.lookup(ragged_features) 3165 self.assertAllEqual([5], ragged_ids.values._shape_as_list()) 3166 3167 ragged_ids_val, ragged_ids_row_splits = self.evaluate( 3168 [ragged_ids.values, ragged_ids.row_splits]) 3169 3170 self.assertAllEqual([0, 1, 0, 2, 3], ragged_ids_val) 3171 self.assertAllEqual(input_row_splits, ragged_ids_row_splits) 3172 3173 def testInt32SparseTensor(self): 3174 input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] 3175 input_shape = [4, 4] 3176 sp_features = sparse_tensor.SparseTensor( 3177 constant_op.constant(input_indices, dtypes.int64), 3178 constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32), 3179 constant_op.constant(input_shape, dtypes.int64)) 3180 3181 table = lookup_ops.IdTableWithHashBuckets( 3182 lookup_ops.StaticHashTable( 3183 lookup_ops.KeyValueTensorInitializer( 3184 (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64), -1), 3185 1, 3186 key_dtype=dtypes.int32) 3187 self.evaluate(table.initializer) 3188 3189 sp_ids = table.lookup(sp_features) 3190 3191 self.assertAllEqual([5], sp_ids.values._shape_as_list()) 3192 3193 sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate( 3194 [sp_ids.indices, sp_ids.values, sp_ids.dense_shape]) 3195 3196 self.assertAllEqual(input_indices, sp_ids_ind) 3197 self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val) 3198 self.assertAllEqual(input_shape, sp_ids_shape) 3199 3200 def testInt32RaggedTensor(self): 3201 input_row_splits = [0, 2, 4, 5] 3202 ragged_features = ragged_tensor.RaggedTensor.from_row_splits( 3203 constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32), 3204 constant_op.constant(input_row_splits, dtypes.int32)) 3205 3206 table = lookup_ops.IdTableWithHashBuckets( 3207 lookup_ops.StaticHashTable( 3208 lookup_ops.KeyValueTensorInitializer( 3209 (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64), -1), 3210 1, 3211 key_dtype=dtypes.int32) 3212 self.evaluate(table.initializer) 3213 3214 ragged_ids = table.lookup(ragged_features) 3215 3216 self.assertAllEqual([5], ragged_ids.values._shape_as_list()) 3217 3218 ragged_ids_val, ragged_ids_row_splits = self.evaluate( 3219 [ragged_ids.values, ragged_ids.row_splits]) 3220 3221 self.assertAllEqual([0, 1, 0, 2, 3], ragged_ids_val) 3222 self.assertAllEqual(input_row_splits, ragged_ids_row_splits) 3223 3224 def testInt64SparseTensor(self): 3225 input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]] 3226 input_shape = [4, 4] 3227 sp_features = sparse_tensor.SparseTensor( 3228 constant_op.constant(input_indices, dtypes.int64), 3229 constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64), 3230 constant_op.constant(input_shape, dtypes.int64)) 3231 3232 table = lookup_ops.IdTableWithHashBuckets( 3233 lookup_ops.StaticHashTable( 3234 lookup_ops.KeyValueTensorInitializer( 3235 (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64), -1), 3236 1, 3237 key_dtype=dtypes.int64) 3238 self.evaluate(table.initializer) 3239 3240 sp_ids = table.lookup(sp_features) 3241 3242 self.assertAllEqual([5], sp_ids.values._shape_as_list()) 3243 3244 sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate( 3245 [sp_ids.indices, sp_ids.values, sp_ids.dense_shape]) 3246 3247 self.assertAllEqual(input_indices, sp_ids_ind) 3248 self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val) 3249 self.assertAllEqual(input_shape, sp_ids_shape) 3250 3251 def testInt64RaggedTensor(self): 3252 input_row_splits = [0, 2, 4, 5] 3253 ragged_features = ragged_tensor.RaggedTensor.from_row_splits( 3254 constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64), 3255 constant_op.constant(input_row_splits, dtypes.int64)) 3256 3257 table = lookup_ops.IdTableWithHashBuckets( 3258 lookup_ops.StaticHashTable( 3259 lookup_ops.KeyValueTensorInitializer( 3260 (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64), -1), 3261 1, 3262 key_dtype=dtypes.int64) 3263 self.evaluate(table.initializer) 3264 3265 ragged_ids = table.lookup(ragged_features) 3266 3267 self.assertAllEqual([5], ragged_ids.values._shape_as_list()) 3268 3269 ragged_ids_val, ragged_ids_row_splits = self.evaluate( 3270 [ragged_ids.values, ragged_ids.row_splits]) 3271 3272 self.assertAllEqual([0, 1, 0, 2, 3], ragged_ids_val) 3273 self.assertAllEqual(input_row_splits, ragged_ids_row_splits) 3274 3275 def testIdTableWithHashBucketsWithInvalidHashers(self): 3276 vocab_file = self._createVocabFile("feat_to_id_4.txt") 3277 default_value = -1 3278 vocab_size = 3 3279 oov_buckets = 1 3280 lookup_table = lookup_ops.StaticHashTable( 3281 lookup_ops.TextFileIdTableInitializer( 3282 vocab_file, vocab_size=vocab_size), default_value) 3283 3284 with self.assertRaises(TypeError): 3285 lookup_ops.IdTableWithHashBuckets( 3286 lookup_table, oov_buckets, hasher_spec=1) 3287 3288 table = lookup_ops.IdTableWithHashBuckets( 3289 lookup_table, 3290 oov_buckets, 3291 hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None)) 3292 3293 input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"]) 3294 3295 with self.assertRaises(ValueError): 3296 table.lookup(input_string) 3297 3298 with self.assertRaises(ValueError): 3299 table = lookup_ops.IdTableWithHashBuckets( 3300 lookup_table, oov_buckets, hasher_spec=lookup_ops.StrongHashSpec([])) 3301 3302 with self.assertRaises(ValueError): 3303 table = lookup_ops.IdTableWithHashBuckets( 3304 lookup_table, 3305 oov_buckets, 3306 hasher_spec=lookup_ops.StrongHashSpec([1, 2, 3])) 3307 3308 with self.assertRaises(TypeError): 3309 table = lookup_ops.IdTableWithHashBuckets( 3310 lookup_table, 3311 oov_buckets, 3312 hasher_spec=lookup_ops.StrongHashSpec([None, 2])) 3313 3314 def testIdTableWithHashBucketsNoInnerTable(self): 3315 table = lookup_ops.IdTableWithHashBuckets(None, num_oov_buckets=1) 3316 self.assertIsNone(table.resource_handle) 3317 3318 3319@parameterized.named_parameters( 3320 (f"_{is_anonymous}", is_anonymous) for is_anonymous in [False, True]) 3321class MutableHashTableOpTest(test.TestCase): 3322 3323 def testMutableHashTable(self, is_anonymous): 3324 if is_anonymous and not tf2.enabled(): 3325 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 3326 default_val = -1 3327 keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"]) 3328 values = constant_op.constant([0, 1, 2, 3], dtypes.int64) 3329 table = lookup_ops.MutableHashTable( 3330 dtypes.string, 3331 dtypes.int64, 3332 default_val, 3333 experimental_is_anonymous=is_anonymous) 3334 self.assertAllEqual(0, self.evaluate(table.size())) 3335 3336 self.evaluate(table.insert(keys, values)) 3337 self.assertAllEqual(4, self.evaluate(table.size())) 3338 3339 remove_string = constant_op.constant(["tarkus", "tank"]) 3340 self.evaluate(table.remove(remove_string)) 3341 self.assertAllEqual(3, self.evaluate(table.size())) 3342 3343 input_string = constant_op.constant(["brain", "salad", "tank"]) 3344 output = table.lookup(input_string) 3345 self.assertAllEqual([3], output.get_shape()) 3346 3347 result = self.evaluate(output) 3348 self.assertAllEqual([0, 1, -1], result) 3349 3350 exported_keys, exported_values = table.export() 3351 3352 # exported data is in the order of the internal map, i.e. undefined 3353 sorted_keys = np.sort(self.evaluate(exported_keys)) 3354 sorted_values = np.sort(self.evaluate(exported_values)) 3355 self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys) 3356 self.assertAllEqual([0, 1, 2], sorted_values) 3357 3358 # TODO(https://github.com/tensorflow/tensorflow/issues/24439): remove exepectedFailure when fixed 3359 @unittest.expectedFailure 3360 @test_util.run_v2_only 3361 def testImportedHashTable(self, is_anonymous): 3362 g = ops.Graph() 3363 with g.as_default(): 3364 default_val = -1 3365 keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"]) 3366 values = constant_op.constant([0, 1, 2, 3], dtypes.int64) 3367 table = lookup_ops.MutableHashTable( 3368 dtypes.string, 3369 dtypes.int64, 3370 default_val, 3371 experimental_is_anonymous=is_anonymous) 3372 self.evaluate(table.insert(keys, values)) 3373 op = table.lookup(constant_op.constant(["brain", "salad", "tank"])) 3374 meta_graph = saver.export_meta_graph() 3375 3376 def f(): 3377 saver.import_meta_graph(meta_graph) 3378 return ops.get_default_graph().get_tensor_by_name(op.name) 3379 3380 wrapped = wrap_function.wrap_function(f, []) 3381 self.assertAllEqual([0, 1, -1], wrapped()) 3382 3383 @test_util.run_v1_only("SaverV1") 3384 def testSaveRestore(self, is_anonymous): 3385 if is_anonymous and not tf2.enabled(): 3386 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 3387 save_dir = os.path.join(self.get_temp_dir(), "save_restore") 3388 save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") 3389 3390 with self.session(graph=ops.Graph()) as sess: 3391 v0 = variables.Variable(10.0, name="v0") 3392 v1 = variables.Variable(20.0, name="v1") 3393 3394 default_val = -1 3395 keys = constant_op.constant(["b", "c", "d"], dtypes.string) 3396 values = constant_op.constant([0, 1, 2], dtypes.int64) 3397 table = lookup_ops.MutableHashTable( 3398 dtypes.string, 3399 dtypes.int64, 3400 default_val, 3401 name="t1", 3402 checkpoint=True, 3403 experimental_is_anonymous=is_anonymous) 3404 3405 save = saver.Saver() 3406 self.evaluate(variables.global_variables_initializer()) 3407 3408 # Check that the parameter nodes have been initialized. 3409 self.assertEqual(10.0, self.evaluate(v0)) 3410 self.assertEqual(20.0, self.evaluate(v1)) 3411 3412 self.assertAllEqual(0, self.evaluate(table.size())) 3413 self.evaluate(table.insert(keys, values)) 3414 self.assertAllEqual(3, self.evaluate(table.size())) 3415 3416 val = save.save(sess, save_path) 3417 self.assertIsInstance(val, str) 3418 self.assertEqual(save_path, val) 3419 3420 with self.session(graph=ops.Graph()) as sess: 3421 v0 = variables.Variable(-1.0, name="v0") 3422 v1 = variables.Variable(-1.0, name="v1") 3423 default_val = -1 3424 table = lookup_ops.MutableHashTable( 3425 dtypes.string, 3426 dtypes.int64, 3427 default_val, 3428 name="t1", 3429 checkpoint=True, 3430 experimental_is_anonymous=is_anonymous) 3431 self.evaluate( 3432 table.insert( 3433 constant_op.constant(["a", "c"], dtypes.string), 3434 constant_op.constant([12, 24], dtypes.int64))) 3435 self.assertAllEqual(2, self.evaluate(table.size())) 3436 3437 save = saver.Saver() 3438 3439 # Restore the saved values in the parameter nodes. 3440 save.restore(sess, save_path) 3441 # Check that the parameter nodes have been restored. 3442 self.assertEqual(10.0, self.evaluate(v0)) 3443 self.assertEqual(20.0, self.evaluate(v1)) 3444 3445 self.assertAllEqual(3, self.evaluate(table.size())) 3446 3447 input_string = constant_op.constant(["a", "b", "c", "d", "e"], 3448 dtypes.string) 3449 output = table.lookup(input_string) 3450 self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output)) 3451 3452 @test_util.run_v1_only("SaverV1") 3453 def testSaveRestoreOnlyTable(self, is_anonymous): 3454 if is_anonymous and not tf2.enabled(): 3455 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 3456 save_dir = os.path.join(self.get_temp_dir(), "save_restore") 3457 save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") 3458 3459 with self.session(graph=ops.Graph()) as sess: 3460 v0 = variables.Variable(10.0, name="v0") 3461 v1 = variables.Variable(20.0, name="v1") 3462 3463 default_val = -1 3464 keys = constant_op.constant(["b", "c", "d"], dtypes.string) 3465 values = constant_op.constant([0, 1, 2], dtypes.int64) 3466 table = lookup_ops.MutableHashTable( 3467 dtypes.string, 3468 dtypes.int64, 3469 default_val, 3470 name="t1", 3471 checkpoint=True, 3472 experimental_is_anonymous=is_anonymous) 3473 3474 save = saver.Saver([table]) 3475 self.evaluate(variables.global_variables_initializer()) 3476 3477 # Check that the parameter nodes have been initialized. 3478 self.assertEqual(10.0, self.evaluate(v0)) 3479 self.assertEqual(20.0, self.evaluate(v1)) 3480 3481 self.assertAllEqual(0, self.evaluate(table.size())) 3482 self.evaluate(table.insert(keys, values)) 3483 self.assertAllEqual(3, self.evaluate(table.size())) 3484 3485 val = save.save(sess, save_path) 3486 self.assertIsInstance(val, str) 3487 self.assertEqual(save_path, val) 3488 3489 with self.session(graph=ops.Graph()) as sess: 3490 default_val = -1 3491 table = lookup_ops.MutableHashTable( 3492 dtypes.string, 3493 dtypes.int64, 3494 default_val, 3495 name="t1", 3496 checkpoint=True, 3497 experimental_is_anonymous=is_anonymous) 3498 self.evaluate( 3499 table.insert( 3500 constant_op.constant(["a", "c"], dtypes.string), 3501 constant_op.constant([12, 24], dtypes.int64))) 3502 self.assertAllEqual(2, self.evaluate(table.size())) 3503 3504 save = saver.Saver([table]) 3505 3506 # Restore the saved values in the parameter nodes. 3507 save.restore(sess, save_path) 3508 # Check that the parameter nodes have been restored. 3509 3510 self.assertAllEqual(3, self.evaluate(table.size())) 3511 3512 input_string = constant_op.constant(["a", "b", "c", "d", "e"], 3513 dtypes.string) 3514 output = table.lookup(input_string) 3515 self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output)) 3516 3517 @test_util.run_in_graph_and_eager_modes 3518 def testObjectSaveRestore(self, is_anonymous): 3519 if is_anonymous and not context.executing_eagerly(): 3520 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 3521 save_dir = os.path.join(self.get_temp_dir(), "save_restore") 3522 save_prefix = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") 3523 3524 v0 = variables.Variable(10.0, name="v0") 3525 v1 = variables.Variable(20.0, name="v1") 3526 3527 default_val = -1 3528 keys = constant_op.constant(["b", "c", "d"], dtypes.string) 3529 values = constant_op.constant([0, 1, 2], dtypes.int64) 3530 table = lookup_ops.MutableHashTable( 3531 dtypes.string, 3532 dtypes.int64, 3533 default_val, 3534 name="t1", 3535 checkpoint=True, 3536 experimental_is_anonymous=is_anonymous) 3537 3538 checkpoint = trackable.Checkpoint(table=table, v0=v0, v1=v1) 3539 self.evaluate([v0.initializer, v1.initializer]) 3540 3541 # Check that the parameter nodes have been initialized. 3542 self.assertEqual(10.0, self.evaluate(v0)) 3543 self.assertEqual(20.0, self.evaluate(v1)) 3544 3545 self.assertAllEqual(0, self.evaluate(table.size())) 3546 self.evaluate(table.insert(keys, values)) 3547 self.assertAllEqual(3, self.evaluate(table.size())) 3548 3549 save_path = checkpoint.save(save_prefix) 3550 del table, checkpoint, v0, v1 3551 3552 v0 = variables.Variable(-1.0, name="v0") 3553 v1 = variables.Variable(-1.0, name="v1") 3554 default_val = -1 3555 table = lookup_ops.MutableHashTable( 3556 dtypes.string, 3557 dtypes.int64, 3558 default_val, 3559 name="t1", 3560 checkpoint=True, 3561 experimental_is_anonymous=is_anonymous) 3562 self.evaluate( 3563 table.insert( 3564 constant_op.constant(["a", "c"], dtypes.string), 3565 constant_op.constant([12, 24], dtypes.int64))) 3566 self.assertAllEqual(2, self.evaluate(table.size())) 3567 3568 checkpoint = trackable.Checkpoint(table=table, v0=v0, v1=v1) 3569 3570 # Restore the saved values in the parameter nodes. 3571 checkpoint.restore(save_path).run_restore_ops() 3572 # Check that the parameter nodes have been restored. 3573 self.assertEqual(10.0, self.evaluate(v0)) 3574 self.assertEqual(20.0, self.evaluate(v1)) 3575 3576 self.assertAllEqual(3, self.evaluate(table.size())) 3577 3578 input_string = constant_op.constant(["a", "b", "c", "d", "e"], 3579 dtypes.string) 3580 output = table.lookup(input_string) 3581 self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output)) 3582 3583 @test_util.run_v2_only 3584 def testSavedModelSaveRestore(self, is_anonymous): 3585 save_dir = os.path.join(self.get_temp_dir(), "save_restore") 3586 save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash") 3587 3588 root = autotrackable.AutoTrackable() 3589 3590 default_value = -1 3591 keys = constant_op.constant([11, 12, 13], dtypes.int64) 3592 values = constant_op.constant([0, 1, 2], dtypes.int64) 3593 root.table = lookup_ops.MutableHashTable( 3594 dtypes.int64, 3595 dtypes.int64, 3596 default_value, 3597 experimental_is_anonymous=is_anonymous) 3598 3599 @def_function.function( 3600 input_signature=[tensor_spec.TensorSpec((), dtypes.int64)]) 3601 def lookup(key): 3602 return root.table.lookup(key) 3603 3604 @def_function.function(input_signature=[]) 3605 def size(): 3606 return root.table.size() 3607 3608 @def_function.function(input_signature=[]) 3609 def is_ref_counting(): 3610 return test_ops.is_resource_handle_ref_counting( 3611 root.table.resource_handle) 3612 3613 root.lookup = lookup 3614 root.size = size 3615 root.is_ref_counting = is_ref_counting 3616 3617 self.assertEqual(root.table.size(), 0) 3618 root.table.insert(keys, values) 3619 self.assertEqual(root.table.size(), 3) 3620 self.assertEqual(root.table.lookup(12), 1) 3621 self.assertEqual(root.table.lookup(10), -1) 3622 self.assertEqual(len(root.table.export()[0]), 3) 3623 self.assertEqual(root.is_ref_counting(), is_anonymous) 3624 3625 saved_model_save.save(root, save_path) 3626 3627 del root 3628 loaded = saved_model_load.load(save_path) 3629 self.assertEqual(loaded.size(), 3) 3630 self.assertEqual(loaded.lookup(12), 1) 3631 self.assertEqual(loaded.lookup(10), -1) 3632 self.assertEqual(loaded.is_ref_counting(), is_anonymous) 3633 3634 @test_util.run_v1_only("Multiple sessions") 3635 def testSharing(self, is_anonymous): 3636 if is_anonymous and not tf2.enabled(): 3637 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 3638 # Start a server to store the table state 3639 server = server_lib.Server({"local0": ["localhost:0"]}, 3640 protocol="grpc", 3641 start=True) 3642 # Create two sessions sharing the same state 3643 session1 = session.Session(server.target) 3644 session2 = session.Session(server.target) 3645 3646 table = lookup_ops.MutableHashTable( 3647 dtypes.int64, 3648 dtypes.string, 3649 "-", 3650 name="t1", 3651 experimental_is_anonymous=is_anonymous) 3652 3653 # Populate the table in the first session 3654 with session1: 3655 self.assertAllEqual(0, table.size()) 3656 3657 keys = constant_op.constant([11, 12], dtypes.int64) 3658 values = constant_op.constant(["a", "b"]) 3659 table.insert(keys, values).run() 3660 self.assertAllEqual(2, table.size()) 3661 3662 output = table.lookup(constant_op.constant([11, 12, 13], dtypes.int64)) 3663 self.assertAllEqual([b"a", b"b", b"-"], output) 3664 3665 # Verify that we can access the shared data from the second session 3666 with session2: 3667 self.assertAllEqual(2, table.size()) 3668 3669 output = table.lookup(constant_op.constant([10, 11, 12], dtypes.int64)) 3670 self.assertAllEqual([b"-", b"a", b"b"], output) 3671 3672 def testMutableHashTableOfTensors(self, is_anonymous): 3673 if is_anonymous and not tf2.enabled(): 3674 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 3675 default_val = constant_op.constant([-1, -1], dtypes.int64) 3676 keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"]) 3677 values = constant_op.constant([[0, 1], [2, 3], [4, 5], [6, 7]], 3678 dtypes.int64) 3679 table = lookup_ops.MutableHashTable( 3680 dtypes.string, 3681 dtypes.int64, 3682 default_val, 3683 experimental_is_anonymous=is_anonymous) 3684 self.assertAllEqual(0, self.evaluate(table.size())) 3685 3686 self.evaluate(table.insert(keys, values)) 3687 self.assertAllEqual(4, self.evaluate(table.size())) 3688 3689 remove_string = constant_op.constant(["tarkus", "tank"]) 3690 self.evaluate(table.remove(remove_string)) 3691 self.assertAllEqual(3, self.evaluate(table.size())) 3692 3693 input_string = constant_op.constant(["brain", "salad", "tank"]) 3694 output = table.lookup(input_string) 3695 self.assertAllEqual([3, 2], output.get_shape()) 3696 3697 result = self.evaluate(output) 3698 self.assertAllEqual([[0, 1], [2, 3], [-1, -1]], result) 3699 3700 exported_keys, exported_values = table.export() 3701 # exported data is in the order of the internal map, i.e. undefined 3702 sorted_keys = np.sort(self.evaluate(exported_keys)) 3703 sorted_values = np.sort(self.evaluate(exported_values), axis=0) 3704 self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys) 3705 sorted_expected_values = np.sort([[4, 5], [2, 3], [0, 1]], axis=0) 3706 self.assertAllEqual(sorted_expected_values, sorted_values) 3707 3708 def testMutableHashTableExportInsert(self, is_anonymous): 3709 if is_anonymous and not tf2.enabled(): 3710 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 3711 default_val = constant_op.constant([-1, -1], dtypes.int64) 3712 keys = constant_op.constant(["brain", "salad", "surgery"]) 3713 values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) 3714 table1 = lookup_ops.MutableHashTable( 3715 dtypes.string, 3716 dtypes.int64, 3717 default_val, 3718 experimental_is_anonymous=is_anonymous) 3719 self.assertAllEqual(0, self.evaluate(table1.size())) 3720 self.evaluate(table1.insert(keys, values)) 3721 self.assertAllEqual(3, self.evaluate(table1.size())) 3722 3723 input_string = constant_op.constant(["brain", "salad", "tank"]) 3724 expected_output = [[0, 1], [2, 3], [-1, -1]] 3725 output1 = table1.lookup(input_string) 3726 self.assertAllEqual(expected_output, self.evaluate(output1)) 3727 3728 exported_keys, exported_values = table1.export() 3729 self.assertAllEqual(3, self.evaluate(exported_keys).size) 3730 self.assertAllEqual(6, self.evaluate(exported_values).size) 3731 3732 # Populate a second table from the exported data 3733 table2 = lookup_ops.MutableHashTable( 3734 dtypes.string, 3735 dtypes.int64, 3736 default_val, 3737 experimental_is_anonymous=is_anonymous) 3738 self.assertAllEqual(0, self.evaluate(table2.size())) 3739 self.evaluate(table2.insert(exported_keys, exported_values)) 3740 self.assertAllEqual(3, self.evaluate(table2.size())) 3741 3742 # Verify lookup result is still the same 3743 output2 = table2.lookup(input_string) 3744 self.assertAllEqual(expected_output, self.evaluate(output2)) 3745 3746 def testMutableHashTableOfTensorsInvalidShape(self, is_anonymous): 3747 if is_anonymous and not tf2.enabled(): 3748 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 3749 default_val = constant_op.constant([-1, -1], dtypes.int64) 3750 keys = constant_op.constant(["brain", "salad", "surgery"]) 3751 table = lookup_ops.MutableHashTable( 3752 dtypes.string, 3753 dtypes.int64, 3754 default_val, 3755 experimental_is_anonymous=is_anonymous) 3756 3757 # Shape [6] instead of [3, 2] 3758 values = constant_op.constant([0, 1, 2, 3, 4, 5], dtypes.int64) 3759 with self.assertRaisesOpError("Expected shape"): 3760 self.evaluate(table.insert(keys, values)) 3761 3762 # Shape [2,3] instead of [3, 2] 3763 values = constant_op.constant([[0, 1, 2], [3, 4, 5]], dtypes.int64) 3764 with self.assertRaisesOpError("Expected shape"): 3765 self.evaluate(table.insert(keys, values)) 3766 3767 # Shape [2, 2] instead of [3, 2] 3768 values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) 3769 with self.assertRaisesOpError("Expected shape"): 3770 self.evaluate(table.insert(keys, values)) 3771 3772 # Shape [3, 1] instead of [3, 2] 3773 values = constant_op.constant([[0], [2], [4]], dtypes.int64) 3774 with self.assertRaisesOpError("Expected shape"): 3775 self.evaluate(table.insert(keys, values)) 3776 3777 # Valid Insert 3778 values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) 3779 self.evaluate(table.insert(keys, values)) 3780 self.assertAllEqual(3, self.evaluate(table.size())) 3781 3782 def testMutableHashTableInvalidDefaultValue(self, is_anonymous): 3783 default_val = constant_op.constant([[-1, -1]], dtypes.int64) 3784 with self.assertRaisesOpError("Default value must be a vector"): 3785 table = lookup_ops.MutableHashTable( 3786 dtypes.string, 3787 dtypes.int64, 3788 default_val, 3789 experimental_is_anonymous=is_anonymous) 3790 self.assertAllEqual(0, self.evaluate(table.size())) 3791 3792 def testMutableHashTableDuplicateInsert(self, is_anonymous): 3793 if is_anonymous and not tf2.enabled(): 3794 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 3795 default_val = -1 3796 keys = constant_op.constant(["brain", "salad", "surgery", "brain"]) 3797 values = constant_op.constant([0, 1, 2, 3], dtypes.int64) 3798 table = lookup_ops.MutableHashTable( 3799 dtypes.string, 3800 dtypes.int64, 3801 default_val, 3802 experimental_is_anonymous=is_anonymous) 3803 self.assertAllEqual(0, self.evaluate(table.size())) 3804 3805 self.evaluate(table.insert(keys, values)) 3806 self.assertAllEqual(3, self.evaluate(table.size())) 3807 3808 input_string = constant_op.constant(["brain", "salad", "tank"]) 3809 output = table.lookup(input_string) 3810 3811 result = self.evaluate(output) 3812 self.assertAllEqual([3, 1, -1], result) 3813 3814 def testMutableHashTableFindHighRank(self, is_anonymous): 3815 if is_anonymous and not tf2.enabled(): 3816 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 3817 default_val = -1 3818 keys = constant_op.constant(["brain", "salad", "surgery"]) 3819 values = constant_op.constant([0, 1, 2], dtypes.int64) 3820 table = lookup_ops.MutableHashTable( 3821 dtypes.string, 3822 dtypes.int64, 3823 default_val, 3824 experimental_is_anonymous=is_anonymous) 3825 3826 self.evaluate(table.insert(keys, values)) 3827 self.assertAllEqual(3, self.evaluate(table.size())) 3828 3829 input_string = constant_op.constant([["brain", "salad"], 3830 ["tank", "tarkus"]]) 3831 output = table.lookup(input_string) 3832 self.assertAllEqual([2, 2], output.get_shape()) 3833 3834 result = self.evaluate(output) 3835 self.assertAllEqual([[0, 1], [-1, -1]], result) 3836 3837 def testMutableHashTableFindWithInvalidShapeDefaultValue(self, is_anonymous): 3838 default_val = [-1, -1] 3839 table = lookup_ops.MutableHashTable( 3840 dtypes.string, 3841 dtypes.int64, 3842 default_val, 3843 experimental_is_anonymous=is_anonymous) 3844 3845 input_string = constant_op.constant([["brain", "salad"], ["tank", 3846 "tarkus"]]) 3847 3848 invalid_default_val = constant_op.constant( 3849 [[-2, -3], [-4, -5], [-6, -7], [-8, -9]], dtypes.int64) 3850 3851 with self.assertRaisesRegex( 3852 (ValueError, errors_impl.InvalidArgumentError), 3853 "Expected shape \[2\] or \[2,2,2\] for default value, got \[4,2]"): 3854 self.evaluate(table.lookup(input_string, invalid_default_val)) 3855 3856 invalid_default_val = constant_op.constant([[[-2, -3], [-4, -5]]], 3857 dtypes.int64) 3858 with self.assertRaisesRegex( 3859 (ValueError, errors_impl.InvalidArgumentError), 3860 "Expected shape \[2\] or \[2,2,2\] for default value, got \[1,2,2\]"): 3861 self.evaluate(table.lookup(input_string, invalid_default_val)) 3862 3863 def testMutableHashTableFindHighRankScalarWithDynamicDefaultValue( 3864 self, is_anonymous): 3865 if is_anonymous and not tf2.enabled(): 3866 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 3867 default_val = -1 3868 keys = constant_op.constant(["brain", "salad", "surgery"]) 3869 values = constant_op.constant([0, 1, 2], dtypes.int64) 3870 table = lookup_ops.MutableHashTable( 3871 dtypes.string, 3872 dtypes.int64, 3873 default_val, 3874 experimental_is_anonymous=is_anonymous) 3875 3876 self.evaluate(table.insert(keys, values)) 3877 self.assertAllEqual(3, self.evaluate(table.size())) 3878 3879 input_string = constant_op.constant([["brain", "salad"], ["tank", 3880 "tarkus"]]) 3881 3882 dynamic_default_val = constant_op.constant([[-2, -3], [-4, -5]], 3883 dtypes.int64) 3884 output = table.lookup(input_string, dynamic_default_val) 3885 self.assertAllEqual([2, 2], output.get_shape()) 3886 3887 result = self.evaluate(output) 3888 self.assertAllEqual([[0, 1], [-4, -5]], result) 3889 3890 def testMutableHashTableFindHighRankVectorWithDynamicDefaultValue( 3891 self, is_anonymous): 3892 if is_anonymous and not tf2.enabled(): 3893 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 3894 default_val = [-1, -1] 3895 keys = constant_op.constant(["brain", "salad", "surgery"]) 3896 values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64) 3897 table = lookup_ops.MutableHashTable( 3898 dtypes.string, 3899 dtypes.int64, 3900 default_val, 3901 experimental_is_anonymous=is_anonymous) 3902 3903 self.evaluate(table.insert(keys, values)) 3904 self.assertAllEqual(3, self.evaluate(table.size())) 3905 3906 input_string = constant_op.constant([["brain", "salad"], ["tank", 3907 "tarkus"]]) 3908 3909 dynamic_default_val = constant_op.constant( 3910 [[[-2, -3], [-4, -5]], [[-6, -7], [-8, -9]]], dtypes.int64) 3911 output = table.lookup(input_string, dynamic_default_val) 3912 self.assertAllEqual([2, 2, 2], output.get_shape()) 3913 3914 result = self.evaluate(output) 3915 self.assertAllEqual([[[0, 1], [2, 3]], [[-6, -7], [-8, -9]]], result) 3916 3917 def testMutableHashTableInsertHighRank(self, is_anonymous): 3918 if is_anonymous and not tf2.enabled(): 3919 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 3920 default_val = -1 3921 keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]]) 3922 values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) 3923 table = lookup_ops.MutableHashTable( 3924 dtypes.string, 3925 dtypes.int64, 3926 default_val, 3927 experimental_is_anonymous=is_anonymous) 3928 3929 self.evaluate(table.insert(keys, values)) 3930 self.assertAllEqual(4, self.evaluate(table.size())) 3931 3932 input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"]) 3933 output = table.lookup(input_string) 3934 3935 result = self.evaluate(output) 3936 self.assertAllEqual([0, 1, 3, -1], result) 3937 3938 def testMutableHashTableRemoveHighRank(self, is_anonymous): 3939 if is_anonymous and not tf2.enabled(): 3940 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 3941 default_val = -1 3942 keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]]) 3943 values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64) 3944 table = lookup_ops.MutableHashTable( 3945 dtypes.string, 3946 dtypes.int64, 3947 default_val, 3948 experimental_is_anonymous=is_anonymous) 3949 3950 self.evaluate(table.insert(keys, values)) 3951 self.assertAllEqual(4, self.evaluate(table.size())) 3952 3953 remove_string = constant_op.constant(["salad", "tarkus"]) 3954 self.evaluate(table.remove(remove_string)) 3955 self.assertAllEqual(3, self.evaluate(table.size())) 3956 3957 input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"]) 3958 output = table.lookup(input_string) 3959 3960 result = self.evaluate(output) 3961 self.assertAllEqual([0, -1, 3, -1], result) 3962 3963 def testMutableHashTableOfTensorsFindHighRank(self, is_anonymous): 3964 if is_anonymous and not tf2.enabled(): 3965 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 3966 default_val = constant_op.constant([-1, -1, -1], dtypes.int64) 3967 keys = constant_op.constant(["brain", "salad", "surgery"]) 3968 values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]], 3969 dtypes.int64) 3970 table = lookup_ops.MutableHashTable( 3971 dtypes.string, 3972 dtypes.int64, 3973 default_val, 3974 experimental_is_anonymous=is_anonymous) 3975 3976 self.evaluate(table.insert(keys, values)) 3977 self.assertAllEqual(3, self.evaluate(table.size())) 3978 3979 input_string = constant_op.constant([["brain", "salad"], 3980 ["tank", "tarkus"]]) 3981 output = table.lookup(input_string) 3982 self.assertAllEqual([2, 2, 3], output.get_shape()) 3983 3984 result = self.evaluate(output) 3985 self.assertAllEqual( 3986 [[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result) 3987 3988 def testMutableHashTableOfTensorsRemoveHighRank(self, is_anonymous): 3989 if is_anonymous and not tf2.enabled(): 3990 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 3991 default_val = constant_op.constant([-1, -1, -1], dtypes.int64) 3992 keys = constant_op.constant(["brain", "salad", "surgery"]) 3993 values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]], 3994 dtypes.int64) 3995 table = lookup_ops.MutableHashTable( 3996 dtypes.string, 3997 dtypes.int64, 3998 default_val, 3999 experimental_is_anonymous=is_anonymous) 4000 4001 self.evaluate(table.insert(keys, values)) 4002 self.assertAllEqual(3, self.evaluate(table.size())) 4003 4004 remove_string = constant_op.constant([["brain", "tank"]]) 4005 self.evaluate(table.remove(remove_string)) 4006 self.assertAllEqual(2, self.evaluate(table.size())) 4007 4008 input_string = constant_op.constant([["brain", "salad"], 4009 ["surgery", "tank"]]) 4010 output = table.lookup(input_string) 4011 self.assertAllEqual([2, 2, 3], output.get_shape()) 4012 4013 result = self.evaluate(output) 4014 self.assertAllEqual( 4015 [[[-1, -1, -1], [2, 3, 4]], [[4, 5, 6], [-1, -1, -1]]], result) 4016 4017 def testMultipleMutableHashTables(self, is_anonymous): 4018 if is_anonymous and not tf2.enabled(): 4019 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 4020 default_val = -1 4021 keys = constant_op.constant(["brain", "salad", "surgery"]) 4022 values = constant_op.constant([0, 1, 2], dtypes.int64) 4023 4024 table1 = lookup_ops.MutableHashTable( 4025 dtypes.string, 4026 dtypes.int64, 4027 default_val, 4028 experimental_is_anonymous=is_anonymous) 4029 table2 = lookup_ops.MutableHashTable( 4030 dtypes.string, 4031 dtypes.int64, 4032 default_val, 4033 experimental_is_anonymous=is_anonymous) 4034 table3 = lookup_ops.MutableHashTable( 4035 dtypes.string, 4036 dtypes.int64, 4037 default_val, 4038 experimental_is_anonymous=is_anonymous) 4039 self.evaluate(table1.insert(keys, values)) 4040 self.evaluate(table2.insert(keys, values)) 4041 self.evaluate(table3.insert(keys, values)) 4042 4043 self.assertAllEqual(3, self.evaluate(table1.size())) 4044 self.assertAllEqual(3, self.evaluate(table2.size())) 4045 self.assertAllEqual(3, self.evaluate(table3.size())) 4046 4047 input_string = constant_op.constant(["brain", "salad", "tank"]) 4048 output1 = table1.lookup(input_string) 4049 output2 = table2.lookup(input_string) 4050 output3 = table3.lookup(input_string) 4051 4052 out1, out2, out3 = self.evaluate([output1, output2, output3]) 4053 self.assertAllEqual([0, 1, -1], out1) 4054 self.assertAllEqual([0, 1, -1], out2) 4055 self.assertAllEqual([0, 1, -1], out3) 4056 4057 def testMutableHashTableWithTensorDefault(self, is_anonymous): 4058 if is_anonymous and not tf2.enabled(): 4059 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 4060 default_val = constant_op.constant(-1, dtypes.int64) 4061 keys = constant_op.constant(["brain", "salad", "surgery"]) 4062 values = constant_op.constant([0, 1, 2], dtypes.int64) 4063 table = lookup_ops.MutableHashTable( 4064 dtypes.string, 4065 dtypes.int64, 4066 default_val, 4067 experimental_is_anonymous=is_anonymous) 4068 4069 self.evaluate(table.insert(keys, values)) 4070 self.assertAllEqual(3, self.evaluate(table.size())) 4071 4072 input_string = constant_op.constant(["brain", "salad", "tank"]) 4073 output = table.lookup(input_string) 4074 4075 result = self.evaluate(output) 4076 self.assertAllEqual([0, 1, -1], result) 4077 4078 def testSignatureMismatch(self, is_anonymous): 4079 if is_anonymous and not tf2.enabled(): 4080 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 4081 default_val = -1 4082 keys = constant_op.constant(["brain", "salad", "surgery"]) 4083 values = constant_op.constant([0, 1, 2], dtypes.int64) 4084 table = lookup_ops.MutableHashTable( 4085 dtypes.string, 4086 dtypes.int64, 4087 default_val, 4088 experimental_is_anonymous=is_anonymous) 4089 4090 # insert with keys of the wrong type 4091 with self.assertRaises(ValueError): 4092 self.evaluate(table.insert(constant_op.constant([4, 5, 6]), values)) 4093 4094 # insert with values of the wrong type 4095 with self.assertRaises(ValueError): 4096 self.evaluate(table.insert(keys, constant_op.constant(["a", "b", "c"]))) 4097 4098 self.assertAllEqual(0, self.evaluate(table.size())) 4099 4100 self.evaluate(table.insert(keys, values)) 4101 self.assertAllEqual(3, self.evaluate(table.size())) 4102 4103 input_string_ref = variables.Variable("brain") 4104 input_int64_ref = variables.Variable(-1, dtype=dtypes.int64) 4105 self.evaluate(variables.global_variables_initializer()) 4106 4107 # Ref types do not produce an insert signature mismatch. 4108 self.evaluate(table.insert(input_string_ref, input_int64_ref)) 4109 self.assertAllEqual(3, self.evaluate(table.size())) 4110 4111 # Ref types do not produce a lookup signature mismatch. 4112 self.assertEqual(-1, self.evaluate(table.lookup(input_string_ref))) 4113 4114 # lookup with keys of the wrong type 4115 input_string = constant_op.constant([1, 2, 3], dtypes.int64) 4116 with self.assertRaises(ValueError): 4117 self.evaluate(table.lookup(input_string)) 4118 4119 # default value of the wrong type 4120 with self.assertRaises(TypeError): 4121 lookup_ops.MutableHashTable( 4122 dtypes.string, 4123 dtypes.int64, 4124 "UNK", 4125 experimental_is_anonymous=is_anonymous) 4126 4127 def testMutableHashTableStringFloat(self, is_anonymous): 4128 if is_anonymous and not tf2.enabled(): 4129 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 4130 default_val = -1.5 4131 keys = constant_op.constant(["brain", "salad", "surgery"]) 4132 values = constant_op.constant([0, 1.1, 2.2], dtypes.float32) 4133 table = lookup_ops.MutableHashTable( 4134 dtypes.string, 4135 dtypes.float32, 4136 default_val, 4137 experimental_is_anonymous=is_anonymous) 4138 self.assertAllEqual(0, self.evaluate(table.size())) 4139 4140 self.evaluate(table.insert(keys, values)) 4141 self.assertAllEqual(3, self.evaluate(table.size())) 4142 4143 input_string = constant_op.constant(["brain", "salad", "tank"]) 4144 output = table.lookup(input_string) 4145 4146 result = self.evaluate(output) 4147 self.assertAllClose([0, 1.1, default_val], result) 4148 4149 def testMutableHashTableIntFloat(self, is_anonymous): 4150 if is_anonymous and not tf2.enabled(): 4151 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 4152 default_val = -1.0 4153 keys = constant_op.constant([3, 7, 0], dtypes.int64) 4154 values = constant_op.constant([7.5, -1.2, 9.9], dtypes.float32) 4155 table = lookup_ops.MutableHashTable( 4156 dtypes.int64, 4157 dtypes.float32, 4158 default_val, 4159 experimental_is_anonymous=is_anonymous) 4160 self.assertAllEqual(0, self.evaluate(table.size())) 4161 4162 self.evaluate(table.insert(keys, values)) 4163 self.assertAllEqual(3, self.evaluate(table.size())) 4164 4165 input_string = constant_op.constant([7, 0, 11], dtypes.int64) 4166 output = table.lookup(input_string) 4167 4168 result = self.evaluate(output) 4169 self.assertAllClose([-1.2, 9.9, default_val], result) 4170 4171 def testMutableHashTableInt64String(self, is_anonymous): 4172 if is_anonymous and not tf2.enabled(): 4173 self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON) 4174 default_val = "n/a" 4175 keys = constant_op.constant([0, 1, 2], dtypes.int64) 4176 values = constant_op.constant(["brain", "salad", "surgery"]) 4177 table = lookup_ops.MutableHashTable( 4178 dtypes.int64, 4179 dtypes.string, 4180 default_val, 4181 experimental_is_anonymous=is_anonymous) 4182 self.assertAllEqual(0, self.evaluate(table.size())) 4183 4184 self.evaluate(table.insert(keys, values)) 4185 self.assertAllEqual(3, self.evaluate(table.size())) 4186 4187 input_string = constant_op.constant([0, 1, 3], dtypes.int64) 4188 output = table.lookup(input_string) 4189 4190 result = self.evaluate(output) 4191 self.assertAllEqual((b"brain", b"salad", b"n/a"), result) 4192 4193 def testExportShapeInference(self, is_anonymous): 4194 default_value = -1 4195 table = lookup_ops.MutableHashTable( 4196 dtypes.int64, 4197 dtypes.int64, 4198 default_value=default_value, 4199 experimental_is_anonymous=is_anonymous) 4200 actual_shapes = [t.shape for t in table.export()] 4201 inferred_shapes = [] 4202 4203 @def_function.function 4204 def f(): 4205 for t in table.export(): 4206 inferred_shapes.append(t.shape) 4207 4208 f() 4209 self.assertLen(actual_shapes, 2) 4210 self.assertLen(inferred_shapes, 2) 4211 self.assertTrue(inferred_shapes[0].is_compatible_with(actual_shapes[0])) 4212 self.assertTrue(inferred_shapes[1].is_compatible_with(actual_shapes[1])) 4213 4214 4215class MutableHashTableBenchmark(test.Benchmark): 4216 4217 def _create_table(self): 4218 return lookup_ops.MutableHashTable(dtypes.int64, dtypes.float32, 0.0) 4219 4220 def benchmark_single_repeated_scalar_insert_scalar(self): 4221 table = self._create_table() 4222 value = variables.Variable(1.0) 4223 insert = table.insert(0, value) 4224 size = table.size() 4225 with session.Session() as sess: 4226 sess.run(value.initializer) 4227 self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=10000) 4228 assert sess.run(size) == 1 4229 4230 def benchmark_many_repeated_scalar_insert_scalar(self): 4231 table = self._create_table() 4232 c = dataset_ops.make_one_shot_iterator(counter.Counter()).get_next() 4233 value = variables.Variable(1.0) 4234 insert = table.insert(c, value) 4235 size = table.size() 4236 with session.Session() as sess: 4237 sess.run(value.initializer) 4238 self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=10000) 4239 assert sess.run(size) >= 10000 4240 4241 def benchmark_single_repeated_batch_32_insert_scalar(self): 4242 table = self._create_table() 4243 value = variables.Variable([1.0] * 32) 4244 insert = table.insert(list(range(32)), value) 4245 size = table.size() 4246 with session.Session() as sess: 4247 sess.run(value.initializer) 4248 self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=1000) 4249 assert sess.run(size) == 32 4250 4251 def benchmark_many_repeated_batch_32_insert_scalar(self): 4252 table = self._create_table() 4253 c = dataset_ops.make_one_shot_iterator(counter.Counter()).get_next() 4254 value = variables.Variable([1.0] * 32) 4255 insert = table.insert(32 * c + list(range(32)), value) 4256 size = table.size() 4257 with session.Session() as sess: 4258 sess.run(value.initializer) 4259 self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=1000) 4260 assert sess.run(size) >= 1000 * 32 4261 4262 4263class DenseHashTableBenchmark(MutableHashTableBenchmark): 4264 4265 def _create_table(self): 4266 return lookup_ops.DenseHashTable( 4267 dtypes.int64, 4268 dtypes.float32, 4269 default_value=0.0, 4270 empty_key=-1, 4271 deleted_key=-2) 4272 4273 4274if __name__ == "__main__": 4275 test.main() 4276