xref: /aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/data_structures/lookup_ops_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for lookup ops."""
16import os
17import tempfile
18import unittest
19
20from absl.testing import parameterized
21import numpy as np
22
23from tensorflow.python import tf2
24from tensorflow.python.checkpoint import checkpoint as trackable
25from tensorflow.python.checkpoint import graph_view
26from tensorflow.python.checkpoint import util as checkpoint_util
27from tensorflow.python.client import session
28from tensorflow.python.data.experimental.ops import counter
29from tensorflow.python.data.ops import dataset_ops
30from tensorflow.python.eager import backprop
31from tensorflow.python.eager import context
32from tensorflow.python.eager import def_function
33from tensorflow.python.eager import function
34from tensorflow.python.eager import wrap_function
35from tensorflow.python.framework import constant_op
36from tensorflow.python.framework import dtypes
37from tensorflow.python.framework import errors_impl
38from tensorflow.python.framework import ops
39from tensorflow.python.framework import sparse_tensor
40from tensorflow.python.framework import tensor_spec
41from tensorflow.python.framework import test_ops
42from tensorflow.python.framework import test_util
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import control_flow_ops
45from tensorflow.python.ops import lookup_ops
46from tensorflow.python.ops import map_fn
47from tensorflow.python.ops import variables
48from tensorflow.python.ops.ragged import ragged_tensor
49from tensorflow.python.platform import test
50from tensorflow.python.saved_model import load as saved_model_load
51from tensorflow.python.saved_model import save as saved_model_save
52from tensorflow.python.trackable import asset
53from tensorflow.python.trackable import autotrackable
54from tensorflow.python.training import saver
55from tensorflow.python.training import server_lib
56from tensorflow.python.util import compat
57
58
59class BaseLookupTableTest(test.TestCase):
60
61  def getHashTable(self):
62    if tf2.enabled():
63      return lookup_ops.StaticHashTable
64    else:
65      return lookup_ops.StaticHashTableV1
66
67  def getVocabularyTable(self):
68    if tf2.enabled():
69      return lookup_ops.StaticVocabularyTable
70    else:
71      return lookup_ops.StaticVocabularyTableV1
72
73  def initialize_table(self, table):
74    if not tf2.enabled():
75      self.evaluate(table.initializer)
76
77
78SKIP_ANONYMOUS_IN_TF1_REASON = (
79    "In v1 graph mode, each self.evaluate call will execute the handle "
80    "creation op (e.g. AnonymousHashTable) which will create a new table "
81    "resource unrelated to other self.evaluate calls, so we can't test "
82    "anonymous resources with self.evaluate ."
83)
84
85
86@parameterized.named_parameters(
87    (f"_{is_anonymous}", is_anonymous) for is_anonymous in [False, True])
88class StaticHashTableTest(BaseLookupTableTest, parameterized.TestCase):
89
90  def testStaticHashTable(self, is_anonymous):
91    if is_anonymous and not tf2.enabled():
92      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
93    default_val = -1
94    keys = constant_op.constant(["brain", "salad", "surgery"])
95    values = constant_op.constant([0, 1, 2], dtypes.int64)
96    table = self.getHashTable()(
97        lookup_ops.KeyValueTensorInitializer(keys, values),
98        default_val,
99        experimental_is_anonymous=is_anonymous)
100    self.assertEqual(table._is_anonymous, is_anonymous)
101    self.initialize_table(table)
102
103    self.assertAllEqual(3, self.evaluate(table.size()))
104
105    input_string = constant_op.constant(["brain", "salad", "tank"])
106    output = table.lookup(input_string)
107    self.assertAllEqual([3], output.get_shape())
108
109    result = self.evaluate(output)
110    self.assertAllEqual([0, 1, -1], result)
111
112    exported_keys_tensor, exported_values_tensor = table.export()
113
114    self.assertItemsEqual([b"brain", b"salad", b"surgery"],
115                          self.evaluate(exported_keys_tensor))
116    self.assertItemsEqual([0, 1, 2], self.evaluate(exported_values_tensor))
117
118  def testStaticHashTableFindHighRank(self, is_anonymous):
119    if is_anonymous and not tf2.enabled():
120      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
121    default_val = -1
122    keys = constant_op.constant(["brain", "salad", "surgery"])
123    values = constant_op.constant([0, 1, 2], dtypes.int64)
124    table = self.getHashTable()(
125        lookup_ops.KeyValueTensorInitializer(keys, values),
126        default_val,
127        experimental_is_anonymous=is_anonymous)
128    self.initialize_table(table)
129
130    self.assertAllEqual(3, self.evaluate(table.size()))
131
132    input_string = constant_op.constant([["brain", "salad"],
133                                         ["tank", "tarkus"]])
134    output = table.lookup(input_string)
135
136    result = self.evaluate(output)
137    self.assertAllEqual([[0, 1], [-1, -1]], result)
138
139  def testStaticHashTableInitWithPythonArrays(self, is_anonymous):
140    if is_anonymous and not tf2.enabled():
141      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
142    default_val = -1
143    keys = ["brain", "salad", "surgery"]
144    values = [0, 1, 2]
145    table = self.getHashTable()(
146        lookup_ops.KeyValueTensorInitializer(
147            keys, values, value_dtype=dtypes.int64),
148        default_val,
149        experimental_is_anonymous=is_anonymous)
150    self.initialize_table(table)
151
152    self.assertAllEqual(3, self.evaluate(table.size()))
153
154    input_string = constant_op.constant(["brain", "salad", "tank"])
155    output = table.lookup(input_string)
156
157    result = self.evaluate(output)
158    self.assertAllEqual([0, 1, -1], result)
159
160  def testStaticHashTableInitWithNumPyArrays(self, is_anonymous):
161    if is_anonymous and not tf2.enabled():
162      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
163    default_val = -1
164    keys = np.array(["brain", "salad", "surgery"], dtype=np.str_)
165    values = np.array([0, 1, 2], dtype=np.int64)
166    table = self.getHashTable()(
167        lookup_ops.KeyValueTensorInitializer(keys, values),
168        default_val,
169        experimental_is_anonymous=is_anonymous)
170    self.initialize_table(table)
171
172    self.assertAllEqual(3, self.evaluate(table.size()))
173
174    input_string = constant_op.constant(["brain", "salad", "tank"])
175    output = table.lookup(input_string)
176
177    result = self.evaluate(output)
178    self.assertAllEqual([0, 1, -1], result)
179
180  def testMultipleStaticHashTables(self, is_anonymous):
181    if is_anonymous and not tf2.enabled():
182      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
183    default_val = -1
184    keys = constant_op.constant(["brain", "salad", "surgery"])
185    values = constant_op.constant([0, 1, 2], dtypes.int64)
186
187    table1 = self.getHashTable()(
188        lookup_ops.KeyValueTensorInitializer(keys, values),
189        default_val,
190        experimental_is_anonymous=is_anonymous)
191    table2 = self.getHashTable()(
192        lookup_ops.KeyValueTensorInitializer(keys, values),
193        default_val,
194        experimental_is_anonymous=is_anonymous)
195    table3 = self.getHashTable()(
196        lookup_ops.KeyValueTensorInitializer(keys, values),
197        default_val,
198        experimental_is_anonymous=is_anonymous)
199
200    self.initialize_table(table1)
201    self.initialize_table(table2)
202    self.initialize_table(table3)
203    self.assertAllEqual(3, self.evaluate(table1.size()))
204    self.assertAllEqual(3, self.evaluate(table2.size()))
205    self.assertAllEqual(3, self.evaluate(table3.size()))
206
207    input_string = constant_op.constant(["brain", "salad", "tank"])
208    output1 = table1.lookup(input_string)
209    output2 = table2.lookup(input_string)
210    output3 = table3.lookup(input_string)
211
212    out1, out2, out3 = self.evaluate([output1, output2, output3])
213    self.assertAllEqual([0, 1, -1], out1)
214    self.assertAllEqual([0, 1, -1], out2)
215    self.assertAllEqual([0, 1, -1], out3)
216
217  def testStaticHashTableWithTensorDefault(self, is_anonymous):
218    if is_anonymous and not tf2.enabled():
219      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
220    default_val = constant_op.constant(-1, dtypes.int64)
221    keys = constant_op.constant(["brain", "salad", "surgery"])
222    values = constant_op.constant([0, 1, 2], dtypes.int64)
223    table = self.getHashTable()(
224        lookup_ops.KeyValueTensorInitializer(keys, values),
225        default_val,
226        experimental_is_anonymous=is_anonymous)
227    self.initialize_table(table)
228
229    input_string = constant_op.constant(["brain", "salad", "tank"])
230    output = table.lookup(input_string)
231
232    result = self.evaluate(output)
233    self.assertAllEqual([0, 1, -1], result)
234
235  def testStaticHashTableGetItem(self, is_anonymous):
236    if is_anonymous and not tf2.enabled():
237      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
238    default_val = constant_op.constant(-1, dtypes.int64)
239    keys = constant_op.constant(["brain", "salad", "surgery"])
240    values = constant_op.constant([0, 1, 2], dtypes.int64)
241    table = self.getHashTable()(
242        lookup_ops.KeyValueTensorInitializer(keys, values),
243        default_val,
244        experimental_is_anonymous=is_anonymous)
245    self.initialize_table(table)
246
247    input_string = constant_op.constant(["brain", "salad", "tank"])
248    output = table[input_string]
249
250    result = self.evaluate(output)
251    self.assertAllEqual([0, 1, -1], result)
252
253  def testStaticHashTableWithSparseTensorInput(self, is_anonymous):
254    if is_anonymous and not tf2.enabled():
255      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
256    default_val = constant_op.constant(-1, dtypes.int64)
257    keys = constant_op.constant(["brain", "salad", "surgery"])
258    values = constant_op.constant([0, 1, 2], dtypes.int64)
259    table = self.getHashTable()(
260        lookup_ops.KeyValueTensorInitializer(keys, values),
261        default_val,
262        experimental_is_anonymous=is_anonymous)
263    self.initialize_table(table)
264
265    sp_indices = [[0, 0], [0, 1], [1, 0]]
266    sp_shape = [2, 2]
267    input_tensor = sparse_tensor.SparseTensor(
268        constant_op.constant(sp_indices, dtypes.int64),
269        constant_op.constant(["brain", "salad", "tank"]),
270        constant_op.constant(sp_shape, dtypes.int64))
271    output = table.lookup(input_tensor)
272
273    out_indices, out_values, out_shape = self.evaluate(output)
274
275    self.assertAllEqual([0, 1, -1], out_values)
276    self.assertAllEqual(sp_indices, out_indices)
277    self.assertAllEqual(sp_shape, out_shape)
278
279  def testStaticHashTableWithRaggedTensorInput(self, is_anonymous):
280    if is_anonymous and not tf2.enabled():
281      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
282    default_val = constant_op.constant(-1, dtypes.int64)
283    keys = constant_op.constant(["brain", "salad", "surgery"])
284    values = constant_op.constant([0, 1, 2], dtypes.int64)
285    table = self.getHashTable()(
286        lookup_ops.KeyValueTensorInitializer(keys, values),
287        default_val,
288        experimental_is_anonymous=is_anonymous)
289    self.initialize_table(table)
290
291    row_splits = [0, 2, 3]
292    input_tensor = ragged_tensor.RaggedTensor.from_row_splits(
293        constant_op.constant(["brain", "salad", "tank"]),
294        constant_op.constant(row_splits, dtypes.int64))
295    output = table.lookup(input_tensor)
296
297    out = self.evaluate(output)
298
299    self.assertAllEqual([0, 1, -1], out.values)
300    self.assertAllEqual(row_splits, out.row_splits)
301
302  def testSignatureMismatch(self, is_anonymous):
303    if is_anonymous and not tf2.enabled():
304      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
305    default_val = -1
306    keys = constant_op.constant(["brain", "salad", "surgery"])
307    values = constant_op.constant([0, 1, 2], dtypes.int64)
308    table = self.getHashTable()(
309        lookup_ops.KeyValueTensorInitializer(keys, values),
310        default_val,
311        experimental_is_anonymous=is_anonymous)
312    self.initialize_table(table)
313
314    # Ref types do not produce a lookup signature mismatch.
315    input_string_ref = variables.Variable("brain")
316    self.evaluate(input_string_ref.initializer)
317    self.assertEqual(0, self.evaluate(table.lookup(input_string_ref)))
318
319    input_string = constant_op.constant([1, 2, 3], dtypes.int64)
320    with self.assertRaises(TypeError):
321      table.lookup(input_string)
322
323    with self.assertRaises(TypeError):
324      self.getHashTable()(
325          lookup_ops.KeyValueTensorInitializer(keys, values),
326          "UNK",
327          experimental_is_anonymous=is_anonymous)
328
329  def testDTypes(self, is_anonymous):
330    default_val = -1
331    with self.assertRaises(TypeError):
332      self.getHashTable()(
333          lookup_ops.KeyValueTensorInitializer(["a"], [1], [dtypes.string],
334                                               dtypes.int64),
335          default_val,
336          experimental_is_anonymous=is_anonymous)
337
338  @test_util.run_v1_only("(Cached) Sessions not available in TF2.0")
339  def testNotInitialized(self, is_anonymous):
340    with self.cached_session():
341      default_val = -1
342      table = self.getHashTable()(
343          lookup_ops.KeyValueTensorInitializer(["a"], [1],
344                                               value_dtype=dtypes.int64),
345          default_val,
346          experimental_is_anonymous=is_anonymous)
347
348      input_string = constant_op.constant(["brain", "salad", "surgery"])
349      output = table.lookup(input_string)
350
351      with self.assertRaisesOpError("Table not initialized"):
352        self.evaluate(output)
353
354  @test_util.run_v1_only("(Cached) Sessions not available in TF2.0")
355  def testInitializeTwice(self, is_anonymous):
356    with self.cached_session():
357      default_val = -1
358      keys = constant_op.constant(["brain", "salad", "surgery"])
359      values = constant_op.constant([0, 1, 2], dtypes.int64)
360      table = self.getHashTable()(
361          lookup_ops.KeyValueTensorInitializer(keys, values),
362          default_val,
363          experimental_is_anonymous=is_anonymous)
364      self.initialize_table(table)
365      # Make sure that initializing twice doesn't throw any errors.
366      self.initialize_table(table)
367
368  def testInitializationWithInvalidDimensions(self, is_anonymous):
369    default_val = -1
370    keys = constant_op.constant(["brain", "salad", "surgery"])
371    values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
372
373    raised_error = ValueError
374    if context.executing_eagerly():
375      raised_error = errors_impl.InvalidArgumentError
376    with self.assertRaises(raised_error):
377      self.getHashTable()(
378          lookup_ops.KeyValueTensorInitializer(keys, values),
379          default_val,
380          experimental_is_anonymous=is_anonymous)
381
382  @test_util.run_v1_only("Sessions not available in TF2.0")
383  def testMultipleSessions(self, is_anonymous):
384    if is_anonymous and not tf2.enabled():
385      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
386    # Start a server
387    server = server_lib.Server({"local0": ["localhost:0"]},
388                               protocol="grpc",
389                               start=True)
390    # Create two sessions sharing the same state
391    session1 = session.Session(server.target)
392    session2 = session.Session(server.target)
393
394    default_val = -1
395    keys = constant_op.constant(["brain", "salad", "surgery"])
396    values = constant_op.constant([0, 1, 2], dtypes.int64)
397    table = self.getHashTable()(
398        lookup_ops.KeyValueTensorInitializer(keys, values),
399        default_val,
400        name="t1",
401        experimental_is_anonymous=is_anonymous)
402
403    # Init the table in the first session.
404    with session1:
405      self.initialize_table(table)
406      self.assertAllEqual(3, self.evaluate(table.size()))
407
408    # Init the table in the second session and verify that we do not get a
409    # "Table already initialized" error.
410    with session2:
411      self.evaluate(table.initializer)
412      self.assertAllEqual(3, self.evaluate(table.size()))
413
414  @test_util.run_v2_only
415  def testImportedHashTable(self, is_anonymous):
416    g = ops.Graph()
417    with g.as_default():
418      t = lookup_ops.StaticHashTable(
419          lookup_ops.KeyValueTensorInitializer(["a"], [1]),
420          2)
421      init_op = t._init_op
422      op = t.lookup(ops.convert_to_tensor(["a"]))
423      meta_graph = saver.export_meta_graph()
424
425    def f():
426      saver.import_meta_graph(meta_graph)
427      return ops.get_default_graph().get_tensor_by_name(op.name)
428
429    wrapped = wrap_function.wrap_function(f, [])
430    pruned_init_fn = wrapped.prune(
431        (), [wrapped.graph.get_operation_by_name(init_op.name)])
432    self.evaluate(pruned_init_fn())
433    self.assertAllEqual([1], wrapped())
434
435  def testStaticHashTableInt32String(self, is_anonymous):
436    if is_anonymous and not tf2.enabled():
437      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
438    default_val = "n/a"
439    keys = constant_op.constant([0, 1, 2], dtypes.int32)
440    values = constant_op.constant(["brain", "salad", "surgery"])
441    table = self.getHashTable()(
442        lookup_ops.KeyValueTensorInitializer(keys, values),
443        default_val,
444        experimental_is_anonymous=is_anonymous)
445    self.initialize_table(table)
446
447    input_tensor = constant_op.constant([0, 1, -1])
448    output = table.lookup(input_tensor)
449
450    result = self.evaluate(output)
451    self.assertAllEqual([b"brain", b"salad", b"n/a"], result)
452
453  def testTableUseInFunction(self, is_anonymous):
454    if not context.executing_eagerly():
455      self.skipTest("Only Eager mode test.")
456    keys = constant_op.constant([0, 1, 2], dtypes.int32)
457    values = constant_op.constant(["brain", "salad", "surgery"])
458    table = self.getHashTable()(
459        lookup_ops.KeyValueTensorInitializer(keys, values),
460        "n/a",
461        experimental_is_anonymous=is_anonymous)
462
463    @function.defun()
464    def lookup_table_func(k):
465      return table.lookup(k)
466
467    result = lookup_table_func(constant_op.constant([0, 1, -1]))
468    self.assertAllEqual([b"brain", b"salad", b"n/a"], result)
469    result = lookup_table_func(constant_op.constant([2, -1, 1]))
470    self.assertAllEqual([b"surgery", b"n/a", b"salad"], result)
471
472  def testTableCreatedInFunction(self, is_anonymous):
473    if not context.executing_eagerly():
474      self.skipTest("Only Eager mode test.")
475    keys = constant_op.constant([0, 1, 2], dtypes.int32)
476    values = constant_op.constant(["brain", "salad", "surgery"])
477
478    @function.defun()
479    def lookup_table_func(k):
480      table = self.getHashTable()(
481          lookup_ops.KeyValueTensorInitializer(keys, values),
482          "n/a",
483          experimental_is_anonymous=is_anonymous)
484      return table.lookup(k)
485
486    result = lookup_table_func(constant_op.constant([0, 1, -1]))
487    self.assertAllEqual([b"brain", b"salad", b"n/a"], result)
488    result = lookup_table_func(constant_op.constant([2, -1, 1]))
489    self.assertAllEqual([b"surgery", b"n/a", b"salad"], result)
490
491  def testTwoTablesInControlFlow(self, is_anonymous):
492    if is_anonymous and not tf2.enabled():
493      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
494    keys = constant_op.constant([1, 2, 3], dtypes.int32)
495    values = constant_op.constant([5, 10, 15], dtypes.int32)
496
497    def table_func1(x):
498      table = self.getHashTable()(
499          lookup_ops.KeyValueTensorInitializer(keys, values),
500          -1,
501          experimental_is_anonymous=is_anonymous)
502      return table.lookup(x)
503
504    elems = np.array([2, 4, 1], dtype=np.int32)
505    result1 = map_fn.map_fn(table_func1, elems, dtype=dtypes.int32)
506
507    def table_func2(x):
508      table = self.getHashTable()(
509          lookup_ops.KeyValueTensorInitializer(keys, values),
510          -1,
511          experimental_is_anonymous=is_anonymous)
512      return table.lookup(x)
513
514    elems = np.array([2, 4, 1], dtype=np.int32)
515    result2 = map_fn.map_fn(table_func2, elems, dtype=dtypes.int32)
516
517    self.evaluate(lookup_ops.tables_initializer())
518
519    self.assertAllEqual([10, -1, 5], self.evaluate(result1))
520    self.assertAllEqual([10, -1, 5], self.evaluate(result2))
521
522  @test_util.enable_control_flow_v2
523  def testLookupTableInWhileV2(self, is_anonymous):
524    lookup = self.getHashTable()(
525        lookup_ops.KeyValueTensorInitializer(
526            constant_op.constant([2, 5], dtype=dtypes.int64),
527            constant_op.constant([-10.0, 1], dtype=dtypes.float32)),
528        -1,
529        experimental_is_anonymous=is_anonymous)
530
531    beta = variables.Variable(1.0, trainable=True)
532
533    @def_function.function
534    def get_loss(unused_beta):
535      return map_fn.map_fn(
536          lookup.lookup,
537          constant_op.constant([2, 3], dtype=dtypes.int64),
538          dtype=dtypes.float32)
539
540    with backprop.GradientTape() as tape:
541      loss = get_loss(beta)
542
543    self.assertIsNone(tape.gradient(loss, beta))
544
545  @test_util.enable_control_flow_v2
546  def testLookupTableInCondV2(self, is_anonymous):
547    if is_anonymous and not tf2.enabled():
548      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
549    lookup = self.getHashTable()(
550        lookup_ops.KeyValueTensorInitializer(
551            constant_op.constant([2, 5], dtype=dtypes.int64),
552            constant_op.constant([-10.0, 1], dtype=dtypes.float32)),
553        -1,
554        experimental_is_anonymous=is_anonymous)
555
556    beta = variables.Variable(1.0, trainable=True)
557
558    @def_function.function
559    def get_loss(beta):
560
561      def true_fn():
562        return lookup.lookup(constant_op.constant(2, dtype=dtypes.int64))
563
564      def false_fn():
565        return constant_op.constant(0, dtype=dtypes.float32)
566
567      return beta * control_flow_ops.cond(
568          constant_op.constant(True), true_fn=true_fn, false_fn=false_fn)
569
570    with backprop.GradientTape() as tape:
571      loss = get_loss(beta)
572    grad = tape.gradient(loss, beta)
573    self.evaluate(variables.global_variables_initializer())
574    self.evaluate(lookup_ops.tables_initializer())
575    self.assertAllEqual(grad, -10.)
576
577  def testExportShapeInference(self, is_anonymous):
578    table = self.getHashTable()(
579        lookup_ops.KeyValueTensorInitializer(
580            constant_op.constant([2, 5], dtype=dtypes.int64),
581            constant_op.constant([-10.0, 1], dtype=dtypes.float32)),
582        -1,
583        experimental_is_anonymous=is_anonymous)
584    actual_shapes = [t.shape for t in table.export()]
585    inferred_shapes = []
586
587    @def_function.function
588    def f():
589      for t in table.export():
590        inferred_shapes.append(t.shape)
591
592    f()
593    self.assertLen(actual_shapes, 2)
594    self.assertLen(inferred_shapes, 2)
595    self.assertTrue(inferred_shapes[0].is_compatible_with(actual_shapes[0]))
596    self.assertTrue(inferred_shapes[1].is_compatible_with(actual_shapes[1]))
597
598  @test_util.run_v2_only
599  def testSavedModelSaveRestore(self, is_anonymous):
600    save_dir = os.path.join(self.get_temp_dir(), "save_restore")
601    save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
602
603    root = autotrackable.AutoTrackable()
604
605    default_value = -1
606    keys = constant_op.constant([11, 12, 13], dtypes.int64)
607    values = constant_op.constant([0, 1, 2], dtypes.int64)
608    root.table = self.getHashTable()(
609        lookup_ops.KeyValueTensorInitializer(keys, values),
610        default_value,
611        experimental_is_anonymous=is_anonymous)
612
613    @def_function.function(
614        input_signature=[tensor_spec.TensorSpec((), dtypes.int64)])
615    def lookup(key):
616      return root.table.lookup(key)
617
618    @def_function.function(input_signature=[])
619    def size():
620      return root.table.size()
621
622    @def_function.function(input_signature=[])
623    def is_ref_counting():
624      return test_ops.is_resource_handle_ref_counting(
625          root.table.resource_handle)
626
627    root.lookup = lookup
628    root.size = size
629    root.is_ref_counting = is_ref_counting
630
631    self.assertEqual(root.table.size(), 3)
632    self.assertEqual(root.lookup(12), 1)
633    self.assertEqual(root.lookup(10), -1)
634    self.assertLen(root.table.export()[0], 3)
635    self.assertEqual(root.is_ref_counting(), is_anonymous)
636
637    saved_model_save.save(root, save_path)
638
639    del root
640    loaded = saved_model_load.load(save_path)
641    self.assertEqual(loaded.size(), 3)
642    self.assertEqual(loaded.lookup(12), 1)
643    self.assertEqual(loaded.lookup(10), -1)
644    self.assertEqual(loaded.is_ref_counting(), is_anonymous)
645
646
647@parameterized.named_parameters(
648    (f"_{is_anonymous}", is_anonymous) for is_anonymous in [False, True])
649class KeyValueTensorInitializerTest(BaseLookupTableTest):
650
651  def test_string(self, is_anonymous):
652    init = lookup_ops.KeyValueTensorInitializer(
653        ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64)
654    table = self.getHashTable()(
655        init, default_value=-1, experimental_is_anonymous=is_anonymous)
656    self.initialize_table(table)
657
658  def test_multiple_tables(self, is_anonymous):
659    with ops.name_scope("table_scope"):
660      init1 = lookup_ops.KeyValueTensorInitializer(
661          ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64)
662      table1 = self.getHashTable()(
663          init1, default_value=-1, experimental_is_anonymous=is_anonymous)
664      if not context.executing_eagerly():
665        self.assertEqual("hash_table", table1.name)
666        self.assertEqual("table_scope/hash_table",
667                         table1.resource_handle.op.name)
668      init2 = lookup_ops.KeyValueTensorInitializer(
669          ("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64)
670      table2 = self.getHashTable()(
671          init2, default_value=-1, experimental_is_anonymous=is_anonymous)
672      if not context.executing_eagerly():
673        self.assertEqual("hash_table_1", table2.name)
674        self.assertEqual("table_scope/hash_table_1",
675                         table2.resource_handle.op.name)
676
677  def test_int64(self, is_anonymous):
678    init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
679                                                dtypes.int64, dtypes.int64)
680    table = self.getHashTable()(
681        init, default_value=-1, experimental_is_anonymous=is_anonymous)
682    self.initialize_table(table)
683
684  def test_int32(self, is_anonymous):
685    init = lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
686                                                dtypes.int32, dtypes.int64)
687    with self.assertRaises(errors_impl.OpError):
688      table = self.getHashTable()(
689          init, default_value=-1, experimental_is_anonymous=is_anonymous)
690      self.initialize_table(table)
691
692
693@parameterized.named_parameters(
694    (f"_{is_anonymous}", is_anonymous) for is_anonymous in [False, True])
695class InitializeTableFromFileOpTest(BaseLookupTableTest):
696
697  def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
698    vocabulary_file = os.path.join(self.get_temp_dir(), basename)
699    with open(vocabulary_file, "w") as f:
700      f.write("\n".join(values) + "\n")
701    return vocabulary_file
702
703  def testInitializeStringTable(self, is_anonymous):
704    if is_anonymous and not tf2.enabled():
705      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
706    vocabulary_file = self._createVocabFile("one_column_1.txt")
707    default_value = -1
708    init = lookup_ops.TextFileInitializer(
709        vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
710        dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
711    self.assertIn("one_column_1.txt_-2_-1", init._shared_name)
712    table = self.getHashTable()(
713        init, default_value, experimental_is_anonymous=is_anonymous)
714    self.initialize_table(table)
715
716    output = table.lookup(constant_op.constant(["brain", "salad", "tank"]))
717
718    result = self.evaluate(output)
719    self.assertAllEqual([0, 1, -1], result)
720
721  def testInitializeInt64Table(self, is_anonymous):
722    if is_anonymous and not tf2.enabled():
723      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
724    vocabulary_file = self._createVocabFile(
725        "one_column_int64.txt", values=("42", "1", "-1000"))
726
727    with self.cached_session():
728      default_value = -1
729      init = lookup_ops.TextFileInitializer(
730          vocabulary_file, dtypes.int64, lookup_ops.TextFileIndex.WHOLE_LINE,
731          dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
732      self.assertIn("one_column_int64.txt_-2_-1", init._shared_name)
733      table = self.getHashTable()(
734          init, default_value, experimental_is_anonymous=is_anonymous)
735      self.initialize_table(table)
736
737      output = table.lookup(
738          constant_op.constant((42, 1, 11), dtype=dtypes.int64))
739
740      result = self.evaluate(output)
741      self.assertAllEqual([0, 1, -1], result)
742
743  def testInitializeIndexTable(self, is_anonymous):
744    if is_anonymous and not tf2.enabled():
745      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
746    vocabulary_file = self._createVocabFile("one_column_2.txt")
747
748    with self.cached_session():
749      default_value = "UNK"
750      key_index = lookup_ops.TextFileIndex.LINE_NUMBER
751      value_index = lookup_ops.TextFileIndex.WHOLE_LINE
752      init = lookup_ops.TextFileInitializer(
753          vocabulary_file, dtypes.int64, key_index, dtypes.string, value_index)
754      self.assertIn("one_column_2.txt_-1_-2", init._shared_name)
755      table = self.getHashTable()(
756          init, default_value, experimental_is_anonymous=is_anonymous)
757      self.initialize_table(table)
758
759      input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
760      output = table.lookup(input_values)
761
762      result = self.evaluate(output)
763      self.assertAllEqual([b"brain", b"salad", b"surgery", b"UNK"], result)
764
765  def testMultiColumn(self, is_anonymous):
766    if is_anonymous and not tf2.enabled():
767      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
768    vocabulary_file = os.path.join(self.get_temp_dir(), "three_columns.txt")
769    with open(vocabulary_file, "w") as f:
770      f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
771
772    with self.cached_session():
773      default_value = -1
774      key_index = 1
775      value_index = 2
776
777      init = lookup_ops.TextFileInitializer(
778          vocabulary_file, dtypes.string, key_index, dtypes.int64, value_index)
779      self.assertIn("three_columns.txt_1_2", init._shared_name)
780      table = self.getHashTable()(
781          init, default_value, experimental_is_anonymous=is_anonymous)
782      self.initialize_table(table)
783
784      input_string = constant_op.constant(["brain", "salad", "surgery"])
785      output = table.lookup(input_string)
786
787      result = self.evaluate(output)
788      self.assertAllEqual([1, 5, 6], result)
789
790  def testInvalidDataTypeInMultiColumn(self, is_anonymous):
791    vocabulary_file = os.path.join(self.get_temp_dir(), "three_columns.txt")
792    with open(vocabulary_file, "w") as f:
793      f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
794
795    with self.cached_session():
796      default_value = -1
797      key_index = 2
798      value_index = 1
799      init = lookup_ops.TextFileInitializer(
800          vocabulary_file, dtypes.string, key_index, dtypes.int64, value_index)
801      self.assertIn("three_columns.txt_2_1", init._shared_name)
802      with self.assertRaisesOpError("is not a valid"):
803        table = self.getHashTable()(
804            init, default_value, experimental_is_anonymous=is_anonymous)
805        self.initialize_table(table)
806
807  def testInvalidDataType(self, is_anonymous):
808    vocabulary_file = self._createVocabFile("one_column_3.txt")
809
810    with self.cached_session():
811      default_value = "UNK"
812      key_index = lookup_ops.TextFileIndex.WHOLE_LINE
813      value_index = lookup_ops.TextFileIndex.LINE_NUMBER
814
815      with self.assertRaises(ValueError):
816        init = lookup_ops.TextFileInitializer(vocabulary_file, dtypes.int64,
817                                              key_index, dtypes.string,
818                                              value_index)
819        self.assertIn("one_column_3.txt_-2_-1", init._shared_name)
820        self.getHashTable()(
821            init, default_value, experimental_is_anonymous=is_anonymous)
822
823  def testInvalidIndex(self, is_anonymous):
824    vocabulary_file = self._createVocabFile("one_column_4.txt")
825    with self.cached_session():
826      default_value = -1
827      key_index = 1  # second column of the line
828      value_index = lookup_ops.TextFileIndex.LINE_NUMBER
829      init = lookup_ops.TextFileInitializer(
830          vocabulary_file, dtypes.string, key_index, dtypes.int64, value_index)
831      self.assertIn("one_column_4.txt_1_-1", init._shared_name)
832
833      with self.assertRaisesOpError("Invalid number of columns"):
834        table = self.getHashTable()(
835            init, default_value, experimental_is_anonymous=is_anonymous)
836        self.initialize_table(table)
837
838  def testInitializeSameTableWithMultipleNodes(self, is_anonymous):
839    if is_anonymous and not tf2.enabled():
840      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
841    vocabulary_file = self._createVocabFile("one_column_5.txt")
842
843    with self.cached_session():
844      default_value = -1
845      init1 = lookup_ops.TextFileInitializer(
846          vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
847          dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
848      self.assertIn("one_column_5.txt_-2_-1", init1._shared_name)
849      table1 = self.getHashTable()(
850          init1, default_value, experimental_is_anonymous=is_anonymous)
851      init2 = lookup_ops.TextFileInitializer(
852          vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
853          dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
854      self.assertIn("one_column_5.txt_-2_-1", init2._shared_name)
855      table2 = self.getHashTable()(
856          init2, default_value, experimental_is_anonymous=is_anonymous)
857      init3 = lookup_ops.TextFileInitializer(
858          vocabulary_file, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
859          dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
860      self.assertIn("one_column_5.txt_-2_-1", init3._shared_name)
861      table3 = self.getHashTable()(
862          init3, default_value, experimental_is_anonymous=is_anonymous)
863
864      self.evaluate(lookup_ops.tables_initializer())
865
866      input_string = constant_op.constant(["brain", "salad", "tank"])
867
868      output1 = table1.lookup(input_string)
869      output2 = table2.lookup(input_string)
870      output3 = table3.lookup(input_string)
871
872      out1, out2, out3 = self.evaluate([output1, output2, output3])
873      self.assertAllEqual([0, 1, -1], out1)
874      self.assertAllEqual([0, 1, -1], out2)
875      self.assertAllEqual([0, 1, -1], out3)
876
877  def testInitializeTableWithNoFilename(self, is_anonymous):
878    with self.cached_session():
879      default_value = -1
880      with self.assertRaises(ValueError):
881        self.getHashTable()(
882            lookup_ops.TextFileInitializer(
883                "", dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
884                dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
885            default_value,
886            experimental_is_anonymous=is_anonymous)
887
888  def testInitializeWithVocabSize(self, is_anonymous):
889    if is_anonymous and not tf2.enabled():
890      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
891    with self.cached_session():
892      default_value = -1
893      vocab_size = 3
894      vocabulary_file1 = self._createVocabFile("one_column6.txt")
895      init1 = lookup_ops.TextFileInitializer(
896          vocabulary_file1,
897          dtypes.string,
898          lookup_ops.TextFileIndex.WHOLE_LINE,
899          dtypes.int64,
900          lookup_ops.TextFileIndex.LINE_NUMBER,
901          vocab_size=vocab_size)
902      self.assertIn("one_column6.txt_3_-2_-1", init1._shared_name)
903      table1 = self.getHashTable()(
904          init1, default_value, experimental_is_anonymous=is_anonymous)
905
906      # Initialize from file.
907      self.initialize_table(table1)
908      self.assertEqual(vocab_size, self.evaluate(table1.size()))
909
910      vocabulary_file2 = self._createVocabFile("one_column7.txt")
911      vocab_size = 5
912      init2 = lookup_ops.TextFileInitializer(
913          vocabulary_file2,
914          dtypes.string,
915          lookup_ops.TextFileIndex.WHOLE_LINE,
916          dtypes.int64,
917          lookup_ops.TextFileIndex.LINE_NUMBER,
918          vocab_size=vocab_size)
919      self.assertIn("one_column7.txt_5_-2_-1", init2._shared_name)
920      with self.assertRaisesOpError("Invalid vocab_size"):
921        table2 = self.getHashTable()(
922            init2, default_value, experimental_is_anonymous=is_anonymous)
923        self.initialize_table(table2)
924
925      vocab_size = 1
926      vocabulary_file3 = self._createVocabFile("one_column3.txt")
927      init3 = lookup_ops.TextFileInitializer(
928          vocabulary_file3,
929          dtypes.string,
930          lookup_ops.TextFileIndex.WHOLE_LINE,
931          dtypes.int64,
932          lookup_ops.TextFileIndex.LINE_NUMBER,
933          vocab_size=vocab_size)
934      self.assertIn("one_column3.txt_1_-2_-1", init3._shared_name)
935      table3 = self.getHashTable()(
936          init3, default_value, experimental_is_anonymous=is_anonymous)
937
938      # Smaller vocab size reads only vocab_size records.
939      self.initialize_table(table3)
940      self.assertEqual(vocab_size, self.evaluate(table3.size()))
941
942  @test_util.run_v1_only("placeholder usage")
943  def testFeedVocabularyName(self, is_anonymous):
944    if is_anonymous and not tf2.enabled():
945      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
946    vocabulary_file = self._createVocabFile("feed_vocabulary.txt")
947
948    with self.cached_session():
949      default_value = -1
950      init = lookup_ops.TextFileInitializer(
951          "old_file.txt", dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
952          dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER)
953      self.assertIn("old_file.txt_-2_-1", init._shared_name)
954      table = self.getHashTable()(
955          init, default_value, experimental_is_anonymous=is_anonymous)
956
957      # Initialize with non existing file (old_file.txt) should fail.
958      # TODO(yleon): Update message, which might change per FileSystem.
959      with self.assertRaisesOpError("old_file.txt"):
960        self.evaluate(table.initializer)
961
962      # Initialize the model feeding the vocabulary file.
963      filenames = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
964      table.initializer.run(feed_dict={filenames[0]: vocabulary_file})
965
966      input_string = constant_op.constant(["brain", "salad", "tank"])
967      output = table.lookup(input_string)
968
969      result = self.evaluate(output)
970      self.assertAllEqual([0, 1, -1], result)
971
972  def testInvalidFilenames(self, is_anonymous):
973    vocabulary_file = self._createVocabFile("filename_shape.txt")
974
975    with self.cached_session():
976      default_value = -1
977
978      # Invalid data type
979      other_type = constant_op.constant(1)
980      with self.assertRaises(Exception) as cm:
981        self.getHashTable()(
982            lookup_ops.TextFileInitializer(
983                other_type, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
984                dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
985            default_value,
986            experimental_is_anonymous=is_anonymous)
987      self.assertIsInstance(cm.exception, (ValueError, TypeError))
988
989      # Non-scalar filename
990      filenames = constant_op.constant([vocabulary_file, vocabulary_file])
991      if not context.executing_eagerly():
992        with self.assertRaises(Exception) as cm:
993          self.getHashTable()(
994              lookup_ops.TextFileInitializer(
995                  filenames, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
996                  dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
997              default_value,
998              experimental_is_anonymous=is_anonymous)
999        self.assertIsInstance(cm.exception, (ValueError, TypeError))
1000      else:
1001        with self.assertRaises(errors_impl.InvalidArgumentError):
1002          self.getHashTable()(
1003              lookup_ops.TextFileInitializer(
1004                  filenames, dtypes.string, lookup_ops.TextFileIndex.WHOLE_LINE,
1005                  dtypes.int64, lookup_ops.TextFileIndex.LINE_NUMBER),
1006              default_value,
1007              experimental_is_anonymous=is_anonymous)
1008
1009  def testIdToStringTable(self, is_anonymous):
1010    if is_anonymous and not tf2.enabled():
1011      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1012    vocab_file = self._createVocabFile("feat_to_id_1.txt")
1013    with self.cached_session():
1014      default_value = "UNK"
1015      vocab_size = 3
1016      init = lookup_ops.TextFileStringTableInitializer(
1017          vocab_file, vocab_size=vocab_size)
1018      self.assertTrue("feat_to_id_1.txt_3_-1_-2", init._shared_name)
1019      table = self.getHashTable()(
1020          init, default_value, experimental_is_anonymous=is_anonymous)
1021
1022      self.initialize_table(table)
1023
1024      input_values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
1025
1026      out = table.lookup(input_values)
1027      self.assertAllEqual([b"brain", b"salad", b"surgery", b"UNK"],
1028                          self.evaluate(out))
1029      self.assertEqual(vocab_size, self.evaluate(table.size()))
1030
1031  def testStringToIdTable(self, is_anonymous):
1032    if is_anonymous and not tf2.enabled():
1033      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1034    vocab_file = self._createVocabFile("feat_to_id_2.txt")
1035    with self.cached_session():
1036      default_value = -1
1037      vocab_size = 3
1038      init = lookup_ops.TextFileIdTableInitializer(
1039          vocab_file, vocab_size=vocab_size)
1040      self.assertTrue("feat_to_id_2.txt_3_-1_-2", init._shared_name)
1041      table = self.getHashTable()(
1042          init, default_value, experimental_is_anonymous=is_anonymous)
1043      self.initialize_table(table)
1044
1045      input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
1046
1047      out = table.lookup(input_string)
1048      self.assertAllEqual([0, 1, 2, -1], self.evaluate(out))
1049      self.assertEqual(vocab_size, self.evaluate(table.size()))
1050
1051  def testInt64ToIdTable(self, is_anonymous):
1052    if is_anonymous and not tf2.enabled():
1053      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1054    vocab_file = self._createVocabFile(
1055        "feat_to_id_3.txt", values=("42", "1", "-1000"))
1056    with self.cached_session():
1057      default_value = -1
1058      vocab_size = 3
1059      init = lookup_ops.TextFileIdTableInitializer(
1060          vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64)
1061      self.assertTrue("feat_to_id_3.txt_3_-1_-2", init._shared_name)
1062      table = self.getHashTable()(
1063          init, default_value, experimental_is_anonymous=is_anonymous)
1064      self.initialize_table(table)
1065
1066      out = table.lookup(
1067          constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64))
1068      self.assertAllEqual((0, 1, 2, -1), self.evaluate(out))
1069      self.assertEqual(vocab_size, self.evaluate(table.size()))
1070
1071
1072@parameterized.named_parameters(
1073    (f"_{is_anonymous}", is_anonymous) for is_anonymous in [False, True])
1074class StaticVocabularyTableTest(BaseLookupTableTest):
1075
1076  def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
1077    vocabulary_file = os.path.join(self.get_temp_dir(), basename)
1078    with open(vocabulary_file, "w") as f:
1079      f.write("\n".join(values) + "\n")
1080    return vocabulary_file
1081
1082  def testStringStaticVocabularyTable(self, is_anonymous):
1083    if is_anonymous and not tf2.enabled():
1084      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1085    vocab_file = self._createVocabFile("feat_to_id_1.txt")
1086    vocab_size = 3
1087    oov_buckets = 1
1088    table = self.getVocabularyTable()(
1089        lookup_ops.TextFileIdTableInitializer(
1090            vocab_file, vocab_size=vocab_size),
1091        oov_buckets,
1092        experimental_is_anonymous=is_anonymous)
1093
1094    self.initialize_table(table)
1095
1096    input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
1097
1098    out = table.lookup(input_string)
1099    self.assertAllEqual([0, 1, 2, 3], self.evaluate(out))
1100    self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size()))
1101
1102  def testStaticVocabularyTableGetItem(self, is_anonymous):
1103    if is_anonymous and not tf2.enabled():
1104      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1105    vocab_file = self._createVocabFile("feat_to_id_1.txt")
1106    vocab_size = 3
1107    oov_buckets = 1
1108    table = self.getVocabularyTable()(
1109        lookup_ops.TextFileIdTableInitializer(
1110            vocab_file, vocab_size=vocab_size),
1111        oov_buckets,
1112        experimental_is_anonymous=is_anonymous)
1113
1114    self.initialize_table(table)
1115
1116    input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
1117
1118    out = table[input_string]
1119    self.assertAllEqual([0, 1, 2, 3], self.evaluate(out))
1120    self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size()))
1121
1122  def testInt32StaticVocabularyTable(self, is_anonymous):
1123    if is_anonymous and not tf2.enabled():
1124      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1125    vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000"))
1126    vocab_size = 3
1127    oov_buckets = 1
1128    table = self.getVocabularyTable()(
1129        lookup_ops.TextFileIdTableInitializer(
1130            vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64),
1131        oov_buckets,
1132        lookup_key_dtype=dtypes.int32,
1133        experimental_is_anonymous=is_anonymous)
1134
1135    self.initialize_table(table)
1136
1137    values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int32)
1138
1139    out = table.lookup(values)
1140    self.assertAllEqual([0, 1, 2, 3], self.evaluate(out))
1141    self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size()))
1142
1143  def testInt64StaticVocabularyTable(self, is_anonymous):
1144    if is_anonymous and not tf2.enabled():
1145      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1146    vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000"))
1147    vocab_size = 3
1148    oov_buckets = 1
1149    table = self.getVocabularyTable()(
1150        lookup_ops.TextFileIdTableInitializer(
1151            vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64),
1152        oov_buckets,
1153        experimental_is_anonymous=is_anonymous)
1154
1155    self.initialize_table(table)
1156
1157    values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64)
1158
1159    out = table.lookup(values)
1160    self.assertAllEqual([0, 1, 2, 3], self.evaluate(out))
1161    self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size()))
1162
1163  def testStringStaticVocabularyTableNoInitializer(self, is_anonymous):
1164    oov_buckets = 5
1165
1166    # Set a table that only uses hash buckets, for each input value returns
1167    # an id calculated by fingerprint("input") mod oov_buckets.
1168    table = self.getVocabularyTable()(
1169        None, oov_buckets, experimental_is_anonymous=is_anonymous)
1170    self.initialize_table(table)
1171
1172    values = constant_op.constant(("brain", "salad", "surgery"))
1173
1174    out = table.lookup(values)
1175    self.assertAllEqual(
1176        [
1177            3,  # fingerprint("brain") mod 5.
1178            1,  # fingerprint("salad") mod 5.
1179            4  # fingerprint("surgery") mod 5
1180        ],
1181        self.evaluate(out))
1182    self.assertEqual(oov_buckets, self.evaluate(table.size()))
1183
1184  def testStaticVocabularyTableWithMultipleInitializers(self, is_anonymous):
1185    if is_anonymous and not tf2.enabled():
1186      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1187    vocab_file = self._createVocabFile("feat_to_id_4.txt")
1188    vocab_size = 3
1189    oov_buckets = 3
1190
1191    init = lookup_ops.TextFileIdTableInitializer(
1192        vocab_file, vocab_size=vocab_size)
1193    table1 = self.getVocabularyTable()(
1194        init,
1195        oov_buckets,
1196        name="table1",
1197        experimental_is_anonymous=is_anonymous)
1198
1199    table2 = self.getVocabularyTable()(
1200        init,
1201        oov_buckets,
1202        name="table2",
1203        experimental_is_anonymous=is_anonymous)
1204
1205    self.evaluate(lookup_ops.tables_initializer())
1206
1207    input_string = constant_op.constant(
1208        ["fruit", "brain", "salad", "surgery", "UNK"])
1209
1210    out1 = table1.lookup(input_string)
1211    out2 = table2.lookup(input_string)
1212
1213    out1, out2 = self.evaluate([out1, out2])
1214    self.assertAllEqual([5, 0, 1, 2, 5], out1)
1215    self.assertAllEqual([5, 0, 1, 2, 5], out2)
1216    self.assertEqual(vocab_size + oov_buckets, self.evaluate(table1.size()))
1217    self.assertEqual(vocab_size + oov_buckets, self.evaluate(table2.size()))
1218
1219  def testStaticVocabularyTableInitializationAcrossSessions(self, is_anonymous):
1220    if is_anonymous and not tf2.enabled():
1221      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1222    vocab_file = self._createVocabFile("feat_to_id_5.txt")
1223    with self.cached_session():
1224      vocab_size = 3
1225      oov_buckets = 1
1226      table1 = self.getVocabularyTable()(
1227          lookup_ops.TextFileIdTableInitializer(
1228              vocab_file, vocab_size=vocab_size),
1229          oov_buckets,
1230          experimental_is_anonymous=is_anonymous)
1231
1232      self.initialize_table(table1)
1233
1234      input_string_1 = constant_op.constant(
1235          ["brain", "salad", "surgery", "UNK"])
1236
1237      out1 = table1.lookup(input_string_1)
1238
1239      self.assertAllEqual([0, 1, 2, 3], self.evaluate(out1))
1240      self.assertEqual(vocab_size + oov_buckets, self.evaluate(table1.size()))
1241
1242    with self.cached_session():
1243      vocab_size = 3
1244      oov_buckets = 1
1245
1246      # Underlying lookup table already initialized in previous session.
1247      # No need to initialize table2
1248      table2 = self.getVocabularyTable()(
1249          lookup_ops.TextFileIdTableInitializer(
1250              vocab_file, vocab_size=vocab_size),
1251          oov_buckets,
1252          experimental_is_anonymous=is_anonymous)
1253
1254      input_string_2 = constant_op.constant(["fruit", "salad", "UNK"])
1255
1256      out2 = table2.lookup(input_string_2)
1257
1258      self.assertAllEqual([3, 1, 3], self.evaluate(out2))
1259      self.assertEqual(vocab_size + oov_buckets, self.evaluate(table2.size()))
1260
1261  def testStaticVocabularyTableAssetTracking(self, is_anonymous):
1262    vocab_file = self._createVocabFile("vocab.txt")
1263    vocab_size = 3
1264    oov_buckets = 1
1265    table = self.getVocabularyTable()(
1266        lookup_ops.TextFileIdTableInitializer(
1267            vocab_file, vocab_size=vocab_size),
1268        oov_buckets,
1269        experimental_is_anonymous=is_anonymous)
1270    objects = checkpoint_util.list_objects(graph_view.ObjectGraphView(table))
1271    assets = list(filter(lambda obj: isinstance(obj, asset.Asset), objects))
1272    self.assertLen(assets, 1)
1273    self.assertEqual(
1274        self.evaluate(assets[0].asset_path), compat.as_bytes(vocab_file))
1275
1276  def testSparseTensor(self, is_anonymous):
1277    if is_anonymous and not tf2.enabled():
1278      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1279    vocab_file = self._createVocabFile("feat_to_id_7.txt")
1280    input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
1281    input_shape = [4, 4]
1282    sp_features = sparse_tensor.SparseTensor(
1283        constant_op.constant(input_indices, dtypes.int64),
1284        constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"],
1285                             dtypes.string),
1286        constant_op.constant(input_shape, dtypes.int64))
1287
1288    table = self.getVocabularyTable()(
1289        lookup_ops.TextFileIdTableInitializer(vocab_file, vocab_size=3),
1290        1,
1291        experimental_is_anonymous=is_anonymous)
1292    self.initialize_table(table)
1293
1294    sp_ids = table.lookup(sp_features)
1295
1296    self.assertAllEqual([5], sp_ids.values._shape_as_list())
1297
1298    sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate(
1299        [sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
1300
1301    self.assertAllEqual(input_indices, sp_ids_ind)
1302    self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val)
1303    self.assertAllEqual(input_shape, sp_ids_shape)
1304
1305  def testRaggedTensor(self, is_anonymous):
1306    if is_anonymous and not tf2.enabled():
1307      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1308    vocab_file = self._createVocabFile("feat_to_id_7.txt")
1309    input_row_splits = [0, 2, 4, 5]
1310    ragged_features = ragged_tensor.RaggedTensor.from_row_splits(
1311        constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"],
1312                             dtypes.string),
1313        constant_op.constant(input_row_splits, dtypes.int64))
1314
1315    table = self.getVocabularyTable()(
1316        lookup_ops.TextFileIdTableInitializer(vocab_file, vocab_size=3),
1317        1,
1318        experimental_is_anonymous=is_anonymous)
1319    self.initialize_table(table)
1320
1321    ragged_ids = table.lookup(ragged_features)
1322
1323    self.assertAllEqual([5], ragged_ids.values._shape_as_list())
1324
1325    ragged_ids_val, ragged_ids_row_splits = self.evaluate(
1326        [ragged_ids.values, ragged_ids.row_splits])
1327
1328    self.assertAllEqual([0, 1, 0, 2, 3], ragged_ids_val)
1329    self.assertAllEqual(input_row_splits, ragged_ids_row_splits)
1330
1331  def testInt32SparseTensor(self, is_anonymous):
1332    if is_anonymous and not tf2.enabled():
1333      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1334    input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
1335    input_shape = [4, 4]
1336    sp_features = sparse_tensor.SparseTensor(
1337        constant_op.constant(input_indices, dtypes.int64),
1338        constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32),
1339        constant_op.constant(input_shape, dtypes.int64))
1340
1341    table = self.getVocabularyTable()(
1342        lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
1343                                             dtypes.int64, dtypes.int64),
1344        1,
1345        lookup_key_dtype=dtypes.int32,
1346        experimental_is_anonymous=is_anonymous)
1347    self.initialize_table(table)
1348
1349    sp_ids = table.lookup(sp_features)
1350
1351    self.assertAllEqual([5], sp_ids.values._shape_as_list())
1352
1353    sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate(
1354        [sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
1355
1356    self.assertAllEqual(input_indices, sp_ids_ind)
1357    self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val)
1358    self.assertAllEqual(input_shape, sp_ids_shape)
1359
1360  def testInt32RaggedTensor(self, is_anonymous):
1361    if is_anonymous and not tf2.enabled():
1362      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1363    input_row_splits = [0, 2, 4, 5]
1364    ragged_features = ragged_tensor.RaggedTensor.from_row_splits(
1365        constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32),
1366        constant_op.constant(input_row_splits, dtypes.int64))
1367
1368    table = self.getVocabularyTable()(
1369        lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
1370                                             dtypes.int64, dtypes.int64),
1371        1,
1372        lookup_key_dtype=dtypes.int32,
1373        experimental_is_anonymous=is_anonymous)
1374    self.initialize_table(table)
1375
1376    ragged_ids = table.lookup(ragged_features)
1377
1378    self.assertAllEqual([5], ragged_ids.values._shape_as_list())
1379
1380    ragged_ids_val, ragged_ids_row_splits = self.evaluate(
1381        [ragged_ids.values, ragged_ids.row_splits])
1382
1383    self.assertAllEqual([0, 1, 0, 2, 3], ragged_ids_val)
1384    self.assertAllEqual(input_row_splits, ragged_ids_row_splits)
1385
1386  def testInt64SparseTensor(self, is_anonymous):
1387    if is_anonymous and not tf2.enabled():
1388      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1389    input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
1390    input_shape = [4, 4]
1391    sp_features = sparse_tensor.SparseTensor(
1392        constant_op.constant(input_indices, dtypes.int64),
1393        constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64),
1394        constant_op.constant(input_shape, dtypes.int64))
1395
1396    table = self.getVocabularyTable()(
1397        lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
1398                                             dtypes.int64, dtypes.int64),
1399        1,
1400        experimental_is_anonymous=is_anonymous)
1401    self.initialize_table(table)
1402
1403    sp_ids = table.lookup(sp_features)
1404
1405    self.assertAllEqual([5], sp_ids.values._shape_as_list())
1406
1407    sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate(
1408        [sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
1409
1410    self.assertAllEqual(input_indices, sp_ids_ind)
1411    self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val)
1412    self.assertAllEqual(input_shape, sp_ids_shape)
1413
1414  def testInt64RaggedTensor(self, is_anonymous):
1415    if is_anonymous and not tf2.enabled():
1416      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1417    input_row_splits = [0, 2, 4, 5]
1418    ragged_features = ragged_tensor.RaggedTensor.from_row_splits(
1419        constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64),
1420        constant_op.constant(input_row_splits, dtypes.int64))
1421
1422    table = self.getVocabularyTable()(
1423        lookup_ops.KeyValueTensorInitializer((42, 1, -1000), (0, 1, 2),
1424                                             dtypes.int64, dtypes.int64),
1425        1,
1426        experimental_is_anonymous=is_anonymous)
1427    self.initialize_table(table)
1428
1429    ragged_ids = table.lookup(ragged_features)
1430
1431    self.assertAllEqual([5], ragged_ids.values._shape_as_list())
1432
1433    ragged_ids_val, ragged_ids_row_splits = self.evaluate(
1434        [ragged_ids.values, ragged_ids.row_splits])
1435
1436    self.assertAllEqual([0, 1, 0, 2, 3], ragged_ids_val)
1437    self.assertAllEqual(input_row_splits, ragged_ids_row_splits)
1438
1439  def testStaticVocabularyTableNoInnerTable(self, is_anonymous):
1440    table = self.getVocabularyTable()(
1441        None, num_oov_buckets=1, experimental_is_anonymous=is_anonymous)
1442    self.assertIsNone(table.resource_handle)
1443
1444  @test_util.run_v2_only
1445  def testSavedModelSaveRestore(self, is_anonymous):
1446    save_dir = os.path.join(self.get_temp_dir(), "save_restore")
1447    save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
1448
1449    root = autotrackable.AutoTrackable()
1450
1451    vocab_file = self._createVocabFile("feat_to_id_3.txt", ("11", "12", "13"))
1452    vocab_size = 3
1453    oov_buckets = 1
1454    root.table = self.getVocabularyTable()(
1455        lookup_ops.TextFileIdTableInitializer(
1456            vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64),
1457        oov_buckets,
1458        experimental_is_anonymous=is_anonymous)
1459
1460    @def_function.function(
1461        input_signature=[tensor_spec.TensorSpec((), dtypes.int64)])
1462    def lookup(key):
1463      return root.table.lookup(key)
1464
1465    @def_function.function(input_signature=[])
1466    def size():
1467      return root.table.size()
1468
1469    @def_function.function(input_signature=[])
1470    def is_ref_counting():
1471      return test_ops.is_resource_handle_ref_counting(
1472          root.table.resource_handle)
1473
1474    root.lookup = lookup
1475    root.size = size
1476    root.is_ref_counting = is_ref_counting
1477
1478    self.assertEqual(root.table.size(), 4)
1479    self.assertEqual(root.lookup(12), 1)
1480    self.assertEqual(root.lookup(10), 3)
1481    self.assertEqual(root.is_ref_counting(), is_anonymous)
1482
1483    saved_model_save.save(root, save_path)
1484
1485    del root
1486    loaded = saved_model_load.load(save_path)
1487    self.assertEqual(loaded.size(), 4)
1488    self.assertEqual(loaded.lookup(12), 1)
1489    self.assertEqual(loaded.lookup(10), 3)
1490    self.assertEqual(loaded.is_ref_counting(), is_anonymous)
1491
1492
1493@parameterized.named_parameters(
1494    (f"_{is_anonymous}", is_anonymous) for is_anonymous in [False, True])
1495class DenseHashTableOpTest(test.TestCase):
1496
1497  def testBasic(self, is_anonymous):
1498    if is_anonymous and not tf2.enabled():
1499      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1500    keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
1501    values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
1502    table = lookup_ops.DenseHashTable(
1503        dtypes.int64,
1504        dtypes.int64,
1505        default_value=-1,
1506        empty_key=0,
1507        deleted_key=-1,
1508        experimental_is_anonymous=is_anonymous)
1509    self.assertAllEqual(0, self.evaluate(table.size()))
1510
1511    self.evaluate(table.insert(keys, values))
1512    self.assertAllEqual(4, self.evaluate(table.size()))
1513
1514    remove_string = constant_op.constant([12, 15], dtypes.int64)
1515    self.evaluate(table.remove(remove_string))
1516    self.assertAllEqual(3, self.evaluate(table.size()))
1517
1518    input_string = constant_op.constant([11, 12, 15], dtypes.int64)
1519    output = table.lookup(input_string)
1520    self.assertAllEqual([3], output.get_shape())
1521
1522    result = self.evaluate(output)
1523    self.assertAllEqual([0, -1, -1], result)
1524
1525  def testGetItem(self, is_anonymous):
1526    if is_anonymous and not tf2.enabled():
1527      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1528    keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
1529    values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
1530    table = lookup_ops.DenseHashTable(
1531        dtypes.int64,
1532        dtypes.int64,
1533        default_value=-1,
1534        empty_key=0,
1535        deleted_key=-1,
1536        experimental_is_anonymous=is_anonymous)
1537
1538    self.evaluate(table.insert(keys, values))
1539
1540    input_string = constant_op.constant([11, 12, 15], dtypes.int64)
1541    output = table[input_string]
1542    self.assertAllEqual([3], output.get_shape())
1543
1544    result = self.evaluate(output)
1545    self.assertAllEqual([0, 1, -1], result)
1546
1547  def testBasicBool(self, is_anonymous):
1548    if is_anonymous and not tf2.enabled():
1549      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1550    keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
1551    values = constant_op.constant([True, True, True, True], dtypes.bool)
1552    table = lookup_ops.DenseHashTable(
1553        dtypes.int64,
1554        dtypes.bool,
1555        default_value=False,
1556        empty_key=0,
1557        deleted_key=-1,
1558        experimental_is_anonymous=is_anonymous)
1559    self.assertAllEqual(0, self.evaluate(table.size()))
1560
1561    self.evaluate(table.insert(keys, values))
1562    self.assertAllEqual(4, self.evaluate(table.size()))
1563
1564    remove_string = constant_op.constant([11, 15], dtypes.int64)
1565    self.evaluate(table.remove(remove_string))
1566    self.assertAllEqual(3, self.evaluate(table.size()))
1567
1568    input_string = constant_op.constant([11, 12, 15], dtypes.int64)
1569    output = table.lookup(input_string)
1570    self.assertAllEqual([3], output.get_shape())
1571
1572    result = self.evaluate(output)
1573    self.assertAllEqual([False, True, False], result)
1574
1575  def testSameEmptyAndDeletedKey(self, is_anonymous):
1576    with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
1577                                "Empty and deleted keys"):
1578      table = lookup_ops.DenseHashTable(
1579          dtypes.int64,
1580          dtypes.int64,
1581          default_value=-1,
1582          empty_key=42,
1583          deleted_key=42,
1584          experimental_is_anonymous=is_anonymous)
1585      self.assertAllEqual(0, self.evaluate(table.size()))
1586
1587  @test_util.run_v1_only("uses placeholders")
1588  def testLookupUnknownShape(self, is_anonymous):
1589    if is_anonymous and not tf2.enabled():
1590      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1591    with self.cached_session():
1592      keys = constant_op.constant([11, 12, 13], dtypes.int64)
1593      values = constant_op.constant([0, 1, 2], dtypes.int64)
1594      table = lookup_ops.DenseHashTable(
1595          dtypes.int64,
1596          dtypes.int64,
1597          default_value=-1,
1598          empty_key=0,
1599          deleted_key=-1,
1600          experimental_is_anonymous=is_anonymous)
1601
1602      self.evaluate(table.insert(keys, values))
1603      self.assertAllEqual(3, self.evaluate(table.size()))
1604
1605      placeholder_keys = array_ops.placeholder(dtypes.int64)
1606      output = table.lookup(placeholder_keys)
1607      self.assertAllEqual(None, output.get_shape())
1608      result = output.eval({placeholder_keys: [11, 12, 15]})
1609      self.assertAllEqual([0, 1, -1], result)
1610
1611  def testMapStringToFloat(self, is_anonymous):
1612    if is_anonymous and not tf2.enabled():
1613      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1614    keys = constant_op.constant(["a", "b", "c", "d"], dtypes.string)
1615    values = constant_op.constant([0.0, 1.1, 2.2, 3.3], dtypes.float32)
1616    default_value = constant_op.constant(-1.5, dtypes.float32)
1617    table = lookup_ops.DenseHashTable(
1618        dtypes.string,
1619        dtypes.float32,
1620        default_value=default_value,
1621        empty_key="",
1622        deleted_key="$",
1623        experimental_is_anonymous=is_anonymous)
1624    self.assertAllEqual(0, self.evaluate(table.size()))
1625
1626    self.evaluate(table.insert(keys, values))
1627    self.assertAllEqual(4, self.evaluate(table.size()))
1628
1629    remove_string = constant_op.constant(["b", "e"])
1630    self.evaluate(table.remove(remove_string))
1631    self.assertAllEqual(3, self.evaluate(table.size()))
1632
1633    input_string = constant_op.constant(["a", "b", "d", "e"], dtypes.string)
1634    output = table.lookup(input_string)
1635    self.assertAllEqual([4], output.get_shape())
1636
1637    result = self.evaluate(output)
1638    self.assertAllClose([0, -1.5, 3.3, -1.5], result)
1639
1640  def testMapInt64ToFloat(self, is_anonymous):
1641    if is_anonymous and not tf2.enabled():
1642      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1643    for float_dtype in [dtypes.float32, dtypes.float64]:
1644      keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
1645      values = constant_op.constant([0.0, 1.1, 2.2, 3.3], float_dtype)
1646      default_value = constant_op.constant(-1.5, float_dtype)
1647      table = lookup_ops.DenseHashTable(
1648          dtypes.int64,
1649          float_dtype,
1650          default_value=default_value,
1651          empty_key=0,
1652          deleted_key=-1,
1653          experimental_is_anonymous=is_anonymous)
1654      self.assertAllEqual(0, self.evaluate(table.size()))
1655
1656      self.evaluate(table.insert(keys, values))
1657      self.assertAllEqual(4, self.evaluate(table.size()))
1658
1659      remove_string = constant_op.constant([12, 15], dtypes.int64)
1660      self.evaluate(table.remove(remove_string))
1661      self.assertAllEqual(3, self.evaluate(table.size()))
1662
1663      input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64)
1664      output = table.lookup(input_string)
1665      self.assertAllEqual([4], output.get_shape())
1666
1667      result = self.evaluate(output)
1668      self.assertAllClose([0, -1.5, 3.3, -1.5], result)
1669
1670  def testVectorValues(self, is_anonymous):
1671    if is_anonymous and not tf2.enabled():
1672      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1673    keys = constant_op.constant([11, 12, 13], dtypes.int64)
1674    values = constant_op.constant([[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]],
1675                                  dtypes.int64)
1676    default_value = constant_op.constant([-1, -2, -3, -4], dtypes.int64)
1677    table = lookup_ops.DenseHashTable(
1678        dtypes.int64,
1679        dtypes.int64,
1680        default_value=default_value,
1681        empty_key=0,
1682        deleted_key=-1,
1683        initial_num_buckets=4,
1684        experimental_is_anonymous=is_anonymous)
1685    self.assertAllEqual(0, self.evaluate(table.size()))
1686
1687    self.evaluate(table.insert(keys, values))
1688    self.assertAllEqual(3, self.evaluate(table.size()))
1689    self.assertAllEqual(4, len(self.evaluate(table.export()[0])))
1690
1691    self.evaluate(
1692        table.insert(
1693            constant_op.constant([14], dtypes.int64),
1694            constant_op.constant([[2, 3, 4, 5]], dtypes.int64)))
1695    self.assertAllEqual(4, self.evaluate(table.size()))
1696    self.assertAllEqual(8, len(self.evaluate(table.export()[0])))
1697
1698    remove_string = constant_op.constant([12, 16], dtypes.int64)
1699    self.evaluate(table.remove(remove_string))
1700    self.assertAllEqual(3, self.evaluate(table.size()))
1701    self.assertAllEqual(8, len(self.evaluate(table.export()[0])))
1702
1703    input_string = constant_op.constant([11, 12, 14, 15], dtypes.int64)
1704    output = table.lookup(input_string)
1705    self.assertAllEqual([4, 4],
1706                        output.shape,
1707                        msg="Saw shape: %s" % output.shape)
1708
1709    result = self.evaluate(output)
1710    self.assertAllEqual(
1711        [[0, 1, 2, 3], [-1, -2, -3, -4], [2, 3, 4, 5], [-1, -2, -3, -4]],
1712        result)
1713
1714  def testVectorKeys(self, is_anonymous):
1715    if is_anonymous and not tf2.enabled():
1716      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1717    keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64)
1718    values = constant_op.constant([10, 11, 12], dtypes.int64)
1719    empty_key = constant_op.constant([0, 3], dtypes.int64)
1720    deleted_key = constant_op.constant([-1, -1], dtypes.int64)
1721    default_value = constant_op.constant(-1, dtypes.int64)
1722    table = lookup_ops.DenseHashTable(
1723        dtypes.int64,
1724        dtypes.int64,
1725        default_value=default_value,
1726        empty_key=empty_key,
1727        deleted_key=deleted_key,
1728        initial_num_buckets=8,
1729        experimental_is_anonymous=is_anonymous)
1730    self.assertAllEqual(0, self.evaluate(table.size()))
1731
1732    self.evaluate(table.insert(keys, values))
1733    self.assertAllEqual(3, self.evaluate(table.size()))
1734
1735    self.evaluate(
1736        table.insert(
1737            constant_op.constant([[0, 0]], dtypes.int64),
1738            constant_op.constant([13], dtypes.int64)))
1739    self.assertAllEqual(4, self.evaluate(table.size()))
1740    self.assertAllEqual(8, len(self.evaluate(table.export()[0])))
1741
1742    remove_string = constant_op.constant([[1, 2], [7, 8]], dtypes.int64)
1743    self.evaluate(table.remove(remove_string))
1744    self.assertAllEqual(3, self.evaluate(table.size()))
1745    self.assertAllEqual(8, len(self.evaluate(table.export()[0])))
1746
1747    input_string = constant_op.constant([[0, 1], [1, 2], [1, 3], [0, 2]],
1748                                        dtypes.int64)
1749    output = table.lookup(input_string)
1750    self.assertAllEqual([4], output.get_shape())
1751
1752    result = self.evaluate(output)
1753    self.assertAllEqual([10, -1, 12, -1], result)
1754
1755  def testResize(self, is_anonymous):
1756    if is_anonymous and not tf2.enabled():
1757      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1758    keys = constant_op.constant([11, 12, 13], dtypes.int64)
1759    values = constant_op.constant([0, 1, 2], dtypes.int64)
1760    table = lookup_ops.DenseHashTable(
1761        dtypes.int64,
1762        dtypes.int64,
1763        default_value=-1,
1764        empty_key=0,
1765        deleted_key=-1,
1766        initial_num_buckets=4,
1767        experimental_is_anonymous=is_anonymous)
1768    self.assertAllEqual(0, self.evaluate(table.size()))
1769
1770    self.evaluate(table.insert(keys, values))
1771    self.assertAllEqual(3, self.evaluate(table.size()))
1772    self.assertAllEqual(4, len(self.evaluate(table.export()[0])))
1773
1774    keys2 = constant_op.constant([12, 99], dtypes.int64)
1775    self.evaluate(table.remove(keys2))
1776    self.assertAllEqual(2, self.evaluate(table.size()))
1777    self.assertAllEqual(4, len(self.evaluate(table.export()[0])))
1778
1779    keys3 = constant_op.constant([13, 14, 15, 16, 17], dtypes.int64)
1780    values3 = constant_op.constant([3, 4, 5, 6, 7], dtypes.int64)
1781
1782    self.evaluate(table.insert(keys3, values3))
1783    self.assertAllEqual(6, self.evaluate(table.size()))
1784    self.assertAllEqual(16, len(self.evaluate(table.export()[0])))
1785
1786    keys4 = constant_op.constant([10, 11, 12, 13, 14, 15, 16, 17, 18],
1787                                 dtypes.int64)
1788    output = table.lookup(keys4)
1789    self.assertAllEqual([-1, 0, -1, 3, 4, 5, 6, 7, -1], self.evaluate(output))
1790
1791  def testExport(self, is_anonymous):
1792    if is_anonymous and not tf2.enabled():
1793      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1794    keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
1795    values = constant_op.constant([1, 2, 3, 4], dtypes.int64)
1796    table = lookup_ops.DenseHashTable(
1797        dtypes.int64,
1798        dtypes.int64,
1799        default_value=-1,
1800        empty_key=100,
1801        deleted_key=200,
1802        initial_num_buckets=8,
1803        experimental_is_anonymous=is_anonymous)
1804    self.assertAllEqual(0, self.evaluate(table.size()))
1805
1806    self.evaluate(table.insert(keys, values))
1807    self.assertAllEqual(4, self.evaluate(table.size()))
1808
1809    keys2 = constant_op.constant([12, 15], dtypes.int64)
1810    self.evaluate(table.remove(keys2))
1811    self.assertAllEqual(3, self.evaluate(table.size()))
1812
1813    exported_keys, exported_values = table.export()
1814
1815    np_keys = self.evaluate(exported_keys)
1816    np_values = self.evaluate(exported_values)
1817
1818    self.assertAllEqual(8, len(np_keys))
1819    self.assertAllEqual(8, len(np_values))
1820
1821    # pair up keys and values, drop extra added dimension
1822    pairs = np.dstack((np_keys.flatten(), np_values.flatten()))[0]
1823    # sort by key
1824    pairs = pairs[pairs[:, 0].argsort()]
1825    self.assertAllEqual([[11, 1], [13, 3], [14, 4], [100, 0], [100, 0],
1826                         [100, 0], [100, 0], [200, 2]], pairs)
1827
1828  @test_util.run_v1_only("Saver V1 only")
1829  def testSaveRestore(self, is_anonymous):
1830    if is_anonymous and not tf2.enabled():
1831      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1832    save_dir = os.path.join(self.get_temp_dir(), "save_restore")
1833    save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
1834
1835    with self.session(graph=ops.Graph()) as sess:
1836      default_value = -1
1837      empty_key = 0
1838      deleted_key = -1
1839      keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
1840      values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
1841      table = lookup_ops.DenseHashTable(
1842          dtypes.int64,
1843          dtypes.int64,
1844          default_value=default_value,
1845          empty_key=empty_key,
1846          deleted_key=deleted_key,
1847          name="t1",
1848          checkpoint=True,
1849          initial_num_buckets=32,
1850          experimental_is_anonymous=is_anonymous)
1851
1852      save = saver.Saver()
1853
1854      self.assertAllEqual(0, table.size())
1855      table.insert(keys, values).run()
1856      self.assertAllEqual(4, table.size())
1857      self.assertAllEqual(32, len(table.export()[0].eval()))
1858
1859      keys2 = constant_op.constant([12, 15], dtypes.int64)
1860      table.remove(keys2).run()
1861      self.assertAllEqual(3, table.size())
1862      self.assertAllEqual(32, len(table.export()[0].eval()))
1863
1864      val = save.save(sess, save_path)
1865      self.assertIsInstance(val, str)
1866      self.assertEqual(save_path, val)
1867
1868    with self.session(graph=ops.Graph()) as sess:
1869      table = lookup_ops.DenseHashTable(
1870          dtypes.int64,
1871          dtypes.int64,
1872          default_value=default_value,
1873          empty_key=empty_key,
1874          deleted_key=deleted_key,
1875          name="t1",
1876          checkpoint=True,
1877          initial_num_buckets=64,
1878          experimental_is_anonymous=is_anonymous)
1879      table.insert(
1880          constant_op.constant([11, 14], dtypes.int64),
1881          constant_op.constant([12, 24], dtypes.int64)).run()
1882      self.assertAllEqual(2, table.size())
1883      self.assertAllEqual(64, len(table.export()[0].eval()))
1884
1885      save = saver.Saver()
1886
1887      # Restore the saved values in the parameter nodes.
1888      save.restore(sess, save_path)
1889
1890      self.assertAllEqual(3, table.size())
1891      self.assertAllEqual(32, len(table.export()[0].eval()))
1892
1893      input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64)
1894      output = table.lookup(input_string)
1895      self.assertAllEqual([-1, 0, -1, 2, 3], output)
1896
1897  @test_util.run_v1_only("Saver V1 only")
1898  def testSaveRestoreOnlyTable(self, is_anonymous):
1899    if is_anonymous and not tf2.enabled():
1900      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1901    save_dir = os.path.join(self.get_temp_dir(), "save_restore")
1902    save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
1903
1904    with self.session(graph=ops.Graph()) as sess:
1905      default_value = -1
1906      empty_key = 0
1907      deleted_key = -1
1908      keys = constant_op.constant([11, 12, 13, 14], dtypes.int64)
1909      values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
1910      table = lookup_ops.DenseHashTable(
1911          dtypes.int64,
1912          dtypes.int64,
1913          default_value=default_value,
1914          empty_key=empty_key,
1915          deleted_key=deleted_key,
1916          name="t1",
1917          checkpoint=True,
1918          initial_num_buckets=32,
1919          experimental_is_anonymous=is_anonymous)
1920
1921      save = saver.Saver([table])
1922
1923      self.assertAllEqual(0, table.size())
1924      table.insert(keys, values).run()
1925      self.assertAllEqual(4, table.size())
1926      self.assertAllEqual(32, len(table.export()[0].eval()))
1927
1928      keys2 = constant_op.constant([12, 15], dtypes.int64)
1929      table.remove(keys2).run()
1930      self.assertAllEqual(3, table.size())
1931      self.assertAllEqual(32, len(table.export()[0].eval()))
1932
1933      val = save.save(sess, save_path)
1934      self.assertIsInstance(val, str)
1935      self.assertEqual(save_path, val)
1936
1937    with self.session(graph=ops.Graph()) as sess:
1938      table = lookup_ops.DenseHashTable(
1939          dtypes.int64,
1940          dtypes.int64,
1941          default_value=default_value,
1942          empty_key=empty_key,
1943          deleted_key=deleted_key,
1944          name="t1",
1945          checkpoint=True,
1946          initial_num_buckets=64,
1947          experimental_is_anonymous=is_anonymous)
1948      table.insert(
1949          constant_op.constant([11, 14], dtypes.int64),
1950          constant_op.constant([12, 24], dtypes.int64)).run()
1951      self.assertAllEqual(2, table.size())
1952      self.assertAllEqual(64, len(table.export()[0].eval()))
1953
1954      save = saver.Saver([table])
1955
1956      # Restore the saved values in the parameter nodes.
1957      save.restore(sess, save_path)
1958
1959      self.assertAllEqual(3, table.size())
1960      self.assertAllEqual(32, len(table.export()[0].eval()))
1961
1962      input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64)
1963      output = table.lookup(input_string)
1964      self.assertAllEqual([-1, 0, -1, 2, 3], output)
1965
1966  @test_util.run_in_graph_and_eager_modes
1967  def testObjectSaveRestore(self, is_anonymous):
1968    if is_anonymous and not context.executing_eagerly():
1969      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
1970    save_dir = os.path.join(self.get_temp_dir(), "save_restore")
1971    save_prefix = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
1972
1973    default_value = -1
1974    empty_key = 0
1975    deleted_key = -1
1976    keys = constant_op.constant([11, 12, 13], dtypes.int64)
1977    values = constant_op.constant([0, 1, 2], dtypes.int64)
1978    save_table = lookup_ops.DenseHashTable(
1979        dtypes.int64,
1980        dtypes.int64,
1981        default_value=default_value,
1982        empty_key=empty_key,
1983        deleted_key=deleted_key,
1984        name="t1",
1985        checkpoint=True,
1986        initial_num_buckets=32,
1987        experimental_is_anonymous=is_anonymous)
1988
1989    save_checkpoint = trackable.Checkpoint(table=save_table)
1990
1991    self.assertAllEqual(0, self.evaluate(save_table.size()))
1992    self.evaluate(save_table.insert(keys, values))
1993    self.assertAllEqual(3, self.evaluate(save_table.size()))
1994    self.assertAllEqual(32, len(self.evaluate(save_table.export()[0])))
1995
1996    save_path = save_checkpoint.save(save_prefix)
1997    del save_table, save_checkpoint
1998
1999    load_table = lookup_ops.DenseHashTable(
2000        dtypes.int64,
2001        dtypes.int64,
2002        default_value=default_value,
2003        empty_key=empty_key,
2004        deleted_key=deleted_key,
2005        name="t1",
2006        checkpoint=True,
2007        initial_num_buckets=64,
2008        experimental_is_anonymous=is_anonymous)
2009    self.evaluate(
2010        load_table.insert(
2011            constant_op.constant([11, 14], dtypes.int64),
2012            constant_op.constant([12, 24], dtypes.int64)))
2013    self.assertAllEqual(2, self.evaluate(load_table.size()))
2014    self.assertAllEqual(64, len(self.evaluate(load_table.export()[0])))
2015
2016    restore_checkpoint = trackable.Checkpoint(table=load_table)
2017
2018    # Restore the saved values in the parameter nodes.
2019    restore_checkpoint.restore(save_path).run_restore_ops()
2020
2021    self.assertAllEqual(3, self.evaluate(load_table.size()))
2022    self.assertAllEqual(32, len(self.evaluate(load_table.export()[0])))
2023
2024    input_string = constant_op.constant([10, 11, 12, 13, 14], dtypes.int64)
2025    output = load_table.lookup(input_string)
2026    self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output))
2027
2028  @test_util.run_v2_only
2029  def testSavedModelSaveRestore(self, is_anonymous):
2030    save_dir = os.path.join(self.get_temp_dir(), "save_restore")
2031    save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
2032
2033    root = autotrackable.AutoTrackable()
2034
2035    default_value = -1
2036    empty_key = 0
2037    deleted_key = -1
2038    keys = constant_op.constant([11, 12, 13], dtypes.int64)
2039    values = constant_op.constant([0, 1, 2], dtypes.int64)
2040    root.table = lookup_ops.DenseHashTable(
2041        dtypes.int64,
2042        dtypes.int64,
2043        default_value=default_value,
2044        empty_key=empty_key,
2045        deleted_key=deleted_key,
2046        name="t1",
2047        checkpoint=True,
2048        initial_num_buckets=32,
2049        experimental_is_anonymous=is_anonymous)
2050
2051    @def_function.function(
2052        input_signature=[tensor_spec.TensorSpec((), dtypes.int64)])
2053    def lookup(key):
2054      return root.table.lookup(key)
2055
2056    @def_function.function(input_signature=[])
2057    def size():
2058      return root.table.size()
2059
2060    @def_function.function(input_signature=[])
2061    def is_ref_counting():
2062      return test_ops.is_resource_handle_ref_counting(
2063          root.table.resource_handle)
2064
2065    root.lookup = lookup
2066    root.size = size
2067    root.is_ref_counting = is_ref_counting
2068
2069    self.assertEqual(root.table.size(), 0)
2070    root.table.insert(keys, values)
2071    self.assertEqual(root.table.size(), 3)
2072    self.assertEqual(root.table.lookup(12), 1)
2073    self.assertEqual(root.table.lookup(10), -1)
2074    self.assertEqual(len(root.table.export()[0]), 32)
2075    self.assertEqual(root.is_ref_counting(), is_anonymous)
2076
2077    saved_model_save.save(root, save_path)
2078
2079    del root
2080    loaded = saved_model_load.load(save_path)
2081    self.assertEqual(loaded.size(), 3)
2082    self.assertEqual(loaded.lookup(12), 1)
2083    self.assertEqual(loaded.lookup(10), -1)
2084    self.assertEqual(loaded.is_ref_counting(), is_anonymous)
2085
2086  @test_util.run_v1_only("Saver V1 only")
2087  def testVectorSaveRestore(self, is_anonymous):
2088    if is_anonymous and not tf2.enabled():
2089      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
2090    save_dir = os.path.join(self.get_temp_dir(), "vector_save_restore")
2091    save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
2092
2093    with self.session(graph=ops.Graph()) as sess:
2094      empty_key = constant_op.constant([11, 13], dtypes.int64)
2095      deleted_key = constant_op.constant([-2, -3], dtypes.int64)
2096      default_value = constant_op.constant([-1, -2], dtypes.int64)
2097      keys = constant_op.constant([[11, 12], [11, 14], [12, 13], [13, 14]],
2098                                  dtypes.int64)
2099      values = constant_op.constant([[0, 1], [2, 3], [2, 4], [4, 5]],
2100                                    dtypes.int64)
2101      table = lookup_ops.DenseHashTable(
2102          dtypes.int64,
2103          dtypes.int64,
2104          default_value=default_value,
2105          empty_key=empty_key,
2106          deleted_key=deleted_key,
2107          name="t1",
2108          checkpoint=True,
2109          initial_num_buckets=32,
2110          experimental_is_anonymous=is_anonymous)
2111
2112      save = saver.Saver()
2113
2114      self.assertAllEqual(0, table.size())
2115      table.insert(keys, values).run()
2116      self.assertAllEqual(4, table.size())
2117      self.assertAllEqual(32, len(table.export()[0].eval()))
2118
2119      keys2 = constant_op.constant([[12, 13], [16, 17]], dtypes.int64)
2120      table.remove(keys2).run()
2121      self.assertAllEqual(3, table.size())
2122      self.assertAllEqual(32, len(table.export()[0].eval()))
2123
2124      val = save.save(sess, save_path)
2125      self.assertIsInstance(val, str)
2126      self.assertEqual(save_path, val)
2127
2128    with self.session(graph=ops.Graph()) as sess:
2129      empty_key = constant_op.constant([11, 13], dtypes.int64)
2130      deleted_key = constant_op.constant([-2, -3], dtypes.int64)
2131      default_value = constant_op.constant([-1, -2], dtypes.int64)
2132      table = lookup_ops.DenseHashTable(
2133          dtypes.int64,
2134          dtypes.int64,
2135          default_value=default_value,
2136          empty_key=empty_key,
2137          deleted_key=deleted_key,
2138          name="t1",
2139          checkpoint=True,
2140          initial_num_buckets=64,
2141          experimental_is_anonymous=is_anonymous)
2142      table.insert(
2143          constant_op.constant([[11, 12], [13, 15]], dtypes.int64),
2144          constant_op.constant([[21, 22], [23, 24]], dtypes.int64)).run()
2145      self.assertAllEqual(2, table.size())
2146      self.assertAllEqual(64, len(table.export()[0].eval()))
2147
2148      save = saver.Saver()
2149
2150      # Restore the saved values in the parameter nodes.
2151      save.restore(sess, save_path)
2152
2153      self.assertAllEqual(3, table.size())
2154      self.assertAllEqual(32, len(table.export()[0].eval()))
2155
2156      input_string = constant_op.constant(
2157          [[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64)
2158      output = table.lookup(input_string)
2159      self.assertAllEqual([[0, 1], [2, 3], [-1, -2], [4, 5], [-1, -2]],
2160                          self.evaluate(output))
2161
2162  @test_util.run_v1_only("Saver V1 only")
2163  def testVectorScalarSaveRestore(self, is_anonymous):
2164    if is_anonymous and not tf2.enabled():
2165      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
2166    save_dir = os.path.join(self.get_temp_dir(), "vector_scalar_save_restore")
2167    save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
2168
2169    with self.session(graph=ops.Graph()) as sess:
2170      empty_key = constant_op.constant([11, 13], dtypes.int64)
2171      deleted_key = constant_op.constant([-1, -1], dtypes.int64)
2172      default_value = constant_op.constant(-1, dtypes.int64)
2173      keys = constant_op.constant([[11, 12], [11, 14], [12, 13], [13, 14]],
2174                                  dtypes.int64)
2175      values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
2176      table = lookup_ops.DenseHashTable(
2177          dtypes.int64,
2178          dtypes.int64,
2179          default_value=default_value,
2180          empty_key=empty_key,
2181          deleted_key=deleted_key,
2182          name="t2",
2183          checkpoint=True,
2184          initial_num_buckets=32,
2185          experimental_is_anonymous=is_anonymous)
2186
2187      save = saver.Saver()
2188
2189      self.assertAllEqual(0, table.size())
2190      table.insert(keys, values).run()
2191      self.assertAllEqual(4, table.size())
2192      self.assertAllEqual(32, len(table.export()[0].eval()))
2193
2194      keys2 = constant_op.constant([[12, 13], [15, 16]], dtypes.int64)
2195      table.remove(keys2).run()
2196      self.assertAllEqual(3, table.size())
2197      self.assertAllEqual(32, len(table.export()[0].eval()))
2198
2199      val = save.save(sess, save_path)
2200      self.assertIsInstance(val, str)
2201      self.assertEqual(save_path, val)
2202
2203    with self.session(graph=ops.Graph()) as sess:
2204      empty_key = constant_op.constant([11, 13], dtypes.int64)
2205      deleted_key = constant_op.constant([-1, -1], dtypes.int64)
2206      default_value = constant_op.constant(-1, dtypes.int64)
2207      table = lookup_ops.DenseHashTable(
2208          dtypes.int64,
2209          dtypes.int64,
2210          default_value=default_value,
2211          empty_key=empty_key,
2212          deleted_key=deleted_key,
2213          name="t2",
2214          checkpoint=True,
2215          initial_num_buckets=64,
2216          experimental_is_anonymous=is_anonymous)
2217      table.insert(
2218          constant_op.constant([[11, 12], [13, 15]], dtypes.int64),
2219          constant_op.constant([3, 4], dtypes.int64)).run()
2220      self.assertAllEqual(2, table.size())
2221      self.assertAllEqual(64, len(table.export()[0].eval()))
2222
2223      save = saver.Saver()
2224
2225      # Restore the saved values in the parameter nodes.
2226      save.restore(sess, save_path)
2227
2228      self.assertAllEqual(3, table.size())
2229      self.assertAllEqual(32, len(table.export()[0].eval()))
2230
2231      input_string = constant_op.constant(
2232          [[11, 12], [11, 14], [11, 15], [13, 14], [13, 15]], dtypes.int64)
2233      output = table.lookup(input_string)
2234      self.assertAllEqual([0, 1, -1, 3, -1], output)
2235
2236  def testReprobe(self, is_anonymous):
2237    if is_anonymous and not tf2.enabled():
2238      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
2239    # Insert 6 keys into a table with 8 buckets.
2240    # The values are chosen to make sure collisions occur when using GCC STL
2241    keys = constant_op.constant([11, 12, 13, 19, 20, 21], dtypes.int64)
2242    values = constant_op.constant([51, 52, 53, 54, 55, 56], dtypes.int64)
2243    table = lookup_ops.DenseHashTable(
2244        dtypes.int64,
2245        dtypes.int64,
2246        default_value=-1,
2247        empty_key=0,
2248        deleted_key=-1,
2249        initial_num_buckets=8,
2250        experimental_is_anonymous=is_anonymous)
2251    self.assertAllEqual(0, self.evaluate(table.size()))
2252
2253    self.evaluate(table.insert(keys, values))
2254    self.assertAllEqual(6, self.evaluate(table.size()))
2255
2256    input_string = constant_op.constant([10, 11, 12, 13, 14, 19, 20, 21, 22],
2257                                        dtypes.int64)
2258    output = table.lookup(input_string)
2259    self.assertAllEqual([9], output.get_shape())
2260
2261    result = self.evaluate(output)
2262    self.assertAllEqual([-1, 51, 52, 53, -1, 54, 55, 56, -1], result)
2263
2264  def testCustomEmptyKey(self, is_anonymous):
2265    if is_anonymous and not tf2.enabled():
2266      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
2267    keys = constant_op.constant([11, 0, 13], dtypes.int64)
2268    values = constant_op.constant([0, 1, 2], dtypes.int64)
2269    table = lookup_ops.DenseHashTable(
2270        dtypes.int64,
2271        dtypes.int64,
2272        default_value=-1,
2273        empty_key=12,
2274        deleted_key=-1,
2275        experimental_is_anonymous=is_anonymous)
2276    self.assertAllEqual(0, self.evaluate(table.size()))
2277
2278    self.evaluate(table.insert(keys, values))
2279    self.assertAllEqual(3, self.evaluate(table.size()))
2280
2281    input_string = constant_op.constant([11, 0, 15], dtypes.int64)
2282    output = table.lookup(input_string)
2283    self.assertAllEqual([3], output.get_shape())
2284
2285    result = self.evaluate(output)
2286    self.assertAllEqual([0, 1, -1], result)
2287
2288  def testErrors(self, is_anonymous):
2289    table = lookup_ops.DenseHashTable(
2290        dtypes.int64,
2291        dtypes.int64,
2292        default_value=-1,
2293        empty_key=0,
2294        deleted_key=-1,
2295        experimental_is_anonymous=is_anonymous)
2296
2297    # Inserting the empty key returns an error
2298    keys1 = constant_op.constant([11, 0], dtypes.int64)
2299    values1 = constant_op.constant([0, 1], dtypes.int64)
2300    with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
2301                                "empty_key"):
2302      self.evaluate(table.insert(keys1, values1))
2303
2304    # Looking up the empty key returns an error
2305    with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
2306                                "empty_key"):
2307      self.evaluate(table.lookup(keys1))
2308
2309    # Inserting the deleted key returns an error
2310    keys2 = constant_op.constant([11, -1], dtypes.int64)
2311    values2 = constant_op.constant([0, 1], dtypes.int64)
2312    with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
2313                                "deleted_key"):
2314      self.evaluate(table.insert(keys2, values2))
2315
2316    # Looking up the empty key returns an error
2317    with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
2318                                "deleted_key"):
2319      self.evaluate(table.lookup(keys2))
2320
2321    # Arbitrary tensors of keys are not supported
2322    keys = constant_op.constant([[11, 0], [12, 1]], dtypes.int64)
2323    values = constant_op.constant([[11, 0], [12, 1]], dtypes.int64)
2324    with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
2325                                "Expected key shape"):
2326      self.evaluate(table.lookup(keys))
2327    with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
2328                                "Expected key shape"):
2329      self.evaluate(table.insert(keys, values))
2330
2331    with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
2332                                "Number of buckets must be"):
2333      table2 = lookup_ops.DenseHashTable(
2334          dtypes.int64,
2335          dtypes.int64,
2336          default_value=-1,
2337          empty_key=17,
2338          deleted_key=-1,
2339          initial_num_buckets=12,
2340          experimental_is_anonymous=is_anonymous)
2341      self.assertAllEqual(0, self.evaluate(table2.size()))
2342
2343    with self.assertRaisesRegex(
2344        errors_impl.InvalidArgumentError,
2345        "Empty and deleted keys must have same shape"):
2346      table3 = lookup_ops.DenseHashTable(
2347          dtypes.int64,
2348          dtypes.int64,
2349          default_value=-1,
2350          empty_key=42,
2351          deleted_key=[1, 2],
2352          experimental_is_anonymous=is_anonymous)
2353      self.assertAllEqual(0, self.evaluate(table3.size()))
2354
2355    with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
2356                                "Empty and deleted keys cannot be equal"):
2357      table4 = lookup_ops.DenseHashTable(
2358          dtypes.int64,
2359          dtypes.int64,
2360          default_value=-1,
2361          empty_key=42,
2362          deleted_key=42,
2363          experimental_is_anonymous=is_anonymous)
2364      self.assertAllEqual(0, self.evaluate(table4.size()))
2365
2366    with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
2367                                "Empty and deleted keys cannot be equal"):
2368      table5 = lookup_ops.DenseHashTable(
2369          dtypes.int64,
2370          dtypes.int64,
2371          default_value=-1,
2372          empty_key=[1, 2, 3],
2373          deleted_key=[1, 2, 3],
2374          experimental_is_anonymous=is_anonymous)
2375      self.assertAllEqual(0, self.evaluate(table5.size()))
2376
2377  @test_util.run_in_graph_and_eager_modes
2378  def testStringToResource(self, is_anonymous):
2379    v = variables.Variable(1.)
2380    v1 = variables.Variable(1.)
2381    table = lookup_ops.DenseHashTable(
2382        dtypes.string,
2383        dtypes.resource,
2384        default_value=v.handle,
2385        empty_key="<empty>",
2386        deleted_key="<deleted>",
2387        experimental_is_anonymous=is_anonymous)
2388    self.assertEqual([], table.lookup("not_found").shape)
2389    table.insert("v1", v1.handle)
2390    self.assertEqual([], table.lookup("v1").shape)
2391
2392  def testExportShapeInference(self, is_anonymous):
2393    default_value = -1
2394    empty_key = 0
2395    deleted_key = -1
2396    table = lookup_ops.DenseHashTable(
2397        dtypes.int64,
2398        dtypes.int64,
2399        default_value=default_value,
2400        empty_key=empty_key,
2401        deleted_key=deleted_key,
2402        experimental_is_anonymous=is_anonymous)
2403    actual_shapes = [t.shape for t in table.export()]
2404    inferred_shapes = []
2405
2406    @def_function.function
2407    def f():
2408      for t in table.export():
2409        inferred_shapes.append(t.shape)
2410
2411    f()
2412    self.assertLen(actual_shapes, 2)
2413    self.assertLen(inferred_shapes, 2)
2414    self.assertTrue(inferred_shapes[0].is_compatible_with(actual_shapes[0]))
2415    self.assertTrue(inferred_shapes[1].is_compatible_with(actual_shapes[1]))
2416
2417
2418class IndexTableFromFile(test.TestCase):
2419
2420  def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
2421    vocabulary_file = os.path.join(self.get_temp_dir(), basename)
2422    with open(vocabulary_file, "w") as f:
2423      f.write("\n".join(values) + "\n")
2424    return vocabulary_file
2425
2426  def test_string_index_table_from_file(self):
2427    vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
2428
2429    table = lookup_ops.index_table_from_file(
2430        vocabulary_file=vocabulary_file, num_oov_buckets=1)
2431    ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
2432
2433    if not context.executing_eagerly():
2434      with self.assertRaises(errors_impl.OpError):
2435        self.evaluate(ids)
2436    self.evaluate(lookup_ops.tables_initializer())
2437    self.assertAllEqual((1, 2, 3), self.evaluate(ids))
2438
2439  def test_string_index_table_from_multicolumn_file(self):
2440    vocabulary_file = self._createVocabFile(
2441        "f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1"))
2442    table = lookup_ops.index_table_from_file(
2443        vocabulary_file=vocabulary_file,
2444        num_oov_buckets=1,
2445        key_column_index=0,
2446        value_column_index=lookup_ops.TextFileIndex.LINE_NUMBER)
2447    ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
2448
2449    if not context.executing_eagerly():
2450      with self.assertRaises(errors_impl.OpError):
2451        self.evaluate(ids)
2452    self.evaluate(lookup_ops.tables_initializer())
2453    self.assertAllEqual((1, 2, 3), self.evaluate(ids))
2454
2455  def test_string_index_table_from_multicolumn_file_custom_delimiter(self):
2456    vocabulary_file = self._createVocabFile(
2457        "f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1"))
2458    table = lookup_ops.index_table_from_file(
2459        vocabulary_file=vocabulary_file,
2460        num_oov_buckets=1,
2461        key_column_index=0,
2462        value_column_index=lookup_ops.TextFileIndex.LINE_NUMBER,
2463        delimiter=" ")
2464    ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
2465
2466    if not context.executing_eagerly():
2467      with self.assertRaises(errors_impl.OpError):
2468        self.evaluate(ids)
2469    self.evaluate(lookup_ops.tables_initializer())
2470    self.assertAllEqual((1, 2, 3), self.evaluate(ids))
2471
2472  def test_string_index_table_from_file_tensor_filename(self):
2473    vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
2474    vocabulary_file = constant_op.constant(vocabulary_file)
2475    table = lookup_ops.index_table_from_file(
2476        vocabulary_file=vocabulary_file, num_oov_buckets=1)
2477    ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
2478
2479    if not context.executing_eagerly():
2480      with self.assertRaises(errors_impl.OpError):
2481        self.evaluate(ids)
2482    self.evaluate(lookup_ops.tables_initializer())
2483    self.assertAllEqual((1, 2, 3), self.evaluate(ids))
2484    if not context.executing_eagerly():
2485      self.assertEqual(1,
2486                       len(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)))
2487
2488  @test_util.run_v1_only("placeholder usage")
2489  def test_string_index_table_from_file_placeholder_filename(self):
2490    vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
2491    with self.cached_session():
2492      vocabulary_placeholder = array_ops.placeholder(dtypes.string, [])
2493      table = lookup_ops.index_table_from_file(
2494          vocabulary_file=vocabulary_placeholder, num_oov_buckets=1)
2495      ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
2496
2497      with self.assertRaises(errors_impl.OpError):
2498        self.evaluate(ids)
2499
2500      feed_dict = {vocabulary_placeholder.name: vocabulary_file}
2501      lookup_ops.tables_initializer().run(feed_dict=feed_dict)
2502      self.assertAllEqual((1, 2, 3), self.evaluate(ids))
2503      self.assertEqual(0,
2504                       len(ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)))
2505
2506  def test_int32_index_table_from_file(self):
2507    vocabulary_file = self._createVocabFile(
2508        "f2i_vocab2.txt", values=("42", "1", "-1000"))
2509    table = lookup_ops.index_table_from_file(
2510        vocabulary_file=vocabulary_file,
2511        num_oov_buckets=1,
2512        key_dtype=dtypes.int32)
2513    ids = table.lookup(constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
2514
2515    if not context.executing_eagerly():
2516      with self.assertRaises(errors_impl.OpError):
2517        self.evaluate(ids)
2518    self.evaluate(lookup_ops.tables_initializer())
2519    self.assertAllEqual((1, 2, 3), self.evaluate(ids))
2520
2521  def test_int64_index_table_from_file(self):
2522    vocabulary_file = self._createVocabFile(
2523        "f2i_vocab3.txt", values=("42", "1", "-1000"))
2524    table = lookup_ops.index_table_from_file(
2525        vocabulary_file=vocabulary_file,
2526        num_oov_buckets=1,
2527        key_dtype=dtypes.int64)
2528    ids = table.lookup(constant_op.constant((1, -1000, 11), dtype=dtypes.int64))
2529
2530    if not context.executing_eagerly():
2531      with self.assertRaises(errors_impl.OpError):
2532        self.evaluate(ids)
2533    self.evaluate(lookup_ops.tables_initializer())
2534    self.assertAllEqual((1, 2, 3), self.evaluate(ids))
2535
2536  def test_index_table_from_file_with_default_value(self):
2537    default_value = -42
2538    vocabulary_file = self._createVocabFile("f2i_vocab4.txt")
2539    table = lookup_ops.index_table_from_file(
2540        vocabulary_file=vocabulary_file, default_value=default_value)
2541    ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
2542
2543    if not context.executing_eagerly():
2544      with self.assertRaises(errors_impl.OpError):
2545        self.evaluate(ids)
2546    self.evaluate(lookup_ops.tables_initializer())
2547    self.assertAllEqual((1, 2, default_value), self.evaluate(ids))
2548
2549  def test_index_table_from_file_with_oov_buckets(self):
2550    vocabulary_file = self._createVocabFile("f2i_vocab5.txt")
2551    table = lookup_ops.index_table_from_file(
2552        vocabulary_file=vocabulary_file, num_oov_buckets=1000)
2553    ids = table.lookup(
2554        constant_op.constant(["salad", "surgery", "tarkus", "toccata"]))
2555
2556    if not context.executing_eagerly():
2557      with self.assertRaises(errors_impl.OpError):
2558        self.evaluate(ids)
2559    self.evaluate(lookup_ops.tables_initializer())
2560    self.assertAllEqual(
2561        (
2562            1,  # From vocabulary file.
2563            2,  # From vocabulary file.
2564            867,  # 3 + fingerprint("tarkus") mod 300.
2565            860),  # 3 + fingerprint("toccata") mod 300.
2566        self.evaluate(ids))
2567
2568  def test_index_table_from_file_fails_with_empty_vocabulary_file_name(self):
2569    self.assertRaises(
2570        ValueError, lookup_ops.index_table_from_file, vocabulary_file="")
2571
2572  def test_index_table_from_file_fails_with_empty_vocabulary(self):
2573    self.assertRaises(
2574        ValueError, lookup_ops.index_table_from_file, vocabulary_file=None)
2575
2576  def test_index_table_from_file_str_fails_with_zero_size_vocabulary(self):
2577    vocabulary_file = self._createVocabFile("zero_vocab_str.txt")
2578    self.assertRaisesRegex(
2579        ValueError, "`vocab_size` must be greater than 0, got 0 for "
2580        "vocabulary_file: .*zero_vocab_str.txt",
2581        lookup_ops.index_table_from_file,
2582        vocabulary_file=vocabulary_file,
2583        vocab_size=0)
2584
2585  def test_index_table_from_file_tensor_fails_with_zero_size_vocabulary(self):
2586    vocabulary_file = constant_op.constant(
2587        self._createVocabFile("zero_vocab_tensor.txt"))
2588    self.assertRaisesRegex(
2589        ValueError, "`vocab_size` must be greater than 0, got 0 for "
2590        "vocabulary_file: .*zero_vocab_tensor.txt",
2591        lookup_ops.index_table_from_file,
2592        vocabulary_file=vocabulary_file,
2593        vocab_size=0)
2594
2595  def test_index_table_from_file_with_vocab_size_too_small(self):
2596    vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
2597    table = lookup_ops.index_table_from_file(
2598        vocabulary_file=vocabulary_file, vocab_size=2)
2599    ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
2600
2601    if not context.executing_eagerly():
2602      with self.assertRaises(errors_impl.OpError):
2603        self.evaluate(ids)
2604    self.evaluate(lookup_ops.tables_initializer())
2605    self.assertAllEqual((1, -1, -1), self.evaluate(ids))
2606    self.assertEqual(2, self.evaluate(table.size()))
2607
2608  def test_index_table_from_file_with_vocab_size_too_large(self):
2609    vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
2610    with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
2611                                "Invalid vocab_size"):
2612      table = lookup_ops.index_table_from_file(
2613          vocabulary_file=vocabulary_file, vocab_size=4)
2614      self.evaluate(table.initializer)
2615
2616  def test_index_table_from_file_with_vocab_size(self):
2617    vocabulary_file = self._createVocabFile("f2i_vocab8.txt")
2618
2619    self.assertRaises(
2620        ValueError,
2621        lookup_ops.index_table_from_file,
2622        vocabulary_file=vocabulary_file,
2623        vocab_size=0)
2624
2625    table = lookup_ops.index_table_from_file(
2626        vocabulary_file=vocabulary_file, vocab_size=3)
2627    ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
2628
2629    if not context.executing_eagerly():
2630      with self.assertRaises(errors_impl.OpError):
2631        self.evaluate(ids)
2632    self.evaluate(lookup_ops.tables_initializer())
2633    self.assertAllEqual((1, 2, -1), self.evaluate(ids))
2634    self.assertEqual(3, self.evaluate(table.size()))
2635
2636  def test_index_table_from_file_with_invalid_hashers(self):
2637    vocabulary_file = self._createVocabFile("invalid_hasher.txt")
2638    with self.assertRaises(TypeError):
2639      lookup_ops.index_table_from_file(
2640          vocabulary_file=vocabulary_file,
2641          vocab_size=3,
2642          num_oov_buckets=1,
2643          hasher_spec=1)
2644
2645    table = lookup_ops.index_table_from_file(
2646        vocabulary_file=vocabulary_file,
2647        vocab_size=3,
2648        num_oov_buckets=1,
2649        hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None))
2650
2651    self.assertRaises(ValueError, table.lookup,
2652                      constant_op.constant(["salad", "surgery", "tarkus"]))
2653
2654  def test_index_table_from_file_table_ref_with_oov_buckets(self):
2655    vocabulary_file = self._createVocabFile("f2i_vocab9.txt")
2656    table = lookup_ops.index_table_from_file(
2657        vocabulary_file=vocabulary_file, num_oov_buckets=1)
2658    self.assertIsNotNone(table.resource_handle)
2659
2660  def test_index_table_from_file_table_ref_without_oov_buckets(self):
2661    vocabulary_file = self._createVocabFile("f2i_vocab10.txt")
2662    table = lookup_ops.index_table_from_file(
2663        vocabulary_file=vocabulary_file, num_oov_buckets=0)
2664    self.assertIsNotNone(table.resource_handle)
2665
2666
2667class IndexTableFromTensor(test.TestCase):
2668
2669  @test_util.run_in_graph_and_eager_modes
2670  def test_index_table_from_tensor_with_tensor_init(self):
2671    table = lookup_ops.index_table_from_tensor(
2672        vocabulary_list=("brain", "salad", "surgery"), num_oov_buckets=1)
2673
2674    if not context.executing_eagerly():
2675      with self.assertRaises(errors_impl.OpError):
2676        self.evaluate(
2677            table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))))
2678    else:
2679      # Reinitializing a table in eager should work.
2680      table = lookup_ops.index_table_from_tensor(
2681          vocabulary_list=("brain", "salad", "surgery"), num_oov_buckets=1)
2682    self.evaluate(lookup_ops.tables_initializer())
2683    ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
2684    self.assertAllEqual((1, 2, 3), self.evaluate(ids))
2685
2686  def test_int32_index_table_from_tensor_with_tensor_init(self):
2687    table = lookup_ops.index_table_from_tensor(
2688        vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32)
2689    ids = table.lookup(constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
2690
2691    if not context.executing_eagerly():
2692      with self.assertRaises(errors_impl.FailedPreconditionError):
2693        self.evaluate(ids)
2694    self.evaluate(lookup_ops.tables_initializer())
2695    self.assertAllEqual((1, 2, 3), self.evaluate(ids))
2696
2697  def test_int64_index_table_from_tensor_with_tensor_init(self):
2698    table = lookup_ops.index_table_from_tensor(
2699        vocabulary_list=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64)
2700    ids = table.lookup(constant_op.constant((1, -1000, 11), dtype=dtypes.int64))
2701
2702    if not context.executing_eagerly():
2703      with self.assertRaises(errors_impl.FailedPreconditionError):
2704        self.evaluate(ids)
2705    self.evaluate(lookup_ops.tables_initializer())
2706    self.assertAllEqual((1, 2, 3), self.evaluate(ids))
2707
2708  def test_index_table_from_tensor_with_default_value(self):
2709    default_value = -42
2710    table = lookup_ops.index_table_from_tensor(
2711        vocabulary_list=["brain", "salad", "surgery"],
2712        default_value=default_value)
2713    ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
2714
2715    if not context.executing_eagerly():
2716      with self.assertRaises(errors_impl.FailedPreconditionError):
2717        self.evaluate(ids)
2718    self.evaluate(lookup_ops.tables_initializer())
2719    self.assertAllEqual((1, 2, default_value), self.evaluate(ids))
2720
2721  def test_index_table_from_tensor_missing_vocabulary_list(self):
2722    with self.assertRaisesRegex(ValueError,
2723                                "`vocabulary_list` must be specified"):
2724      lookup_ops.index_table_from_tensor(
2725          vocabulary_list=None, num_oov_buckets=1)
2726
2727  def test_index_table_from_tensor_empty_vocabulary_list(self):
2728    with self.assertRaisesRegex(errors_impl.OpError,
2729                                "keys and values cannot be empty"):
2730      _ = lookup_ops.index_table_from_tensor(
2731          vocabulary_list=np.array([], dtype=np.str_), num_oov_buckets=1)
2732      self.evaluate(lookup_ops.tables_initializer())
2733
2734  def test_index_table_from_tensor_with_invalid_hashers(self):
2735    with self.assertRaises(TypeError):
2736      lookup_ops.index_table_from_tensor(
2737          vocabulary_list=["brain", "salad", "surgery"],
2738          num_oov_buckets=1,
2739          hasher_spec=1)
2740
2741    table = lookup_ops.index_table_from_tensor(
2742        vocabulary_list=["brain", "salad", "surgery"],
2743        num_oov_buckets=1,
2744        hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None))
2745
2746    self.assertRaises(ValueError, table.lookup,
2747                      constant_op.constant(["salad", "surgery", "tarkus"]))
2748
2749
2750class IndexToStringTableFromFileTest(test.TestCase):
2751
2752  def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
2753    vocabulary_file = os.path.join(self.get_temp_dir(), basename)
2754    with open(vocabulary_file, "w") as f:
2755      f.write("\n".join(values) + "\n")
2756    return vocabulary_file
2757
2758  def test_index_to_string_table(self):
2759    vocabulary_path = self._createVocabFile("i2f_vocab1.txt")
2760    # vocabulary_file supports string and tensor
2761    type_funcs = [str, constant_op.constant]
2762    for type_func in type_funcs:
2763      vocabulary_file = type_func(vocabulary_path)
2764      table = lookup_ops.index_to_string_table_from_file(
2765          vocabulary_file=vocabulary_file)
2766      features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64))
2767      if not context.executing_eagerly():
2768        with self.assertRaises(errors_impl.OpError):
2769          self.evaluate(features)
2770      self.evaluate(lookup_ops.tables_initializer())
2771      self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
2772                          self.evaluate(features))
2773
2774  def test_index_to_string_table_from_multicolumn_file(self):
2775    vocabulary_file = self._createVocabFile(
2776        "f2i_vocab1.txt", values=("brain\t300", "salad\t20", "surgery\t1"))
2777    table = lookup_ops.index_to_string_table_from_file(
2778        vocabulary_file=vocabulary_file,
2779        key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER,
2780        value_column_index=0)
2781    features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64))
2782    if not context.executing_eagerly():
2783      with self.assertRaises(errors_impl.OpError):
2784        self.evaluate(features)
2785    self.evaluate(lookup_ops.tables_initializer())
2786    self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
2787                        self.evaluate(features))
2788
2789  def test_index_to_string_table_from_multicolumn_file_custom_delimiter(self):
2790    vocabulary_file = self._createVocabFile(
2791        "f2i_vocab1.txt", values=("brain 300", "salad 20", "surgery 1"))
2792    table = lookup_ops.index_to_string_table_from_file(
2793        vocabulary_file=vocabulary_file,
2794        key_column_index=lookup_ops.TextFileIndex.LINE_NUMBER,
2795        value_column_index=0,
2796        delimiter=" ")
2797    features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64))
2798    if not context.executing_eagerly():
2799      with self.assertRaises(errors_impl.OpError):
2800        self.evaluate(features)
2801    self.evaluate(lookup_ops.tables_initializer())
2802    self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
2803                        self.evaluate(features))
2804
2805  def test_index_to_string_table_with_default_value(self):
2806    default_value = b"NONE"
2807    vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
2808    table = lookup_ops.index_to_string_table_from_file(
2809        vocabulary_file=vocabulary_file, default_value=default_value)
2810    features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
2811    if not context.executing_eagerly():
2812      with self.assertRaises(errors_impl.OpError):
2813        self.evaluate(features)
2814    self.evaluate(lookup_ops.tables_initializer())
2815    self.assertAllEqual((b"salad", b"surgery", default_value),
2816                        self.evaluate(features))
2817
2818  def test_index_to_string_table_with_vocab_size_too_small(self):
2819    default_value = b"NONE"
2820    vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
2821    table = lookup_ops.index_to_string_table_from_file(
2822        vocabulary_file=vocabulary_file,
2823        vocab_size=2,
2824        default_value=default_value)
2825    features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
2826    if not context.executing_eagerly():
2827      with self.assertRaises(errors_impl.OpError):
2828        self.evaluate(features)
2829    self.evaluate(lookup_ops.tables_initializer())
2830    self.assertAllEqual((b"salad", default_value, default_value),
2831                        self.evaluate(features))
2832
2833  def test_index_to_string_table_with_vocab_size_too_large(self):
2834    vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
2835    with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
2836                                "Invalid vocab_size"):
2837      _ = lookup_ops.index_to_string_table_from_file(
2838          vocabulary_file=vocabulary_file, vocab_size=4)
2839      self.evaluate(lookup_ops.tables_initializer())
2840
2841  def test_index_to_string_table_with_vocab_size(self):
2842    vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
2843    table = lookup_ops.index_to_string_table_from_file(
2844        vocabulary_file=vocabulary_file, vocab_size=3)
2845    features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
2846
2847    if not context.executing_eagerly():
2848      with self.assertRaises(errors_impl.OpError):
2849        self.evaluate(features)
2850    self.evaluate(lookup_ops.tables_initializer())
2851    self.assertAllEqual((b"salad", b"surgery", b"UNK"), self.evaluate(features))
2852
2853
2854class IndexToStringTableFromTensorTest(test.TestCase):
2855
2856  def test_index_to_string_table_from_tensor(self):
2857    vocabulary_list = constant_op.constant(["brain", "salad", "surgery"])
2858    table = lookup_ops.index_to_string_table_from_tensor(
2859        vocabulary_list=vocabulary_list)
2860
2861    indices = constant_op.constant([0, 1, 2, 3], dtypes.int64)
2862    features = table.lookup(indices)
2863    if not context.executing_eagerly():
2864      with self.assertRaises(errors_impl.OpError):
2865        self.evaluate(features)
2866    self.evaluate(lookup_ops.tables_initializer())
2867
2868    self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
2869                        self.evaluate(features))
2870
2871  def test_duplicate_entries(self):
2872    vocabulary_list = constant_op.constant(["hello", "hello"])
2873    table = lookup_ops.index_to_string_table_from_tensor(
2874        vocabulary_list=vocabulary_list)
2875    indices = constant_op.constant([0, 1, 4], dtypes.int64)
2876    features = table.lookup(indices)
2877    self.evaluate(lookup_ops.tables_initializer())
2878    self.assertAllEqual((b"hello", b"hello", b"UNK"), self.evaluate(features))
2879
2880  def test_index_to_string_with_default_value(self):
2881    default_value = b"NONE"
2882    vocabulary_list = constant_op.constant(["brain", "salad", "surgery"])
2883    table = lookup_ops.index_to_string_table_from_tensor(
2884        vocabulary_list=vocabulary_list, default_value=default_value)
2885    indices = constant_op.constant([1, 2, 4], dtypes.int64)
2886    features = table.lookup(indices)
2887    if not context.executing_eagerly():
2888      with self.assertRaises(errors_impl.OpError):
2889        self.evaluate(features)
2890    self.evaluate(lookup_ops.tables_initializer())
2891    self.assertAllEqual((b"salad", b"surgery", default_value),
2892                        self.evaluate(features))
2893
2894
2895class IdTableWithHashBucketsTest(test.TestCase):
2896
2897  def _createVocabFile(self, basename, values=("brain", "salad", "surgery")):
2898    vocabulary_file = os.path.join(self.get_temp_dir(), basename)
2899    with open(vocabulary_file, "w") as f:
2900      f.write("\n".join(values) + "\n")
2901    return vocabulary_file
2902
2903  def testStringIdTableWithHashBuckets(self):
2904    vocab_file = self._createVocabFile("feat_to_id_1.txt")
2905    default_value = -1
2906    vocab_size = 3
2907    oov_buckets = 1
2908    table = lookup_ops.IdTableWithHashBuckets(
2909        lookup_ops.StaticHashTable(
2910            lookup_ops.TextFileIdTableInitializer(
2911                vocab_file, vocab_size=vocab_size), default_value),
2912        oov_buckets)
2913
2914    self.evaluate(table.initializer)
2915
2916    input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
2917
2918    out = table.lookup(input_string)
2919    self.assertAllEqual([0, 1, 2, 3], self.evaluate(out))
2920    self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size()))
2921
2922  def testInt32IdTableWithHashBuckets(self):
2923    vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000"))
2924    default_value = -1
2925    vocab_size = 3
2926    oov_buckets = 1
2927    table = lookup_ops.IdTableWithHashBuckets(
2928        lookup_ops.StaticHashTable(
2929            lookup_ops.TextFileIdTableInitializer(
2930                vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64),
2931            default_value),
2932        oov_buckets,
2933        key_dtype=dtypes.int32)
2934
2935    self.evaluate(table.initializer)
2936
2937    values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int32)
2938
2939    out = table.lookup(values)
2940    self.assertAllEqual([0, 1, 2, 3], self.evaluate(out))
2941    self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size()))
2942
2943  def testInt64IdTableWithHashBuckets(self):
2944    vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000"))
2945    default_value = -1
2946    vocab_size = 3
2947    oov_buckets = 1
2948    table = lookup_ops.IdTableWithHashBuckets(
2949        lookup_ops.StaticHashTable(
2950            lookup_ops.TextFileIdTableInitializer(
2951                vocab_file, vocab_size=vocab_size, key_dtype=dtypes.int64),
2952            default_value), oov_buckets)
2953
2954    self.evaluate(table.initializer)
2955
2956    values = constant_op.constant((42, 1, -1000, 11), dtype=dtypes.int64)
2957
2958    out = table.lookup(values)
2959    self.assertAllEqual([0, 1, 2, 3], self.evaluate(out))
2960    self.assertEqual(vocab_size + oov_buckets, self.evaluate(table.size()))
2961
2962  def testStringIdTableWithOnlyHashBucket(self):
2963    oov_buckets = 5
2964
2965    # Set a table that only uses hash buckets, for each input value returns
2966    # an id calculated by fingerprint("input") mod oov_buckets.
2967    table = lookup_ops.IdTableWithHashBuckets(None, oov_buckets)
2968    self.evaluate(table.initializer)
2969
2970    values = constant_op.constant(("brain", "salad", "surgery"))
2971
2972    out = table.lookup(values)
2973    self.assertAllEqual(
2974        [
2975            3,  # fingerprint("brain") mod 5.
2976            1,  # fingerprint("salad") mod 5.
2977            4  # fingerprint("surgery") mod 5
2978        ],
2979        self.evaluate(out))
2980    self.assertEqual(oov_buckets, self.evaluate(table.size()))
2981
2982  def testInt32IdTableWithOnlyHashBucket(self):
2983    oov_buckets = 5
2984
2985    # Set a table that only uses hash buckets, for each input value returns
2986    # an id calculated by fingerprint("input") mod oov_buckets.
2987    table = lookup_ops.IdTableWithHashBuckets(
2988        None, oov_buckets, key_dtype=dtypes.int32)
2989    self.evaluate(table.initializer)
2990
2991    input_string = constant_op.constant([42, 1, -1000], dtype=dtypes.int32)
2992
2993    out = table.lookup(input_string)
2994    self.assertAllEqual(
2995        [
2996            1,  # fingerprint("42") mod 5.
2997            4,  # fingerprint("1") mod 5.
2998            2  # fingerprint("-1000") mod 5
2999        ],
3000        self.evaluate(out))
3001    self.assertEqual(oov_buckets, self.evaluate(table.size()))
3002
3003  def testFloat64IdTableWithOnlyHashBucket(self):
3004    with self.assertRaisesRegex(TypeError, "Invalid `key_dtype`"):
3005      lookup_ops.IdTableWithHashBuckets(
3006          None, num_oov_buckets=5, key_dtype=dtypes.float64)
3007
3008  def testBoolIdTableWithOnlyHashBucket(self):
3009    with self.assertRaisesRegex(TypeError, "Invalid `key_dtype`"):
3010      lookup_ops.IdTableWithHashBuckets(
3011          None, num_oov_buckets=5, key_dtype=dtypes.bool)
3012
3013  def testIdTableWithHashBucketsWithMultipleInitializers(self):
3014    vocab_file = self._createVocabFile("feat_to_id_4.txt")
3015    default_value = -1
3016    vocab_size = 3
3017    oov_buckets = 3
3018
3019    vocab_table = lookup_ops.StaticHashTable(
3020        lookup_ops.TextFileIdTableInitializer(
3021            vocab_file, vocab_size=vocab_size), default_value)
3022    table1 = lookup_ops.IdTableWithHashBuckets(
3023        vocab_table,
3024        oov_buckets,
3025        hasher_spec=lookup_ops.FastHashSpec,
3026        name="table1")
3027
3028    table2 = lookup_ops.IdTableWithHashBuckets(
3029        vocab_table,
3030        oov_buckets,
3031        hasher_spec=lookup_ops.StrongHashSpec((1, 2)),
3032        name="table2")
3033
3034    self.evaluate(lookup_ops.tables_initializer())
3035
3036    input_string = constant_op.constant(
3037        ["fruit", "brain", "salad", "surgery", "UNK"])
3038
3039    out1 = table1.lookup(input_string)
3040    out2 = table2.lookup(input_string)
3041
3042    out1, out2 = self.evaluate([out1, out2])
3043    self.assertAllEqual([5, 0, 1, 2, 5], out1)
3044    self.assertAllEqual([5, 0, 1, 2, 3], out2)
3045    self.assertEqual(vocab_size + oov_buckets, self.evaluate(table1.size()))
3046    self.assertEqual(vocab_size + oov_buckets, self.evaluate(table2.size()))
3047    if not context.executing_eagerly():
3048      test_util.assert_ops_in_graph({
3049          "table1_Lookup/hash_bucket": "StringToHashBucketFast",
3050          "table2_Lookup/hash_bucket": "StringToHashBucketStrong",
3051      }, ops.get_default_graph())
3052
3053  def testIdTableWithHashBucketsInitializationAcrossSessions(self):
3054    vocab_file = self._createVocabFile("feat_to_id_5.txt")
3055    default_value = -1
3056    vocab_size = 3
3057    oov_buckets = 1
3058    table1 = lookup_ops.IdTableWithHashBuckets(
3059        lookup_ops.StaticHashTable(
3060            lookup_ops.TextFileIdTableInitializer(
3061                vocab_file, vocab_size=vocab_size), default_value), oov_buckets)
3062
3063    self.evaluate(table1.initializer)
3064
3065    input_string_1 = constant_op.constant(["brain", "salad", "surgery", "UNK"])
3066
3067    out1 = table1.lookup(input_string_1)
3068
3069    self.assertAllEqual([0, 1, 2, 3], self.evaluate(out1))
3070    self.assertEqual(vocab_size + oov_buckets, self.evaluate(table1.size()))
3071
3072    default_value = -1
3073    vocab_size = 3
3074    oov_buckets = 1
3075
3076    # Underlying lookup table already initialized in previous session.
3077    # No need to call self.evaluate(table2.initializer)
3078    table2 = lookup_ops.IdTableWithHashBuckets(
3079        lookup_ops.StaticHashTable(
3080            lookup_ops.TextFileIdTableInitializer(
3081                vocab_file, vocab_size=vocab_size), default_value), oov_buckets)
3082
3083    input_string_2 = constant_op.constant(["fruit", "salad", "UNK"])
3084
3085    out2 = table2.lookup(input_string_2)
3086
3087    self.assertAllEqual([3, 1, 3], self.evaluate(out2))
3088    self.assertEqual(vocab_size + oov_buckets, self.evaluate(table2.size()))
3089
3090  def testIdTableWithHashBucketsWithMultipleInitializersDifferentDefault(self):
3091    vocab_file = self._createVocabFile("feat_to_id_6.txt")
3092    default_value1 = -1
3093    vocab_size = 3
3094    oov_buckets = 0
3095    table1 = lookup_ops.IdTableWithHashBuckets(
3096        lookup_ops.StaticHashTable(
3097            lookup_ops.TextFileIdTableInitializer(
3098                vocab_file, vocab_size=vocab_size), default_value1),
3099        oov_buckets)
3100
3101    default_value2 = -2
3102    table2 = lookup_ops.IdTableWithHashBuckets(
3103        lookup_ops.StaticHashTable(
3104            lookup_ops.TextFileIdTableInitializer(
3105                vocab_file, vocab_size=vocab_size), default_value2),
3106        oov_buckets)
3107
3108    self.evaluate(lookup_ops.tables_initializer())
3109
3110    input_string_1 = constant_op.constant(
3111        ["brain", "salad", "surgery", "UNK"])
3112    input_string_2 = constant_op.constant(["fruit", "salad", "UNK"])
3113
3114    out1 = table1.lookup(input_string_1)
3115    out2 = table2.lookup(input_string_2)
3116
3117    out1, out2 = self.evaluate([out1, out2])
3118    self.assertAllEqual([0, 1, 2, -1], out1)
3119    self.assertAllEqual([-2, 1, -2], out2)
3120    self.assertEqual(vocab_size + oov_buckets, self.evaluate(table1.size()))
3121    self.assertEqual(vocab_size + oov_buckets, self.evaluate(table2.size()))
3122
3123  def testSparseTensor(self):
3124    vocab_file = self._createVocabFile("feat_to_id_7.txt")
3125    input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
3126    input_shape = [4, 4]
3127    sp_features = sparse_tensor.SparseTensor(
3128        constant_op.constant(input_indices, dtypes.int64),
3129        constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"],
3130                             dtypes.string),
3131        constant_op.constant(input_shape, dtypes.int64))
3132
3133    table = lookup_ops.IdTableWithHashBuckets(
3134        lookup_ops.StaticHashTable(
3135            lookup_ops.TextFileIdTableInitializer(vocab_file, vocab_size=3),
3136            -1), 1)
3137    self.evaluate(table.initializer)
3138
3139    sp_ids = table.lookup(sp_features)
3140
3141    self.assertAllEqual([5], sp_ids.values._shape_as_list())
3142
3143    sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate(
3144        [sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
3145
3146    self.assertAllEqual(input_indices, sp_ids_ind)
3147    self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val)
3148    self.assertAllEqual(input_shape, sp_ids_shape)
3149
3150  def testRaggedTensor(self):
3151    vocab_file = self._createVocabFile("feat_to_id_7.txt")
3152    input_row_splits = [0, 2, 4, 5]
3153    ragged_features = ragged_tensor.RaggedTensor.from_row_splits(
3154        constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"],
3155                             dtypes.string),
3156        constant_op.constant(input_row_splits, dtypes.int64))
3157
3158    table = lookup_ops.IdTableWithHashBuckets(
3159        lookup_ops.StaticHashTable(
3160            lookup_ops.TextFileIdTableInitializer(vocab_file, vocab_size=3),
3161            -1), 1)
3162    self.evaluate(table.initializer)
3163
3164    ragged_ids = table.lookup(ragged_features)
3165    self.assertAllEqual([5], ragged_ids.values._shape_as_list())
3166
3167    ragged_ids_val, ragged_ids_row_splits = self.evaluate(
3168        [ragged_ids.values, ragged_ids.row_splits])
3169
3170    self.assertAllEqual([0, 1, 0, 2, 3], ragged_ids_val)
3171    self.assertAllEqual(input_row_splits, ragged_ids_row_splits)
3172
3173  def testInt32SparseTensor(self):
3174    input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
3175    input_shape = [4, 4]
3176    sp_features = sparse_tensor.SparseTensor(
3177        constant_op.constant(input_indices, dtypes.int64),
3178        constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32),
3179        constant_op.constant(input_shape, dtypes.int64))
3180
3181    table = lookup_ops.IdTableWithHashBuckets(
3182        lookup_ops.StaticHashTable(
3183            lookup_ops.KeyValueTensorInitializer(
3184                (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64), -1),
3185        1,
3186        key_dtype=dtypes.int32)
3187    self.evaluate(table.initializer)
3188
3189    sp_ids = table.lookup(sp_features)
3190
3191    self.assertAllEqual([5], sp_ids.values._shape_as_list())
3192
3193    sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate(
3194        [sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
3195
3196    self.assertAllEqual(input_indices, sp_ids_ind)
3197    self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val)
3198    self.assertAllEqual(input_shape, sp_ids_shape)
3199
3200  def testInt32RaggedTensor(self):
3201    input_row_splits = [0, 2, 4, 5]
3202    ragged_features = ragged_tensor.RaggedTensor.from_row_splits(
3203        constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32),
3204        constant_op.constant(input_row_splits, dtypes.int32))
3205
3206    table = lookup_ops.IdTableWithHashBuckets(
3207        lookup_ops.StaticHashTable(
3208            lookup_ops.KeyValueTensorInitializer(
3209                (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64), -1),
3210        1,
3211        key_dtype=dtypes.int32)
3212    self.evaluate(table.initializer)
3213
3214    ragged_ids = table.lookup(ragged_features)
3215
3216    self.assertAllEqual([5], ragged_ids.values._shape_as_list())
3217
3218    ragged_ids_val, ragged_ids_row_splits = self.evaluate(
3219        [ragged_ids.values, ragged_ids.row_splits])
3220
3221    self.assertAllEqual([0, 1, 0, 2, 3], ragged_ids_val)
3222    self.assertAllEqual(input_row_splits, ragged_ids_row_splits)
3223
3224  def testInt64SparseTensor(self):
3225    input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
3226    input_shape = [4, 4]
3227    sp_features = sparse_tensor.SparseTensor(
3228        constant_op.constant(input_indices, dtypes.int64),
3229        constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64),
3230        constant_op.constant(input_shape, dtypes.int64))
3231
3232    table = lookup_ops.IdTableWithHashBuckets(
3233        lookup_ops.StaticHashTable(
3234            lookup_ops.KeyValueTensorInitializer(
3235                (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64), -1),
3236        1,
3237        key_dtype=dtypes.int64)
3238    self.evaluate(table.initializer)
3239
3240    sp_ids = table.lookup(sp_features)
3241
3242    self.assertAllEqual([5], sp_ids.values._shape_as_list())
3243
3244    sp_ids_ind, sp_ids_val, sp_ids_shape = self.evaluate(
3245        [sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
3246
3247    self.assertAllEqual(input_indices, sp_ids_ind)
3248    self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val)
3249    self.assertAllEqual(input_shape, sp_ids_shape)
3250
3251  def testInt64RaggedTensor(self):
3252    input_row_splits = [0, 2, 4, 5]
3253    ragged_features = ragged_tensor.RaggedTensor.from_row_splits(
3254        constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64),
3255        constant_op.constant(input_row_splits, dtypes.int64))
3256
3257    table = lookup_ops.IdTableWithHashBuckets(
3258        lookup_ops.StaticHashTable(
3259            lookup_ops.KeyValueTensorInitializer(
3260                (42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64), -1),
3261        1,
3262        key_dtype=dtypes.int64)
3263    self.evaluate(table.initializer)
3264
3265    ragged_ids = table.lookup(ragged_features)
3266
3267    self.assertAllEqual([5], ragged_ids.values._shape_as_list())
3268
3269    ragged_ids_val, ragged_ids_row_splits = self.evaluate(
3270        [ragged_ids.values, ragged_ids.row_splits])
3271
3272    self.assertAllEqual([0, 1, 0, 2, 3], ragged_ids_val)
3273    self.assertAllEqual(input_row_splits, ragged_ids_row_splits)
3274
3275  def testIdTableWithHashBucketsWithInvalidHashers(self):
3276    vocab_file = self._createVocabFile("feat_to_id_4.txt")
3277    default_value = -1
3278    vocab_size = 3
3279    oov_buckets = 1
3280    lookup_table = lookup_ops.StaticHashTable(
3281        lookup_ops.TextFileIdTableInitializer(
3282            vocab_file, vocab_size=vocab_size), default_value)
3283
3284    with self.assertRaises(TypeError):
3285      lookup_ops.IdTableWithHashBuckets(
3286          lookup_table, oov_buckets, hasher_spec=1)
3287
3288    table = lookup_ops.IdTableWithHashBuckets(
3289        lookup_table,
3290        oov_buckets,
3291        hasher_spec=lookup_ops.HasherSpec("my-awesome-hash", None))
3292
3293    input_string = constant_op.constant(["brain", "salad", "surgery", "UNK"])
3294
3295    with self.assertRaises(ValueError):
3296      table.lookup(input_string)
3297
3298    with self.assertRaises(ValueError):
3299      table = lookup_ops.IdTableWithHashBuckets(
3300          lookup_table, oov_buckets, hasher_spec=lookup_ops.StrongHashSpec([]))
3301
3302    with self.assertRaises(ValueError):
3303      table = lookup_ops.IdTableWithHashBuckets(
3304          lookup_table,
3305          oov_buckets,
3306          hasher_spec=lookup_ops.StrongHashSpec([1, 2, 3]))
3307
3308    with self.assertRaises(TypeError):
3309      table = lookup_ops.IdTableWithHashBuckets(
3310          lookup_table,
3311          oov_buckets,
3312          hasher_spec=lookup_ops.StrongHashSpec([None, 2]))
3313
3314  def testIdTableWithHashBucketsNoInnerTable(self):
3315    table = lookup_ops.IdTableWithHashBuckets(None, num_oov_buckets=1)
3316    self.assertIsNone(table.resource_handle)
3317
3318
3319@parameterized.named_parameters(
3320    (f"_{is_anonymous}", is_anonymous) for is_anonymous in [False, True])
3321class MutableHashTableOpTest(test.TestCase):
3322
3323  def testMutableHashTable(self, is_anonymous):
3324    if is_anonymous and not tf2.enabled():
3325      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
3326    default_val = -1
3327    keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"])
3328    values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
3329    table = lookup_ops.MutableHashTable(
3330        dtypes.string,
3331        dtypes.int64,
3332        default_val,
3333        experimental_is_anonymous=is_anonymous)
3334    self.assertAllEqual(0, self.evaluate(table.size()))
3335
3336    self.evaluate(table.insert(keys, values))
3337    self.assertAllEqual(4, self.evaluate(table.size()))
3338
3339    remove_string = constant_op.constant(["tarkus", "tank"])
3340    self.evaluate(table.remove(remove_string))
3341    self.assertAllEqual(3, self.evaluate(table.size()))
3342
3343    input_string = constant_op.constant(["brain", "salad", "tank"])
3344    output = table.lookup(input_string)
3345    self.assertAllEqual([3], output.get_shape())
3346
3347    result = self.evaluate(output)
3348    self.assertAllEqual([0, 1, -1], result)
3349
3350    exported_keys, exported_values = table.export()
3351
3352    # exported data is in the order of the internal map, i.e. undefined
3353    sorted_keys = np.sort(self.evaluate(exported_keys))
3354    sorted_values = np.sort(self.evaluate(exported_values))
3355    self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys)
3356    self.assertAllEqual([0, 1, 2], sorted_values)
3357
3358  # TODO(https://github.com/tensorflow/tensorflow/issues/24439): remove exepectedFailure when fixed
3359  @unittest.expectedFailure
3360  @test_util.run_v2_only
3361  def testImportedHashTable(self, is_anonymous):
3362    g = ops.Graph()
3363    with g.as_default():
3364      default_val = -1
3365      keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"])
3366      values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
3367      table = lookup_ops.MutableHashTable(
3368          dtypes.string,
3369          dtypes.int64,
3370          default_val,
3371          experimental_is_anonymous=is_anonymous)
3372      self.evaluate(table.insert(keys, values))
3373      op = table.lookup(constant_op.constant(["brain", "salad", "tank"]))
3374      meta_graph = saver.export_meta_graph()
3375
3376    def f():
3377      saver.import_meta_graph(meta_graph)
3378      return ops.get_default_graph().get_tensor_by_name(op.name)
3379
3380    wrapped = wrap_function.wrap_function(f, [])
3381    self.assertAllEqual([0, 1, -1], wrapped())
3382
3383  @test_util.run_v1_only("SaverV1")
3384  def testSaveRestore(self, is_anonymous):
3385    if is_anonymous and not tf2.enabled():
3386      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
3387    save_dir = os.path.join(self.get_temp_dir(), "save_restore")
3388    save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
3389
3390    with self.session(graph=ops.Graph()) as sess:
3391      v0 = variables.Variable(10.0, name="v0")
3392      v1 = variables.Variable(20.0, name="v1")
3393
3394      default_val = -1
3395      keys = constant_op.constant(["b", "c", "d"], dtypes.string)
3396      values = constant_op.constant([0, 1, 2], dtypes.int64)
3397      table = lookup_ops.MutableHashTable(
3398          dtypes.string,
3399          dtypes.int64,
3400          default_val,
3401          name="t1",
3402          checkpoint=True,
3403          experimental_is_anonymous=is_anonymous)
3404
3405      save = saver.Saver()
3406      self.evaluate(variables.global_variables_initializer())
3407
3408      # Check that the parameter nodes have been initialized.
3409      self.assertEqual(10.0, self.evaluate(v0))
3410      self.assertEqual(20.0, self.evaluate(v1))
3411
3412      self.assertAllEqual(0, self.evaluate(table.size()))
3413      self.evaluate(table.insert(keys, values))
3414      self.assertAllEqual(3, self.evaluate(table.size()))
3415
3416      val = save.save(sess, save_path)
3417      self.assertIsInstance(val, str)
3418      self.assertEqual(save_path, val)
3419
3420    with self.session(graph=ops.Graph()) as sess:
3421      v0 = variables.Variable(-1.0, name="v0")
3422      v1 = variables.Variable(-1.0, name="v1")
3423      default_val = -1
3424      table = lookup_ops.MutableHashTable(
3425          dtypes.string,
3426          dtypes.int64,
3427          default_val,
3428          name="t1",
3429          checkpoint=True,
3430          experimental_is_anonymous=is_anonymous)
3431      self.evaluate(
3432          table.insert(
3433              constant_op.constant(["a", "c"], dtypes.string),
3434              constant_op.constant([12, 24], dtypes.int64)))
3435      self.assertAllEqual(2, self.evaluate(table.size()))
3436
3437      save = saver.Saver()
3438
3439      # Restore the saved values in the parameter nodes.
3440      save.restore(sess, save_path)
3441      # Check that the parameter nodes have been restored.
3442      self.assertEqual(10.0, self.evaluate(v0))
3443      self.assertEqual(20.0, self.evaluate(v1))
3444
3445      self.assertAllEqual(3, self.evaluate(table.size()))
3446
3447      input_string = constant_op.constant(["a", "b", "c", "d", "e"],
3448                                          dtypes.string)
3449      output = table.lookup(input_string)
3450      self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output))
3451
3452  @test_util.run_v1_only("SaverV1")
3453  def testSaveRestoreOnlyTable(self, is_anonymous):
3454    if is_anonymous and not tf2.enabled():
3455      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
3456    save_dir = os.path.join(self.get_temp_dir(), "save_restore")
3457    save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
3458
3459    with self.session(graph=ops.Graph()) as sess:
3460      v0 = variables.Variable(10.0, name="v0")
3461      v1 = variables.Variable(20.0, name="v1")
3462
3463      default_val = -1
3464      keys = constant_op.constant(["b", "c", "d"], dtypes.string)
3465      values = constant_op.constant([0, 1, 2], dtypes.int64)
3466      table = lookup_ops.MutableHashTable(
3467          dtypes.string,
3468          dtypes.int64,
3469          default_val,
3470          name="t1",
3471          checkpoint=True,
3472          experimental_is_anonymous=is_anonymous)
3473
3474      save = saver.Saver([table])
3475      self.evaluate(variables.global_variables_initializer())
3476
3477      # Check that the parameter nodes have been initialized.
3478      self.assertEqual(10.0, self.evaluate(v0))
3479      self.assertEqual(20.0, self.evaluate(v1))
3480
3481      self.assertAllEqual(0, self.evaluate(table.size()))
3482      self.evaluate(table.insert(keys, values))
3483      self.assertAllEqual(3, self.evaluate(table.size()))
3484
3485      val = save.save(sess, save_path)
3486      self.assertIsInstance(val, str)
3487      self.assertEqual(save_path, val)
3488
3489    with self.session(graph=ops.Graph()) as sess:
3490      default_val = -1
3491      table = lookup_ops.MutableHashTable(
3492          dtypes.string,
3493          dtypes.int64,
3494          default_val,
3495          name="t1",
3496          checkpoint=True,
3497          experimental_is_anonymous=is_anonymous)
3498      self.evaluate(
3499          table.insert(
3500              constant_op.constant(["a", "c"], dtypes.string),
3501              constant_op.constant([12, 24], dtypes.int64)))
3502      self.assertAllEqual(2, self.evaluate(table.size()))
3503
3504      save = saver.Saver([table])
3505
3506      # Restore the saved values in the parameter nodes.
3507      save.restore(sess, save_path)
3508      # Check that the parameter nodes have been restored.
3509
3510      self.assertAllEqual(3, self.evaluate(table.size()))
3511
3512      input_string = constant_op.constant(["a", "b", "c", "d", "e"],
3513                                          dtypes.string)
3514      output = table.lookup(input_string)
3515      self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output))
3516
3517  @test_util.run_in_graph_and_eager_modes
3518  def testObjectSaveRestore(self, is_anonymous):
3519    if is_anonymous and not context.executing_eagerly():
3520      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
3521    save_dir = os.path.join(self.get_temp_dir(), "save_restore")
3522    save_prefix = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
3523
3524    v0 = variables.Variable(10.0, name="v0")
3525    v1 = variables.Variable(20.0, name="v1")
3526
3527    default_val = -1
3528    keys = constant_op.constant(["b", "c", "d"], dtypes.string)
3529    values = constant_op.constant([0, 1, 2], dtypes.int64)
3530    table = lookup_ops.MutableHashTable(
3531        dtypes.string,
3532        dtypes.int64,
3533        default_val,
3534        name="t1",
3535        checkpoint=True,
3536        experimental_is_anonymous=is_anonymous)
3537
3538    checkpoint = trackable.Checkpoint(table=table, v0=v0, v1=v1)
3539    self.evaluate([v0.initializer, v1.initializer])
3540
3541    # Check that the parameter nodes have been initialized.
3542    self.assertEqual(10.0, self.evaluate(v0))
3543    self.assertEqual(20.0, self.evaluate(v1))
3544
3545    self.assertAllEqual(0, self.evaluate(table.size()))
3546    self.evaluate(table.insert(keys, values))
3547    self.assertAllEqual(3, self.evaluate(table.size()))
3548
3549    save_path = checkpoint.save(save_prefix)
3550    del table, checkpoint, v0, v1
3551
3552    v0 = variables.Variable(-1.0, name="v0")
3553    v1 = variables.Variable(-1.0, name="v1")
3554    default_val = -1
3555    table = lookup_ops.MutableHashTable(
3556        dtypes.string,
3557        dtypes.int64,
3558        default_val,
3559        name="t1",
3560        checkpoint=True,
3561        experimental_is_anonymous=is_anonymous)
3562    self.evaluate(
3563        table.insert(
3564            constant_op.constant(["a", "c"], dtypes.string),
3565            constant_op.constant([12, 24], dtypes.int64)))
3566    self.assertAllEqual(2, self.evaluate(table.size()))
3567
3568    checkpoint = trackable.Checkpoint(table=table, v0=v0, v1=v1)
3569
3570    # Restore the saved values in the parameter nodes.
3571    checkpoint.restore(save_path).run_restore_ops()
3572    # Check that the parameter nodes have been restored.
3573    self.assertEqual(10.0, self.evaluate(v0))
3574    self.assertEqual(20.0, self.evaluate(v1))
3575
3576    self.assertAllEqual(3, self.evaluate(table.size()))
3577
3578    input_string = constant_op.constant(["a", "b", "c", "d", "e"],
3579                                        dtypes.string)
3580    output = table.lookup(input_string)
3581    self.assertAllEqual([-1, 0, 1, 2, -1], self.evaluate(output))
3582
3583  @test_util.run_v2_only
3584  def testSavedModelSaveRestore(self, is_anonymous):
3585    save_dir = os.path.join(self.get_temp_dir(), "save_restore")
3586    save_path = os.path.join(tempfile.mkdtemp(prefix=save_dir), "hash")
3587
3588    root = autotrackable.AutoTrackable()
3589
3590    default_value = -1
3591    keys = constant_op.constant([11, 12, 13], dtypes.int64)
3592    values = constant_op.constant([0, 1, 2], dtypes.int64)
3593    root.table = lookup_ops.MutableHashTable(
3594        dtypes.int64,
3595        dtypes.int64,
3596        default_value,
3597        experimental_is_anonymous=is_anonymous)
3598
3599    @def_function.function(
3600        input_signature=[tensor_spec.TensorSpec((), dtypes.int64)])
3601    def lookup(key):
3602      return root.table.lookup(key)
3603
3604    @def_function.function(input_signature=[])
3605    def size():
3606      return root.table.size()
3607
3608    @def_function.function(input_signature=[])
3609    def is_ref_counting():
3610      return test_ops.is_resource_handle_ref_counting(
3611          root.table.resource_handle)
3612
3613    root.lookup = lookup
3614    root.size = size
3615    root.is_ref_counting = is_ref_counting
3616
3617    self.assertEqual(root.table.size(), 0)
3618    root.table.insert(keys, values)
3619    self.assertEqual(root.table.size(), 3)
3620    self.assertEqual(root.table.lookup(12), 1)
3621    self.assertEqual(root.table.lookup(10), -1)
3622    self.assertEqual(len(root.table.export()[0]), 3)
3623    self.assertEqual(root.is_ref_counting(), is_anonymous)
3624
3625    saved_model_save.save(root, save_path)
3626
3627    del root
3628    loaded = saved_model_load.load(save_path)
3629    self.assertEqual(loaded.size(), 3)
3630    self.assertEqual(loaded.lookup(12), 1)
3631    self.assertEqual(loaded.lookup(10), -1)
3632    self.assertEqual(loaded.is_ref_counting(), is_anonymous)
3633
3634  @test_util.run_v1_only("Multiple sessions")
3635  def testSharing(self, is_anonymous):
3636    if is_anonymous and not tf2.enabled():
3637      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
3638    # Start a server to store the table state
3639    server = server_lib.Server({"local0": ["localhost:0"]},
3640                               protocol="grpc",
3641                               start=True)
3642    # Create two sessions sharing the same state
3643    session1 = session.Session(server.target)
3644    session2 = session.Session(server.target)
3645
3646    table = lookup_ops.MutableHashTable(
3647        dtypes.int64,
3648        dtypes.string,
3649        "-",
3650        name="t1",
3651        experimental_is_anonymous=is_anonymous)
3652
3653    # Populate the table in the first session
3654    with session1:
3655      self.assertAllEqual(0, table.size())
3656
3657      keys = constant_op.constant([11, 12], dtypes.int64)
3658      values = constant_op.constant(["a", "b"])
3659      table.insert(keys, values).run()
3660      self.assertAllEqual(2, table.size())
3661
3662      output = table.lookup(constant_op.constant([11, 12, 13], dtypes.int64))
3663      self.assertAllEqual([b"a", b"b", b"-"], output)
3664
3665    # Verify that we can access the shared data from the second session
3666    with session2:
3667      self.assertAllEqual(2, table.size())
3668
3669      output = table.lookup(constant_op.constant([10, 11, 12], dtypes.int64))
3670      self.assertAllEqual([b"-", b"a", b"b"], output)
3671
3672  def testMutableHashTableOfTensors(self, is_anonymous):
3673    if is_anonymous and not tf2.enabled():
3674      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
3675    default_val = constant_op.constant([-1, -1], dtypes.int64)
3676    keys = constant_op.constant(["brain", "salad", "surgery", "tarkus"])
3677    values = constant_op.constant([[0, 1], [2, 3], [4, 5], [6, 7]],
3678                                  dtypes.int64)
3679    table = lookup_ops.MutableHashTable(
3680        dtypes.string,
3681        dtypes.int64,
3682        default_val,
3683        experimental_is_anonymous=is_anonymous)
3684    self.assertAllEqual(0, self.evaluate(table.size()))
3685
3686    self.evaluate(table.insert(keys, values))
3687    self.assertAllEqual(4, self.evaluate(table.size()))
3688
3689    remove_string = constant_op.constant(["tarkus", "tank"])
3690    self.evaluate(table.remove(remove_string))
3691    self.assertAllEqual(3, self.evaluate(table.size()))
3692
3693    input_string = constant_op.constant(["brain", "salad", "tank"])
3694    output = table.lookup(input_string)
3695    self.assertAllEqual([3, 2], output.get_shape())
3696
3697    result = self.evaluate(output)
3698    self.assertAllEqual([[0, 1], [2, 3], [-1, -1]], result)
3699
3700    exported_keys, exported_values = table.export()
3701    # exported data is in the order of the internal map, i.e. undefined
3702    sorted_keys = np.sort(self.evaluate(exported_keys))
3703    sorted_values = np.sort(self.evaluate(exported_values), axis=0)
3704    self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys)
3705    sorted_expected_values = np.sort([[4, 5], [2, 3], [0, 1]], axis=0)
3706    self.assertAllEqual(sorted_expected_values, sorted_values)
3707
3708  def testMutableHashTableExportInsert(self, is_anonymous):
3709    if is_anonymous and not tf2.enabled():
3710      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
3711    default_val = constant_op.constant([-1, -1], dtypes.int64)
3712    keys = constant_op.constant(["brain", "salad", "surgery"])
3713    values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
3714    table1 = lookup_ops.MutableHashTable(
3715        dtypes.string,
3716        dtypes.int64,
3717        default_val,
3718        experimental_is_anonymous=is_anonymous)
3719    self.assertAllEqual(0, self.evaluate(table1.size()))
3720    self.evaluate(table1.insert(keys, values))
3721    self.assertAllEqual(3, self.evaluate(table1.size()))
3722
3723    input_string = constant_op.constant(["brain", "salad", "tank"])
3724    expected_output = [[0, 1], [2, 3], [-1, -1]]
3725    output1 = table1.lookup(input_string)
3726    self.assertAllEqual(expected_output, self.evaluate(output1))
3727
3728    exported_keys, exported_values = table1.export()
3729    self.assertAllEqual(3, self.evaluate(exported_keys).size)
3730    self.assertAllEqual(6, self.evaluate(exported_values).size)
3731
3732    # Populate a second table from the exported data
3733    table2 = lookup_ops.MutableHashTable(
3734        dtypes.string,
3735        dtypes.int64,
3736        default_val,
3737        experimental_is_anonymous=is_anonymous)
3738    self.assertAllEqual(0, self.evaluate(table2.size()))
3739    self.evaluate(table2.insert(exported_keys, exported_values))
3740    self.assertAllEqual(3, self.evaluate(table2.size()))
3741
3742    # Verify lookup result is still the same
3743    output2 = table2.lookup(input_string)
3744    self.assertAllEqual(expected_output, self.evaluate(output2))
3745
3746  def testMutableHashTableOfTensorsInvalidShape(self, is_anonymous):
3747    if is_anonymous and not tf2.enabled():
3748      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
3749    default_val = constant_op.constant([-1, -1], dtypes.int64)
3750    keys = constant_op.constant(["brain", "salad", "surgery"])
3751    table = lookup_ops.MutableHashTable(
3752        dtypes.string,
3753        dtypes.int64,
3754        default_val,
3755        experimental_is_anonymous=is_anonymous)
3756
3757    # Shape [6] instead of [3, 2]
3758    values = constant_op.constant([0, 1, 2, 3, 4, 5], dtypes.int64)
3759    with self.assertRaisesOpError("Expected shape"):
3760      self.evaluate(table.insert(keys, values))
3761
3762    # Shape [2,3] instead of [3, 2]
3763    values = constant_op.constant([[0, 1, 2], [3, 4, 5]], dtypes.int64)
3764    with self.assertRaisesOpError("Expected shape"):
3765      self.evaluate(table.insert(keys, values))
3766
3767    # Shape [2, 2] instead of [3, 2]
3768    values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64)
3769    with self.assertRaisesOpError("Expected shape"):
3770      self.evaluate(table.insert(keys, values))
3771
3772    # Shape [3, 1] instead of [3, 2]
3773    values = constant_op.constant([[0], [2], [4]], dtypes.int64)
3774    with self.assertRaisesOpError("Expected shape"):
3775      self.evaluate(table.insert(keys, values))
3776
3777    # Valid Insert
3778    values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
3779    self.evaluate(table.insert(keys, values))
3780    self.assertAllEqual(3, self.evaluate(table.size()))
3781
3782  def testMutableHashTableInvalidDefaultValue(self, is_anonymous):
3783    default_val = constant_op.constant([[-1, -1]], dtypes.int64)
3784    with self.assertRaisesOpError("Default value must be a vector"):
3785      table = lookup_ops.MutableHashTable(
3786          dtypes.string,
3787          dtypes.int64,
3788          default_val,
3789          experimental_is_anonymous=is_anonymous)
3790      self.assertAllEqual(0, self.evaluate(table.size()))
3791
3792  def testMutableHashTableDuplicateInsert(self, is_anonymous):
3793    if is_anonymous and not tf2.enabled():
3794      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
3795    default_val = -1
3796    keys = constant_op.constant(["brain", "salad", "surgery", "brain"])
3797    values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
3798    table = lookup_ops.MutableHashTable(
3799        dtypes.string,
3800        dtypes.int64,
3801        default_val,
3802        experimental_is_anonymous=is_anonymous)
3803    self.assertAllEqual(0, self.evaluate(table.size()))
3804
3805    self.evaluate(table.insert(keys, values))
3806    self.assertAllEqual(3, self.evaluate(table.size()))
3807
3808    input_string = constant_op.constant(["brain", "salad", "tank"])
3809    output = table.lookup(input_string)
3810
3811    result = self.evaluate(output)
3812    self.assertAllEqual([3, 1, -1], result)
3813
3814  def testMutableHashTableFindHighRank(self, is_anonymous):
3815    if is_anonymous and not tf2.enabled():
3816      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
3817    default_val = -1
3818    keys = constant_op.constant(["brain", "salad", "surgery"])
3819    values = constant_op.constant([0, 1, 2], dtypes.int64)
3820    table = lookup_ops.MutableHashTable(
3821        dtypes.string,
3822        dtypes.int64,
3823        default_val,
3824        experimental_is_anonymous=is_anonymous)
3825
3826    self.evaluate(table.insert(keys, values))
3827    self.assertAllEqual(3, self.evaluate(table.size()))
3828
3829    input_string = constant_op.constant([["brain", "salad"],
3830                                         ["tank", "tarkus"]])
3831    output = table.lookup(input_string)
3832    self.assertAllEqual([2, 2], output.get_shape())
3833
3834    result = self.evaluate(output)
3835    self.assertAllEqual([[0, 1], [-1, -1]], result)
3836
3837  def testMutableHashTableFindWithInvalidShapeDefaultValue(self, is_anonymous):
3838    default_val = [-1, -1]
3839    table = lookup_ops.MutableHashTable(
3840        dtypes.string,
3841        dtypes.int64,
3842        default_val,
3843        experimental_is_anonymous=is_anonymous)
3844
3845    input_string = constant_op.constant([["brain", "salad"], ["tank",
3846                                                              "tarkus"]])
3847
3848    invalid_default_val = constant_op.constant(
3849        [[-2, -3], [-4, -5], [-6, -7], [-8, -9]], dtypes.int64)
3850
3851    with self.assertRaisesRegex(
3852        (ValueError, errors_impl.InvalidArgumentError),
3853        "Expected shape \[2\] or \[2,2,2\] for default value, got \[4,2]"):
3854      self.evaluate(table.lookup(input_string, invalid_default_val))
3855
3856    invalid_default_val = constant_op.constant([[[-2, -3], [-4, -5]]],
3857                                               dtypes.int64)
3858    with self.assertRaisesRegex(
3859        (ValueError, errors_impl.InvalidArgumentError),
3860        "Expected shape \[2\] or \[2,2,2\] for default value, got \[1,2,2\]"):
3861      self.evaluate(table.lookup(input_string, invalid_default_val))
3862
3863  def testMutableHashTableFindHighRankScalarWithDynamicDefaultValue(
3864      self, is_anonymous):
3865    if is_anonymous and not tf2.enabled():
3866      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
3867    default_val = -1
3868    keys = constant_op.constant(["brain", "salad", "surgery"])
3869    values = constant_op.constant([0, 1, 2], dtypes.int64)
3870    table = lookup_ops.MutableHashTable(
3871        dtypes.string,
3872        dtypes.int64,
3873        default_val,
3874        experimental_is_anonymous=is_anonymous)
3875
3876    self.evaluate(table.insert(keys, values))
3877    self.assertAllEqual(3, self.evaluate(table.size()))
3878
3879    input_string = constant_op.constant([["brain", "salad"], ["tank",
3880                                                              "tarkus"]])
3881
3882    dynamic_default_val = constant_op.constant([[-2, -3], [-4, -5]],
3883                                               dtypes.int64)
3884    output = table.lookup(input_string, dynamic_default_val)
3885    self.assertAllEqual([2, 2], output.get_shape())
3886
3887    result = self.evaluate(output)
3888    self.assertAllEqual([[0, 1], [-4, -5]], result)
3889
3890  def testMutableHashTableFindHighRankVectorWithDynamicDefaultValue(
3891      self, is_anonymous):
3892    if is_anonymous and not tf2.enabled():
3893      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
3894    default_val = [-1, -1]
3895    keys = constant_op.constant(["brain", "salad", "surgery"])
3896    values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
3897    table = lookup_ops.MutableHashTable(
3898        dtypes.string,
3899        dtypes.int64,
3900        default_val,
3901        experimental_is_anonymous=is_anonymous)
3902
3903    self.evaluate(table.insert(keys, values))
3904    self.assertAllEqual(3, self.evaluate(table.size()))
3905
3906    input_string = constant_op.constant([["brain", "salad"], ["tank",
3907                                                              "tarkus"]])
3908
3909    dynamic_default_val = constant_op.constant(
3910        [[[-2, -3], [-4, -5]], [[-6, -7], [-8, -9]]], dtypes.int64)
3911    output = table.lookup(input_string, dynamic_default_val)
3912    self.assertAllEqual([2, 2, 2], output.get_shape())
3913
3914    result = self.evaluate(output)
3915    self.assertAllEqual([[[0, 1], [2, 3]], [[-6, -7], [-8, -9]]], result)
3916
3917  def testMutableHashTableInsertHighRank(self, is_anonymous):
3918    if is_anonymous and not tf2.enabled():
3919      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
3920    default_val = -1
3921    keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])
3922    values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64)
3923    table = lookup_ops.MutableHashTable(
3924        dtypes.string,
3925        dtypes.int64,
3926        default_val,
3927        experimental_is_anonymous=is_anonymous)
3928
3929    self.evaluate(table.insert(keys, values))
3930    self.assertAllEqual(4, self.evaluate(table.size()))
3931
3932    input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"])
3933    output = table.lookup(input_string)
3934
3935    result = self.evaluate(output)
3936    self.assertAllEqual([0, 1, 3, -1], result)
3937
3938  def testMutableHashTableRemoveHighRank(self, is_anonymous):
3939    if is_anonymous and not tf2.enabled():
3940      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
3941    default_val = -1
3942    keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])
3943    values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64)
3944    table = lookup_ops.MutableHashTable(
3945        dtypes.string,
3946        dtypes.int64,
3947        default_val,
3948        experimental_is_anonymous=is_anonymous)
3949
3950    self.evaluate(table.insert(keys, values))
3951    self.assertAllEqual(4, self.evaluate(table.size()))
3952
3953    remove_string = constant_op.constant(["salad", "tarkus"])
3954    self.evaluate(table.remove(remove_string))
3955    self.assertAllEqual(3, self.evaluate(table.size()))
3956
3957    input_string = constant_op.constant(["brain", "salad", "tank", "tarkus"])
3958    output = table.lookup(input_string)
3959
3960    result = self.evaluate(output)
3961    self.assertAllEqual([0, -1, 3, -1], result)
3962
3963  def testMutableHashTableOfTensorsFindHighRank(self, is_anonymous):
3964    if is_anonymous and not tf2.enabled():
3965      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
3966    default_val = constant_op.constant([-1, -1, -1], dtypes.int64)
3967    keys = constant_op.constant(["brain", "salad", "surgery"])
3968    values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]],
3969                                  dtypes.int64)
3970    table = lookup_ops.MutableHashTable(
3971        dtypes.string,
3972        dtypes.int64,
3973        default_val,
3974        experimental_is_anonymous=is_anonymous)
3975
3976    self.evaluate(table.insert(keys, values))
3977    self.assertAllEqual(3, self.evaluate(table.size()))
3978
3979    input_string = constant_op.constant([["brain", "salad"],
3980                                         ["tank", "tarkus"]])
3981    output = table.lookup(input_string)
3982    self.assertAllEqual([2, 2, 3], output.get_shape())
3983
3984    result = self.evaluate(output)
3985    self.assertAllEqual(
3986        [[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result)
3987
3988  def testMutableHashTableOfTensorsRemoveHighRank(self, is_anonymous):
3989    if is_anonymous and not tf2.enabled():
3990      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
3991    default_val = constant_op.constant([-1, -1, -1], dtypes.int64)
3992    keys = constant_op.constant(["brain", "salad", "surgery"])
3993    values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]],
3994                                  dtypes.int64)
3995    table = lookup_ops.MutableHashTable(
3996        dtypes.string,
3997        dtypes.int64,
3998        default_val,
3999        experimental_is_anonymous=is_anonymous)
4000
4001    self.evaluate(table.insert(keys, values))
4002    self.assertAllEqual(3, self.evaluate(table.size()))
4003
4004    remove_string = constant_op.constant([["brain", "tank"]])
4005    self.evaluate(table.remove(remove_string))
4006    self.assertAllEqual(2, self.evaluate(table.size()))
4007
4008    input_string = constant_op.constant([["brain", "salad"],
4009                                         ["surgery", "tank"]])
4010    output = table.lookup(input_string)
4011    self.assertAllEqual([2, 2, 3], output.get_shape())
4012
4013    result = self.evaluate(output)
4014    self.assertAllEqual(
4015        [[[-1, -1, -1], [2, 3, 4]], [[4, 5, 6], [-1, -1, -1]]], result)
4016
4017  def testMultipleMutableHashTables(self, is_anonymous):
4018    if is_anonymous and not tf2.enabled():
4019      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
4020    default_val = -1
4021    keys = constant_op.constant(["brain", "salad", "surgery"])
4022    values = constant_op.constant([0, 1, 2], dtypes.int64)
4023
4024    table1 = lookup_ops.MutableHashTable(
4025        dtypes.string,
4026        dtypes.int64,
4027        default_val,
4028        experimental_is_anonymous=is_anonymous)
4029    table2 = lookup_ops.MutableHashTable(
4030        dtypes.string,
4031        dtypes.int64,
4032        default_val,
4033        experimental_is_anonymous=is_anonymous)
4034    table3 = lookup_ops.MutableHashTable(
4035        dtypes.string,
4036        dtypes.int64,
4037        default_val,
4038        experimental_is_anonymous=is_anonymous)
4039    self.evaluate(table1.insert(keys, values))
4040    self.evaluate(table2.insert(keys, values))
4041    self.evaluate(table3.insert(keys, values))
4042
4043    self.assertAllEqual(3, self.evaluate(table1.size()))
4044    self.assertAllEqual(3, self.evaluate(table2.size()))
4045    self.assertAllEqual(3, self.evaluate(table3.size()))
4046
4047    input_string = constant_op.constant(["brain", "salad", "tank"])
4048    output1 = table1.lookup(input_string)
4049    output2 = table2.lookup(input_string)
4050    output3 = table3.lookup(input_string)
4051
4052    out1, out2, out3 = self.evaluate([output1, output2, output3])
4053    self.assertAllEqual([0, 1, -1], out1)
4054    self.assertAllEqual([0, 1, -1], out2)
4055    self.assertAllEqual([0, 1, -1], out3)
4056
4057  def testMutableHashTableWithTensorDefault(self, is_anonymous):
4058    if is_anonymous and not tf2.enabled():
4059      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
4060    default_val = constant_op.constant(-1, dtypes.int64)
4061    keys = constant_op.constant(["brain", "salad", "surgery"])
4062    values = constant_op.constant([0, 1, 2], dtypes.int64)
4063    table = lookup_ops.MutableHashTable(
4064        dtypes.string,
4065        dtypes.int64,
4066        default_val,
4067        experimental_is_anonymous=is_anonymous)
4068
4069    self.evaluate(table.insert(keys, values))
4070    self.assertAllEqual(3, self.evaluate(table.size()))
4071
4072    input_string = constant_op.constant(["brain", "salad", "tank"])
4073    output = table.lookup(input_string)
4074
4075    result = self.evaluate(output)
4076    self.assertAllEqual([0, 1, -1], result)
4077
4078  def testSignatureMismatch(self, is_anonymous):
4079    if is_anonymous and not tf2.enabled():
4080      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
4081    default_val = -1
4082    keys = constant_op.constant(["brain", "salad", "surgery"])
4083    values = constant_op.constant([0, 1, 2], dtypes.int64)
4084    table = lookup_ops.MutableHashTable(
4085        dtypes.string,
4086        dtypes.int64,
4087        default_val,
4088        experimental_is_anonymous=is_anonymous)
4089
4090    # insert with keys of the wrong type
4091    with self.assertRaises(ValueError):
4092      self.evaluate(table.insert(constant_op.constant([4, 5, 6]), values))
4093
4094    # insert with values of the wrong type
4095    with self.assertRaises(ValueError):
4096      self.evaluate(table.insert(keys, constant_op.constant(["a", "b", "c"])))
4097
4098    self.assertAllEqual(0, self.evaluate(table.size()))
4099
4100    self.evaluate(table.insert(keys, values))
4101    self.assertAllEqual(3, self.evaluate(table.size()))
4102
4103    input_string_ref = variables.Variable("brain")
4104    input_int64_ref = variables.Variable(-1, dtype=dtypes.int64)
4105    self.evaluate(variables.global_variables_initializer())
4106
4107    # Ref types do not produce an insert signature mismatch.
4108    self.evaluate(table.insert(input_string_ref, input_int64_ref))
4109    self.assertAllEqual(3, self.evaluate(table.size()))
4110
4111    # Ref types do not produce a lookup signature mismatch.
4112    self.assertEqual(-1, self.evaluate(table.lookup(input_string_ref)))
4113
4114    # lookup with keys of the wrong type
4115    input_string = constant_op.constant([1, 2, 3], dtypes.int64)
4116    with self.assertRaises(ValueError):
4117      self.evaluate(table.lookup(input_string))
4118
4119    # default value of the wrong type
4120    with self.assertRaises(TypeError):
4121      lookup_ops.MutableHashTable(
4122          dtypes.string,
4123          dtypes.int64,
4124          "UNK",
4125          experimental_is_anonymous=is_anonymous)
4126
4127  def testMutableHashTableStringFloat(self, is_anonymous):
4128    if is_anonymous and not tf2.enabled():
4129      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
4130    default_val = -1.5
4131    keys = constant_op.constant(["brain", "salad", "surgery"])
4132    values = constant_op.constant([0, 1.1, 2.2], dtypes.float32)
4133    table = lookup_ops.MutableHashTable(
4134        dtypes.string,
4135        dtypes.float32,
4136        default_val,
4137        experimental_is_anonymous=is_anonymous)
4138    self.assertAllEqual(0, self.evaluate(table.size()))
4139
4140    self.evaluate(table.insert(keys, values))
4141    self.assertAllEqual(3, self.evaluate(table.size()))
4142
4143    input_string = constant_op.constant(["brain", "salad", "tank"])
4144    output = table.lookup(input_string)
4145
4146    result = self.evaluate(output)
4147    self.assertAllClose([0, 1.1, default_val], result)
4148
4149  def testMutableHashTableIntFloat(self, is_anonymous):
4150    if is_anonymous and not tf2.enabled():
4151      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
4152    default_val = -1.0
4153    keys = constant_op.constant([3, 7, 0], dtypes.int64)
4154    values = constant_op.constant([7.5, -1.2, 9.9], dtypes.float32)
4155    table = lookup_ops.MutableHashTable(
4156        dtypes.int64,
4157        dtypes.float32,
4158        default_val,
4159        experimental_is_anonymous=is_anonymous)
4160    self.assertAllEqual(0, self.evaluate(table.size()))
4161
4162    self.evaluate(table.insert(keys, values))
4163    self.assertAllEqual(3, self.evaluate(table.size()))
4164
4165    input_string = constant_op.constant([7, 0, 11], dtypes.int64)
4166    output = table.lookup(input_string)
4167
4168    result = self.evaluate(output)
4169    self.assertAllClose([-1.2, 9.9, default_val], result)
4170
4171  def testMutableHashTableInt64String(self, is_anonymous):
4172    if is_anonymous and not tf2.enabled():
4173      self.skipTest(SKIP_ANONYMOUS_IN_TF1_REASON)
4174    default_val = "n/a"
4175    keys = constant_op.constant([0, 1, 2], dtypes.int64)
4176    values = constant_op.constant(["brain", "salad", "surgery"])
4177    table = lookup_ops.MutableHashTable(
4178        dtypes.int64,
4179        dtypes.string,
4180        default_val,
4181        experimental_is_anonymous=is_anonymous)
4182    self.assertAllEqual(0, self.evaluate(table.size()))
4183
4184    self.evaluate(table.insert(keys, values))
4185    self.assertAllEqual(3, self.evaluate(table.size()))
4186
4187    input_string = constant_op.constant([0, 1, 3], dtypes.int64)
4188    output = table.lookup(input_string)
4189
4190    result = self.evaluate(output)
4191    self.assertAllEqual((b"brain", b"salad", b"n/a"), result)
4192
4193  def testExportShapeInference(self, is_anonymous):
4194    default_value = -1
4195    table = lookup_ops.MutableHashTable(
4196        dtypes.int64,
4197        dtypes.int64,
4198        default_value=default_value,
4199        experimental_is_anonymous=is_anonymous)
4200    actual_shapes = [t.shape for t in table.export()]
4201    inferred_shapes = []
4202
4203    @def_function.function
4204    def f():
4205      for t in table.export():
4206        inferred_shapes.append(t.shape)
4207
4208    f()
4209    self.assertLen(actual_shapes, 2)
4210    self.assertLen(inferred_shapes, 2)
4211    self.assertTrue(inferred_shapes[0].is_compatible_with(actual_shapes[0]))
4212    self.assertTrue(inferred_shapes[1].is_compatible_with(actual_shapes[1]))
4213
4214
4215class MutableHashTableBenchmark(test.Benchmark):
4216
4217  def _create_table(self):
4218    return lookup_ops.MutableHashTable(dtypes.int64, dtypes.float32, 0.0)
4219
4220  def benchmark_single_repeated_scalar_insert_scalar(self):
4221    table = self._create_table()
4222    value = variables.Variable(1.0)
4223    insert = table.insert(0, value)
4224    size = table.size()
4225    with session.Session() as sess:
4226      sess.run(value.initializer)
4227      self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=10000)
4228      assert sess.run(size) == 1
4229
4230  def benchmark_many_repeated_scalar_insert_scalar(self):
4231    table = self._create_table()
4232    c = dataset_ops.make_one_shot_iterator(counter.Counter()).get_next()
4233    value = variables.Variable(1.0)
4234    insert = table.insert(c, value)
4235    size = table.size()
4236    with session.Session() as sess:
4237      sess.run(value.initializer)
4238      self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=10000)
4239      assert sess.run(size) >= 10000
4240
4241  def benchmark_single_repeated_batch_32_insert_scalar(self):
4242    table = self._create_table()
4243    value = variables.Variable([1.0] * 32)
4244    insert = table.insert(list(range(32)), value)
4245    size = table.size()
4246    with session.Session() as sess:
4247      sess.run(value.initializer)
4248      self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=1000)
4249      assert sess.run(size) == 32
4250
4251  def benchmark_many_repeated_batch_32_insert_scalar(self):
4252    table = self._create_table()
4253    c = dataset_ops.make_one_shot_iterator(counter.Counter()).get_next()
4254    value = variables.Variable([1.0] * 32)
4255    insert = table.insert(32 * c + list(range(32)), value)
4256    size = table.size()
4257    with session.Session() as sess:
4258      sess.run(value.initializer)
4259      self.run_op_benchmark(sess, insert, burn_iters=10, min_iters=1000)
4260      assert sess.run(size) >= 1000 * 32
4261
4262
4263class DenseHashTableBenchmark(MutableHashTableBenchmark):
4264
4265  def _create_table(self):
4266    return lookup_ops.DenseHashTable(
4267        dtypes.int64,
4268        dtypes.float32,
4269        default_value=0.0,
4270        empty_key=-1,
4271        deleted_key=-2)
4272
4273
4274if __name__ == "__main__":
4275  test.main()
4276