xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/dictionary_ops.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC
2*14675a02SAndroid Build Coastguard Worker#
3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License");
4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License.
5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at
6*14675a02SAndroid Build Coastguard Worker#
7*14675a02SAndroid Build Coastguard Worker#      http://www.apache.org/licenses/LICENSE-2.0
8*14675a02SAndroid Build Coastguard Worker#
9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software
10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS,
11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and
13*14675a02SAndroid Build Coastguard Worker# limitations under the License.
14*14675a02SAndroid Build Coastguard Worker"""Python and TensorFlow functions to work with dictionaries.
15*14675a02SAndroid Build Coastguard Worker
16*14675a02SAndroid Build Coastguard WorkerPlease see fcp/dictionary/dictionary.h for more on this type of
17*14675a02SAndroid Build Coastguard Workerdictionary.
18*14675a02SAndroid Build Coastguard Worker
19*14675a02SAndroid Build Coastguard WorkerPython Classes:
20*14675a02SAndroid Build Coastguard Worker
21*14675a02SAndroid Build Coastguard Worker* `Dictionary`: A Python analogue to fcp/dictionary/dictionary.h
22*14675a02SAndroid Build Coastguard Worker    that includes additional helpers for dictionary construction.
23*14675a02SAndroid Build Coastguard Worker
24*14675a02SAndroid Build Coastguard WorkerTensorFlow ops:
25*14675a02SAndroid Build Coastguard Worker
26*14675a02SAndroid Build Coastguard Worker*  dictionary_size
27*14675a02SAndroid Build Coastguard Worker   Queries the size of a dictionary.
28*14675a02SAndroid Build Coastguard Worker
29*14675a02SAndroid Build Coastguard Worker*  dictionary_lookup
30*14675a02SAndroid Build Coastguard Worker   Looks up ids for string tokens in the dictionary.
31*14675a02SAndroid Build Coastguard Worker
32*14675a02SAndroid Build Coastguard Worker*  dictionary_reverse_lookup
33*14675a02SAndroid Build Coastguard Worker   Looks up string tokens from ids in the dictionary.
34*14675a02SAndroid Build Coastguard Worker
35*14675a02SAndroid Build Coastguard WorkerCanonical use (note that the dictionary is known at graph construction time):
36*14675a02SAndroid Build Coastguard Worker    dictionary = Dictionary.from_tokens(
37*14675a02SAndroid Build Coastguard Worker       tokens=['some', 'token', 'list'], unk_id=0,
38*14675a02SAndroid Build Coastguard Worker       vocabulary_type=VocabularyType.TOKEN_INDEX)
39*14675a02SAndroid Build Coastguard Worker
40*14675a02SAndroid Build Coastguard Worker    with tf.Graph().as_default():
41*14675a02SAndroid Build Coastguard Worker      tokens = tf.compat.v1.placeholder(tf.String, ...)  # Tokens to look up.
42*14675a02SAndroid Build Coastguard Worker      ids = dictionary_lookup(
43*14675a02SAndroid Build Coastguard Worker          tokens, dictionary.dictionary_description_proto)
44*14675a02SAndroid Build Coastguard Worker"""
45*14675a02SAndroid Build Coastguard Worker
46*14675a02SAndroid Build Coastguard Workerimport collections
47*14675a02SAndroid Build Coastguard Workerimport enum
48*14675a02SAndroid Build Coastguard Worker
49*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf
50*14675a02SAndroid Build Coastguard Worker
51*14675a02SAndroid Build Coastguard Workerfrom fcp.dictionary.dictionary_pb2 import DictionaryDescription  # pylint: disable=g-importing-member
52*14675a02SAndroid Build Coastguard Workerfrom fcp.tensorflow.gen_dictionary_ops import dictionary_lookup
53*14675a02SAndroid Build Coastguard Workerfrom fcp.tensorflow.gen_dictionary_ops import dictionary_reverse_lookup
54*14675a02SAndroid Build Coastguard Workerfrom fcp.tensorflow.gen_dictionary_ops import dictionary_size
55*14675a02SAndroid Build Coastguard Worker
56*14675a02SAndroid Build Coastguard Worker_dictionary_ops = tf.load_op_library(
57*14675a02SAndroid Build Coastguard Worker    tf.compat.v1.resource_loader.get_path_to_datafile('./_dictionary_ops.so'))
58*14675a02SAndroid Build Coastguard Worker
59*14675a02SAndroid Build Coastguard Worker
60*14675a02SAndroid Build Coastguard Workerdef ignore_ids_mask(token_ids, ignore_ids, name=None):
61*14675a02SAndroid Build Coastguard Worker  """Creates a bool mask with True everywhere token_ids is not in ignore_ids."""
62*14675a02SAndroid Build Coastguard Worker  with tf.op_scope([token_ids, ignore_ids], name, 'ignore_ids_mask'):
63*14675a02SAndroid Build Coastguard Worker    # Yay broadcasting
64*14675a02SAndroid Build Coastguard Worker    all_check = tf.not_equal(tf.expand_dims(token_ids, -1), ignore_ids)
65*14675a02SAndroid Build Coastguard Worker    check = tf.reduce_all(all_check, reduction_indices=tf.rank(all_check) - 1)
66*14675a02SAndroid Build Coastguard Worker    check.set_shape(token_ids.get_shape())
67*14675a02SAndroid Build Coastguard Worker    return check
68*14675a02SAndroid Build Coastguard Worker
69*14675a02SAndroid Build Coastguard Worker
70*14675a02SAndroid Build Coastguard Workerdef mask_and_replace_padding(token_ids,
71*14675a02SAndroid Build Coastguard Worker                             lengths,
72*14675a02SAndroid Build Coastguard Worker                             eos_id=None,
73*14675a02SAndroid Build Coastguard Worker                             special_tokens=(),
74*14675a02SAndroid Build Coastguard Worker                             name=None):
75*14675a02SAndroid Build Coastguard Worker  """Creates a mask of valid tokens and sets padded values in id space.
76*14675a02SAndroid Build Coastguard Worker
77*14675a02SAndroid Build Coastguard Worker  This creates a mask the same shape as token_ids with a boolean indicating
78*14675a02SAndroid Build Coastguard Worker  if the id was a valid token (i.e not padding or a special token). If
79*14675a02SAndroid Build Coastguard Worker  provided, this also remaps tokens after lengths to the eos_id. Since the
80*14675a02SAndroid Build Coastguard Worker  dictionary doesn't map tokens to eos or bos ids, it would generally be the
81*14675a02SAndroid Build Coastguard Worker  unknown token id which is not correct if you need to predict the eos.
82*14675a02SAndroid Build Coastguard Worker
83*14675a02SAndroid Build Coastguard Worker  Args:
84*14675a02SAndroid Build Coastguard Worker    token_ids: A matrix `Tensor` of integer ids.
85*14675a02SAndroid Build Coastguard Worker    lengths: A vector `Tensor` of lengths for each row in token_ids.
86*14675a02SAndroid Build Coastguard Worker    eos_id: The end of sequence id, if provided then all token ids after length
87*14675a02SAndroid Build Coastguard Worker      in a row will be replaced with `eos_id`.
88*14675a02SAndroid Build Coastguard Worker    special_tokens: An iterable of special tokens for ids that are not
89*14675a02SAndroid Build Coastguard Worker      considered valid.
90*14675a02SAndroid Build Coastguard Worker    name: Name scope for these ops.
91*14675a02SAndroid Build Coastguard Worker
92*14675a02SAndroid Build Coastguard Worker  Returns:
93*14675a02SAndroid Build Coastguard Worker    token_ids: `token_ids` with all tokens after a row's length replaced with
94*14675a02SAndroid Build Coastguard Worker      eos if provided.
95*14675a02SAndroid Build Coastguard Worker    mask: A bool `Tensor` the same shape as `token_ids` indicating which tokens
96*14675a02SAndroid Build Coastguard Worker      are valid.
97*14675a02SAndroid Build Coastguard Worker  """
98*14675a02SAndroid Build Coastguard Worker  with tf.op_scope([token_ids, lengths, eos_id, special_tokens], name,
99*14675a02SAndroid Build Coastguard Worker                   'mask_and_replace_padding'):
100*14675a02SAndroid Build Coastguard Worker    ranges = tf.range(0, tf.gather(tf.shape(token_ids), 1))
101*14675a02SAndroid Build Coastguard Worker
102*14675a02SAndroid Build Coastguard Worker    # Yay! Broadcasting.
103*14675a02SAndroid Build Coastguard Worker    selected = tf.less(ranges, tf.expand_dims(lengths, -1))
104*14675a02SAndroid Build Coastguard Worker
105*14675a02SAndroid Build Coastguard Worker    if eos_id is not None:
106*14675a02SAndroid Build Coastguard Worker      token_ids = tf.where(
107*14675a02SAndroid Build Coastguard Worker          selected, token_ids,
108*14675a02SAndroid Build Coastguard Worker          tf.fill(
109*14675a02SAndroid Build Coastguard Worker              tf.shape(token_ids), tf.constant(eos_id, dtype=token_ids.dtype)))
110*14675a02SAndroid Build Coastguard Worker    if special_tokens:
111*14675a02SAndroid Build Coastguard Worker      mask = tf.logical_and(
112*14675a02SAndroid Build Coastguard Worker          ignore_ids_mask(token_ids, special_tokens), selected)
113*14675a02SAndroid Build Coastguard Worker    else:
114*14675a02SAndroid Build Coastguard Worker      mask = selected
115*14675a02SAndroid Build Coastguard Worker    return token_ids, mask
116*14675a02SAndroid Build Coastguard Worker
117*14675a02SAndroid Build Coastguard Workertf.no_gradient('DictionarySize')
118*14675a02SAndroid Build Coastguard Workertf.no_gradient('DictionaryLookup')
119*14675a02SAndroid Build Coastguard Workertf.no_gradient('DictionaryReverseLookup')
120*14675a02SAndroid Build Coastguard Worker
121*14675a02SAndroid Build Coastguard Worker
122*14675a02SAndroid Build Coastguard Workerclass VocabularyType(enum.Enum):
123*14675a02SAndroid Build Coastguard Worker  """Valid vocabulary types for Dictionary construction.
124*14675a02SAndroid Build Coastguard Worker
125*14675a02SAndroid Build Coastguard Worker  TOKEN_INDEX: dictionary.dictionary_description contains an embedded map of
126*14675a02SAndroid Build Coastguard Worker    string names stored in order with ids assigned starting from the lowest
127*14675a02SAndroid Build Coastguard Worker    non-special id. Preserves order but is not compact.
128*14675a02SAndroid Build Coastguard Worker  """
129*14675a02SAndroid Build Coastguard Worker  TOKEN_INDEX = 3
130*14675a02SAndroid Build Coastguard Worker
131*14675a02SAndroid Build Coastguard Worker
132*14675a02SAndroid Build Coastguard Workerclass Dictionary(object):
133*14675a02SAndroid Build Coastguard Worker  """Utility for working with fcp/dictionary/ via TensorFlow."""
134*14675a02SAndroid Build Coastguard Worker
135*14675a02SAndroid Build Coastguard Worker  def __init__(
136*14675a02SAndroid Build Coastguard Worker      self,
137*14675a02SAndroid Build Coastguard Worker      dictionary_description
138*14675a02SAndroid Build Coastguard Worker  ):
139*14675a02SAndroid Build Coastguard Worker    """Creates a dictionary from a dictionary_description.
140*14675a02SAndroid Build Coastguard Worker
141*14675a02SAndroid Build Coastguard Worker    Use static from_* constructor methods for building dictionaries from
142*14675a02SAndroid Build Coastguard Worker    common data types.
143*14675a02SAndroid Build Coastguard Worker
144*14675a02SAndroid Build Coastguard Worker    Args:
145*14675a02SAndroid Build Coastguard Worker      dictionary_description: A `dictionary_pb2.DictionaryDescription`
146*14675a02SAndroid Build Coastguard Worker        describing the dictionary.
147*14675a02SAndroid Build Coastguard Worker
148*14675a02SAndroid Build Coastguard Worker    Raises:
149*14675a02SAndroid Build Coastguard Worker      ValueError: An invalid dictionary description.
150*14675a02SAndroid Build Coastguard Worker    """
151*14675a02SAndroid Build Coastguard Worker    if not isinstance(dictionary_description, DictionaryDescription):
152*14675a02SAndroid Build Coastguard Worker      raise ValueError('Expected a DictionaryDescription')
153*14675a02SAndroid Build Coastguard Worker    if not dictionary_description.HasField('vocabulary'):
154*14675a02SAndroid Build Coastguard Worker      raise ValueError('dictionary_description has no vocabulary')
155*14675a02SAndroid Build Coastguard Worker
156*14675a02SAndroid Build Coastguard Worker    self._dictionary_description = dictionary_description
157*14675a02SAndroid Build Coastguard Worker
158*14675a02SAndroid Build Coastguard Worker    # Lazily constructed fields for lookup.
159*14675a02SAndroid Build Coastguard Worker    self._lookup_graph = None
160*14675a02SAndroid Build Coastguard Worker    self._lookup_placeholder = None
161*14675a02SAndroid Build Coastguard Worker    self._lookup_result = None
162*14675a02SAndroid Build Coastguard Worker    self._reverse_lookup_placeholder = None
163*14675a02SAndroid Build Coastguard Worker    self._reverse_lookup_result = None
164*14675a02SAndroid Build Coastguard Worker
165*14675a02SAndroid Build Coastguard Worker  @classmethod
166*14675a02SAndroid Build Coastguard Worker  def from_tokens(
167*14675a02SAndroid Build Coastguard Worker      cls,
168*14675a02SAndroid Build Coastguard Worker      tokens,
169*14675a02SAndroid Build Coastguard Worker      bos_id=None,
170*14675a02SAndroid Build Coastguard Worker      eos_id=None,
171*14675a02SAndroid Build Coastguard Worker      unk_id=None,
172*14675a02SAndroid Build Coastguard Worker      output_blocklist_tokens=None,
173*14675a02SAndroid Build Coastguard Worker      output_size=None,
174*14675a02SAndroid Build Coastguard Worker      vocabulary_type=VocabularyType.TOKEN_INDEX
175*14675a02SAndroid Build Coastguard Worker  ):
176*14675a02SAndroid Build Coastguard Worker    """Creates a dictionary from a provided list of tokens.
177*14675a02SAndroid Build Coastguard Worker
178*14675a02SAndroid Build Coastguard Worker    The id mappings to token ids depend on the vocabulary_type requested.
179*14675a02SAndroid Build Coastguard Worker
180*14675a02SAndroid Build Coastguard Worker    NB: the special tokens must be the first ids [0, num-specials)
181*14675a02SAndroid Build Coastguard Worker
182*14675a02SAndroid Build Coastguard Worker    Args:
183*14675a02SAndroid Build Coastguard Worker      tokens: An unordered iterable of tokens for the dictionary.
184*14675a02SAndroid Build Coastguard Worker      bos_id: Token id for start of sequence.
185*14675a02SAndroid Build Coastguard Worker      eos_id: Token id for end of sequence.
186*14675a02SAndroid Build Coastguard Worker      unk_id: Token id for unknown words.
187*14675a02SAndroid Build Coastguard Worker      output_blocklist_tokens: A list of vocabulary tokens that should be
188*14675a02SAndroid Build Coastguard Worker        filtered from predictions (e.g., punctuation, bad words etc.).
189*14675a02SAndroid Build Coastguard Worker      output_size: If a positive integer, tokens with ids greater than this are
190*14675a02SAndroid Build Coastguard Worker        automatically added to the output blocklist.
191*14675a02SAndroid Build Coastguard Worker      vocabulary_type: `VocabularyType` to use, defaults to TOKEN_INDEX.
192*14675a02SAndroid Build Coastguard Worker
193*14675a02SAndroid Build Coastguard Worker    Returns:
194*14675a02SAndroid Build Coastguard Worker       A `Dictionary` instance.
195*14675a02SAndroid Build Coastguard Worker
196*14675a02SAndroid Build Coastguard Worker    Raises:
197*14675a02SAndroid Build Coastguard Worker      ValueError: If the special tokens don't have the lowest ids.
198*14675a02SAndroid Build Coastguard Worker      ValueError: If there are duplicates in tokens.
199*14675a02SAndroid Build Coastguard Worker    """
200*14675a02SAndroid Build Coastguard Worker    dictionary_description = DictionaryDescription()
201*14675a02SAndroid Build Coastguard Worker
202*14675a02SAndroid Build Coastguard Worker    # Special ids.
203*14675a02SAndroid Build Coastguard Worker    special_ids = []
204*14675a02SAndroid Build Coastguard Worker    if unk_id is not None:
205*14675a02SAndroid Build Coastguard Worker      dictionary_description.special_ids.unk = unk_id
206*14675a02SAndroid Build Coastguard Worker      special_ids.append(unk_id)
207*14675a02SAndroid Build Coastguard Worker    if bos_id is not None:
208*14675a02SAndroid Build Coastguard Worker      dictionary_description.special_ids.bos = bos_id
209*14675a02SAndroid Build Coastguard Worker      special_ids.append(bos_id)
210*14675a02SAndroid Build Coastguard Worker    if eos_id is not None:
211*14675a02SAndroid Build Coastguard Worker      dictionary_description.special_ids.eos = eos_id
212*14675a02SAndroid Build Coastguard Worker      special_ids.append(eos_id)
213*14675a02SAndroid Build Coastguard Worker    if sorted(special_ids) != list(range(len(special_ids))):
214*14675a02SAndroid Build Coastguard Worker      raise ValueError(
215*14675a02SAndroid Build Coastguard Worker          'Special ids must be the first items of the dictionary starting at 0'
216*14675a02SAndroid Build Coastguard Worker          'or None. eos: %s; bos %s; unk: %s' % (eos_id, bos_id, unk_id))
217*14675a02SAndroid Build Coastguard Worker
218*14675a02SAndroid Build Coastguard Worker    # Vocabulary.
219*14675a02SAndroid Build Coastguard Worker    if len(tokens) != len(set(tokens)):
220*14675a02SAndroid Build Coastguard Worker      raise ValueError('Duplicate tokens provided')
221*14675a02SAndroid Build Coastguard Worker    for token in tokens:
222*14675a02SAndroid Build Coastguard Worker      if not isinstance(token, (str, bytes)):
223*14675a02SAndroid Build Coastguard Worker        raise ValueError('Bad type in tokens %s' % token)
224*14675a02SAndroid Build Coastguard Worker    if vocabulary_type == VocabularyType.TOKEN_INDEX:
225*14675a02SAndroid Build Coastguard Worker      for token in tokens:
226*14675a02SAndroid Build Coastguard Worker        dictionary_description.vocabulary.index.token.append(token)
227*14675a02SAndroid Build Coastguard Worker    else:
228*14675a02SAndroid Build Coastguard Worker      raise AssertionError('Unsupported vocabulary_type: %s' % vocabulary_type)
229*14675a02SAndroid Build Coastguard Worker
230*14675a02SAndroid Build Coastguard Worker    # Output blocklist.
231*14675a02SAndroid Build Coastguard Worker    output_blocklist_tokens = list(output_blocklist_tokens or [])
232*14675a02SAndroid Build Coastguard Worker    if output_size:
233*14675a02SAndroid Build Coastguard Worker      assert output_size >= len(special_ids), (
234*14675a02SAndroid Build Coastguard Worker          'Cannot blocklist special tokens via output_size.')
235*14675a02SAndroid Build Coastguard Worker      assert isinstance(tokens, list)  # Make sure order preserving pre-slice.
236*14675a02SAndroid Build Coastguard Worker      output_blocklist_tokens.extend(tokens[output_size - len(special_ids):])
237*14675a02SAndroid Build Coastguard Worker    for token in output_blocklist_tokens:
238*14675a02SAndroid Build Coastguard Worker      assert token in tokens, "Unexpected blocklist token: '%s'" % token
239*14675a02SAndroid Build Coastguard Worker    with tf.compat.v1.Session(graph=tf.Graph()) as sess:
240*14675a02SAndroid Build Coastguard Worker      output_blocklist_ids = sess.run(
241*14675a02SAndroid Build Coastguard Worker          dictionary_lookup(output_blocklist_tokens,
242*14675a02SAndroid Build Coastguard Worker                            dictionary_description.SerializeToString()))
243*14675a02SAndroid Build Coastguard Worker    dictionary_description.output_blocklist_ids.id.extend(
244*14675a02SAndroid Build Coastguard Worker        sorted(output_blocklist_ids))
245*14675a02SAndroid Build Coastguard Worker    assert (len(set(dictionary_description.output_blocklist_ids.id)) == len(
246*14675a02SAndroid Build Coastguard Worker        output_blocklist_tokens)), 'blocklist contains dups or unks?'
247*14675a02SAndroid Build Coastguard Worker
248*14675a02SAndroid Build Coastguard Worker    # Return completed dictionary.
249*14675a02SAndroid Build Coastguard Worker    return cls(
250*14675a02SAndroid Build Coastguard Worker        dictionary_description=dictionary_description)
251*14675a02SAndroid Build Coastguard Worker
252*14675a02SAndroid Build Coastguard Worker  @classmethod
253*14675a02SAndroid Build Coastguard Worker  def from_dictionary_description(cls,
254*14675a02SAndroid Build Coastguard Worker                                  dictionary_description):
255*14675a02SAndroid Build Coastguard Worker    """Returns a Dictionary from a DictionaryDescription."""
256*14675a02SAndroid Build Coastguard Worker    return cls(
257*14675a02SAndroid Build Coastguard Worker        dictionary_description=dictionary_description)
258*14675a02SAndroid Build Coastguard Worker
259*14675a02SAndroid Build Coastguard Worker  def _get_lookup_graph(self):
260*14675a02SAndroid Build Coastguard Worker    """Returns a graph to use for lookup, reverse lookup, and size queries."""
261*14675a02SAndroid Build Coastguard Worker    if self._lookup_graph is None:
262*14675a02SAndroid Build Coastguard Worker      self._lookup_graph = tf.Graph()
263*14675a02SAndroid Build Coastguard Worker      serialized_description_proto = (
264*14675a02SAndroid Build Coastguard Worker          self._dictionary_description.SerializeToString())
265*14675a02SAndroid Build Coastguard Worker      with self._lookup_graph.as_default():
266*14675a02SAndroid Build Coastguard Worker        self._lookup_placeholder = tf.compat.v1.placeholder(
267*14675a02SAndroid Build Coastguard Worker            tf.string, shape=None)
268*14675a02SAndroid Build Coastguard Worker        self._reverse_lookup_placeholder = tf.compat.v1.placeholder(
269*14675a02SAndroid Build Coastguard Worker            tf.int64, shape=None)
270*14675a02SAndroid Build Coastguard Worker
271*14675a02SAndroid Build Coastguard Worker        # Use Dictionary(Op) (without blob) variants.
272*14675a02SAndroid Build Coastguard Worker        self._lookup_result = dictionary_lookup(
273*14675a02SAndroid Build Coastguard Worker            self._lookup_placeholder,
274*14675a02SAndroid Build Coastguard Worker            dictionary_description_proto=serialized_description_proto)
275*14675a02SAndroid Build Coastguard Worker        self._reverse_lookup_result = dictionary_reverse_lookup(
276*14675a02SAndroid Build Coastguard Worker            self._reverse_lookup_placeholder,
277*14675a02SAndroid Build Coastguard Worker            dictionary_description_proto=serialized_description_proto)
278*14675a02SAndroid Build Coastguard Worker        self._size_result = dictionary_size(
279*14675a02SAndroid Build Coastguard Worker            dictionary_description_proto=serialized_description_proto)
280*14675a02SAndroid Build Coastguard Worker
281*14675a02SAndroid Build Coastguard Worker    return self._lookup_graph
282*14675a02SAndroid Build Coastguard Worker
283*14675a02SAndroid Build Coastguard Worker  def lookup(self, tokens):
284*14675a02SAndroid Build Coastguard Worker    """Maps a list of tokens to a list of ids.
285*14675a02SAndroid Build Coastguard Worker
286*14675a02SAndroid Build Coastguard Worker    Args:
287*14675a02SAndroid Build Coastguard Worker      tokens: A list of tokens to lookup.
288*14675a02SAndroid Build Coastguard Worker
289*14675a02SAndroid Build Coastguard Worker    Returns:
290*14675a02SAndroid Build Coastguard Worker      A list of token ids of the same size.
291*14675a02SAndroid Build Coastguard Worker
292*14675a02SAndroid Build Coastguard Worker    Raises:
293*14675a02SAndroid Build Coastguard Worker      ValueError: If tokens is not a list.
294*14675a02SAndroid Build Coastguard Worker    """
295*14675a02SAndroid Build Coastguard Worker    if not isinstance(tokens, list):
296*14675a02SAndroid Build Coastguard Worker      raise ValueError('lookup expected a list of tokens.')
297*14675a02SAndroid Build Coastguard Worker
298*14675a02SAndroid Build Coastguard Worker    with tf.compat.v1.Session(graph=self._get_lookup_graph()) as sess:
299*14675a02SAndroid Build Coastguard Worker      return sess.run(self._lookup_result, {
300*14675a02SAndroid Build Coastguard Worker          self._lookup_placeholder: tokens
301*14675a02SAndroid Build Coastguard Worker      }).tolist()
302*14675a02SAndroid Build Coastguard Worker
303*14675a02SAndroid Build Coastguard Worker  def reverse_lookup(self, ids):
304*14675a02SAndroid Build Coastguard Worker    """Maps a list of ids to tokens.
305*14675a02SAndroid Build Coastguard Worker
306*14675a02SAndroid Build Coastguard Worker    Args:
307*14675a02SAndroid Build Coastguard Worker      ids: A list of ids to map back to tokens.
308*14675a02SAndroid Build Coastguard Worker
309*14675a02SAndroid Build Coastguard Worker    Returns:
310*14675a02SAndroid Build Coastguard Worker      A list of tokens corresponding to those ids.
311*14675a02SAndroid Build Coastguard Worker
312*14675a02SAndroid Build Coastguard Worker    Raises:
313*14675a02SAndroid Build Coastguard Worker      ValueError: If ids is not a list.
314*14675a02SAndroid Build Coastguard Worker    """
315*14675a02SAndroid Build Coastguard Worker    if not isinstance(ids, list):
316*14675a02SAndroid Build Coastguard Worker      raise ValueError('reverse_lookup expected a list of ids.')
317*14675a02SAndroid Build Coastguard Worker    with tf.compat.v1.Session(graph=self._get_lookup_graph()) as sess:
318*14675a02SAndroid Build Coastguard Worker      return list(
319*14675a02SAndroid Build Coastguard Worker          sess.run(self._reverse_lookup_result,
320*14675a02SAndroid Build Coastguard Worker                   {self._reverse_lookup_placeholder: ids}))
321*14675a02SAndroid Build Coastguard Worker
322*14675a02SAndroid Build Coastguard Worker  @property
323*14675a02SAndroid Build Coastguard Worker  def special_ids(self):
324*14675a02SAndroid Build Coastguard Worker    """Returns a list of special token ids."""
325*14675a02SAndroid Build Coastguard Worker    return [t for t in [self.unk_id, self.bos_id, self.eos_id] if t is not None]
326*14675a02SAndroid Build Coastguard Worker
327*14675a02SAndroid Build Coastguard Worker  @property
328*14675a02SAndroid Build Coastguard Worker  def eos_id(self):
329*14675a02SAndroid Build Coastguard Worker    eos_id = self._dictionary_description.special_ids.eos
330*14675a02SAndroid Build Coastguard Worker    return eos_id if eos_id >= 0 else None
331*14675a02SAndroid Build Coastguard Worker
332*14675a02SAndroid Build Coastguard Worker  @property
333*14675a02SAndroid Build Coastguard Worker  def bos_id(self):
334*14675a02SAndroid Build Coastguard Worker    bos_id = self._dictionary_description.special_ids.bos
335*14675a02SAndroid Build Coastguard Worker    return bos_id if bos_id >= 0 else None
336*14675a02SAndroid Build Coastguard Worker
337*14675a02SAndroid Build Coastguard Worker  @property
338*14675a02SAndroid Build Coastguard Worker  def unk_id(self):
339*14675a02SAndroid Build Coastguard Worker    unk_id = self._dictionary_description.special_ids.unk
340*14675a02SAndroid Build Coastguard Worker    return unk_id if unk_id >= 0 else None
341*14675a02SAndroid Build Coastguard Worker
342*14675a02SAndroid Build Coastguard Worker  @property
343*14675a02SAndroid Build Coastguard Worker  def size(self):
344*14675a02SAndroid Build Coastguard Worker    with tf.compat.v1.Session(graph=self._get_lookup_graph()) as sess:
345*14675a02SAndroid Build Coastguard Worker      return sess.run(self._size_result)
346*14675a02SAndroid Build Coastguard Worker
347*14675a02SAndroid Build Coastguard Worker  @property
348*14675a02SAndroid Build Coastguard Worker  def output_blocklist_ids(self):
349*14675a02SAndroid Build Coastguard Worker    return list(self._dictionary_description.output_blocklist_ids.id)
350*14675a02SAndroid Build Coastguard Worker
351*14675a02SAndroid Build Coastguard Worker  @property
352*14675a02SAndroid Build Coastguard Worker  def output_blocklist_tokens(self):
353*14675a02SAndroid Build Coastguard Worker    return self.reverse_lookup(self.output_blocklist_ids)
354*14675a02SAndroid Build Coastguard Worker
355*14675a02SAndroid Build Coastguard Worker  @property
356*14675a02SAndroid Build Coastguard Worker  def tokens(self):
357*14675a02SAndroid Build Coastguard Worker    return self.reverse_lookup(list(range(len(self.special_ids), self.size)))
358*14675a02SAndroid Build Coastguard Worker
359*14675a02SAndroid Build Coastguard Worker  @property
360*14675a02SAndroid Build Coastguard Worker  def dictionary_description_proto(self):
361*14675a02SAndroid Build Coastguard Worker    """Serialized proto containing self.dictionary_description."""
362*14675a02SAndroid Build Coastguard Worker    return self.dictionary_description.SerializeToString()
363*14675a02SAndroid Build Coastguard Worker
364*14675a02SAndroid Build Coastguard Worker  @property
365*14675a02SAndroid Build Coastguard Worker  def dictionary_description(self):
366*14675a02SAndroid Build Coastguard Worker    """Returns the `DictionaryDescription` proto describing this dictionary.
367*14675a02SAndroid Build Coastguard Worker    """
368*14675a02SAndroid Build Coastguard Worker    desc = self._dictionary_description
369*14675a02SAndroid Build Coastguard Worker    return desc
370*14675a02SAndroid Build Coastguard Worker
371*14675a02SAndroid Build Coastguard Worker  def __len__(self):
372*14675a02SAndroid Build Coastguard Worker    return self.size
373