xref: /aosp_15_r20/external/federated-compute/fcp/tensorflow/dictionary_ops.py (revision 14675a029014e728ec732f129a32e299b2da0601)
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Python and TensorFlow functions to work with dictionaries.
15
16Please see fcp/dictionary/dictionary.h for more on this type of
17dictionary.
18
19Python Classes:
20
21* `Dictionary`: A Python analogue to fcp/dictionary/dictionary.h
22    that includes additional helpers for dictionary construction.
23
24TensorFlow ops:
25
26*  dictionary_size
27   Queries the size of a dictionary.
28
29*  dictionary_lookup
30   Looks up ids for string tokens in the dictionary.
31
32*  dictionary_reverse_lookup
33   Looks up string tokens from ids in the dictionary.
34
35Canonical use (note that the dictionary is known at graph construction time):
36    dictionary = Dictionary.from_tokens(
37       tokens=['some', 'token', 'list'], unk_id=0,
38       vocabulary_type=VocabularyType.TOKEN_INDEX)
39
40    with tf.Graph().as_default():
41      tokens = tf.compat.v1.placeholder(tf.String, ...)  # Tokens to look up.
42      ids = dictionary_lookup(
43          tokens, dictionary.dictionary_description_proto)
44"""
45
46import collections
47import enum
48
49import tensorflow as tf
50
51from fcp.dictionary.dictionary_pb2 import DictionaryDescription  # pylint: disable=g-importing-member
52from fcp.tensorflow.gen_dictionary_ops import dictionary_lookup
53from fcp.tensorflow.gen_dictionary_ops import dictionary_reverse_lookup
54from fcp.tensorflow.gen_dictionary_ops import dictionary_size
55
56_dictionary_ops = tf.load_op_library(
57    tf.compat.v1.resource_loader.get_path_to_datafile('./_dictionary_ops.so'))
58
59
60def ignore_ids_mask(token_ids, ignore_ids, name=None):
61  """Creates a bool mask with True everywhere token_ids is not in ignore_ids."""
62  with tf.op_scope([token_ids, ignore_ids], name, 'ignore_ids_mask'):
63    # Yay broadcasting
64    all_check = tf.not_equal(tf.expand_dims(token_ids, -1), ignore_ids)
65    check = tf.reduce_all(all_check, reduction_indices=tf.rank(all_check) - 1)
66    check.set_shape(token_ids.get_shape())
67    return check
68
69
70def mask_and_replace_padding(token_ids,
71                             lengths,
72                             eos_id=None,
73                             special_tokens=(),
74                             name=None):
75  """Creates a mask of valid tokens and sets padded values in id space.
76
77  This creates a mask the same shape as token_ids with a boolean indicating
78  if the id was a valid token (i.e not padding or a special token). If
79  provided, this also remaps tokens after lengths to the eos_id. Since the
80  dictionary doesn't map tokens to eos or bos ids, it would generally be the
81  unknown token id which is not correct if you need to predict the eos.
82
83  Args:
84    token_ids: A matrix `Tensor` of integer ids.
85    lengths: A vector `Tensor` of lengths for each row in token_ids.
86    eos_id: The end of sequence id, if provided then all token ids after length
87      in a row will be replaced with `eos_id`.
88    special_tokens: An iterable of special tokens for ids that are not
89      considered valid.
90    name: Name scope for these ops.
91
92  Returns:
93    token_ids: `token_ids` with all tokens after a row's length replaced with
94      eos if provided.
95    mask: A bool `Tensor` the same shape as `token_ids` indicating which tokens
96      are valid.
97  """
98  with tf.op_scope([token_ids, lengths, eos_id, special_tokens], name,
99                   'mask_and_replace_padding'):
100    ranges = tf.range(0, tf.gather(tf.shape(token_ids), 1))
101
102    # Yay! Broadcasting.
103    selected = tf.less(ranges, tf.expand_dims(lengths, -1))
104
105    if eos_id is not None:
106      token_ids = tf.where(
107          selected, token_ids,
108          tf.fill(
109              tf.shape(token_ids), tf.constant(eos_id, dtype=token_ids.dtype)))
110    if special_tokens:
111      mask = tf.logical_and(
112          ignore_ids_mask(token_ids, special_tokens), selected)
113    else:
114      mask = selected
115    return token_ids, mask
116
117tf.no_gradient('DictionarySize')
118tf.no_gradient('DictionaryLookup')
119tf.no_gradient('DictionaryReverseLookup')
120
121
122class VocabularyType(enum.Enum):
123  """Valid vocabulary types for Dictionary construction.
124
125  TOKEN_INDEX: dictionary.dictionary_description contains an embedded map of
126    string names stored in order with ids assigned starting from the lowest
127    non-special id. Preserves order but is not compact.
128  """
129  TOKEN_INDEX = 3
130
131
132class Dictionary(object):
133  """Utility for working with fcp/dictionary/ via TensorFlow."""
134
135  def __init__(
136      self,
137      dictionary_description
138  ):
139    """Creates a dictionary from a dictionary_description.
140
141    Use static from_* constructor methods for building dictionaries from
142    common data types.
143
144    Args:
145      dictionary_description: A `dictionary_pb2.DictionaryDescription`
146        describing the dictionary.
147
148    Raises:
149      ValueError: An invalid dictionary description.
150    """
151    if not isinstance(dictionary_description, DictionaryDescription):
152      raise ValueError('Expected a DictionaryDescription')
153    if not dictionary_description.HasField('vocabulary'):
154      raise ValueError('dictionary_description has no vocabulary')
155
156    self._dictionary_description = dictionary_description
157
158    # Lazily constructed fields for lookup.
159    self._lookup_graph = None
160    self._lookup_placeholder = None
161    self._lookup_result = None
162    self._reverse_lookup_placeholder = None
163    self._reverse_lookup_result = None
164
165  @classmethod
166  def from_tokens(
167      cls,
168      tokens,
169      bos_id=None,
170      eos_id=None,
171      unk_id=None,
172      output_blocklist_tokens=None,
173      output_size=None,
174      vocabulary_type=VocabularyType.TOKEN_INDEX
175  ):
176    """Creates a dictionary from a provided list of tokens.
177
178    The id mappings to token ids depend on the vocabulary_type requested.
179
180    NB: the special tokens must be the first ids [0, num-specials)
181
182    Args:
183      tokens: An unordered iterable of tokens for the dictionary.
184      bos_id: Token id for start of sequence.
185      eos_id: Token id for end of sequence.
186      unk_id: Token id for unknown words.
187      output_blocklist_tokens: A list of vocabulary tokens that should be
188        filtered from predictions (e.g., punctuation, bad words etc.).
189      output_size: If a positive integer, tokens with ids greater than this are
190        automatically added to the output blocklist.
191      vocabulary_type: `VocabularyType` to use, defaults to TOKEN_INDEX.
192
193    Returns:
194       A `Dictionary` instance.
195
196    Raises:
197      ValueError: If the special tokens don't have the lowest ids.
198      ValueError: If there are duplicates in tokens.
199    """
200    dictionary_description = DictionaryDescription()
201
202    # Special ids.
203    special_ids = []
204    if unk_id is not None:
205      dictionary_description.special_ids.unk = unk_id
206      special_ids.append(unk_id)
207    if bos_id is not None:
208      dictionary_description.special_ids.bos = bos_id
209      special_ids.append(bos_id)
210    if eos_id is not None:
211      dictionary_description.special_ids.eos = eos_id
212      special_ids.append(eos_id)
213    if sorted(special_ids) != list(range(len(special_ids))):
214      raise ValueError(
215          'Special ids must be the first items of the dictionary starting at 0'
216          'or None. eos: %s; bos %s; unk: %s' % (eos_id, bos_id, unk_id))
217
218    # Vocabulary.
219    if len(tokens) != len(set(tokens)):
220      raise ValueError('Duplicate tokens provided')
221    for token in tokens:
222      if not isinstance(token, (str, bytes)):
223        raise ValueError('Bad type in tokens %s' % token)
224    if vocabulary_type == VocabularyType.TOKEN_INDEX:
225      for token in tokens:
226        dictionary_description.vocabulary.index.token.append(token)
227    else:
228      raise AssertionError('Unsupported vocabulary_type: %s' % vocabulary_type)
229
230    # Output blocklist.
231    output_blocklist_tokens = list(output_blocklist_tokens or [])
232    if output_size:
233      assert output_size >= len(special_ids), (
234          'Cannot blocklist special tokens via output_size.')
235      assert isinstance(tokens, list)  # Make sure order preserving pre-slice.
236      output_blocklist_tokens.extend(tokens[output_size - len(special_ids):])
237    for token in output_blocklist_tokens:
238      assert token in tokens, "Unexpected blocklist token: '%s'" % token
239    with tf.compat.v1.Session(graph=tf.Graph()) as sess:
240      output_blocklist_ids = sess.run(
241          dictionary_lookup(output_blocklist_tokens,
242                            dictionary_description.SerializeToString()))
243    dictionary_description.output_blocklist_ids.id.extend(
244        sorted(output_blocklist_ids))
245    assert (len(set(dictionary_description.output_blocklist_ids.id)) == len(
246        output_blocklist_tokens)), 'blocklist contains dups or unks?'
247
248    # Return completed dictionary.
249    return cls(
250        dictionary_description=dictionary_description)
251
252  @classmethod
253  def from_dictionary_description(cls,
254                                  dictionary_description):
255    """Returns a Dictionary from a DictionaryDescription."""
256    return cls(
257        dictionary_description=dictionary_description)
258
259  def _get_lookup_graph(self):
260    """Returns a graph to use for lookup, reverse lookup, and size queries."""
261    if self._lookup_graph is None:
262      self._lookup_graph = tf.Graph()
263      serialized_description_proto = (
264          self._dictionary_description.SerializeToString())
265      with self._lookup_graph.as_default():
266        self._lookup_placeholder = tf.compat.v1.placeholder(
267            tf.string, shape=None)
268        self._reverse_lookup_placeholder = tf.compat.v1.placeholder(
269            tf.int64, shape=None)
270
271        # Use Dictionary(Op) (without blob) variants.
272        self._lookup_result = dictionary_lookup(
273            self._lookup_placeholder,
274            dictionary_description_proto=serialized_description_proto)
275        self._reverse_lookup_result = dictionary_reverse_lookup(
276            self._reverse_lookup_placeholder,
277            dictionary_description_proto=serialized_description_proto)
278        self._size_result = dictionary_size(
279            dictionary_description_proto=serialized_description_proto)
280
281    return self._lookup_graph
282
283  def lookup(self, tokens):
284    """Maps a list of tokens to a list of ids.
285
286    Args:
287      tokens: A list of tokens to lookup.
288
289    Returns:
290      A list of token ids of the same size.
291
292    Raises:
293      ValueError: If tokens is not a list.
294    """
295    if not isinstance(tokens, list):
296      raise ValueError('lookup expected a list of tokens.')
297
298    with tf.compat.v1.Session(graph=self._get_lookup_graph()) as sess:
299      return sess.run(self._lookup_result, {
300          self._lookup_placeholder: tokens
301      }).tolist()
302
303  def reverse_lookup(self, ids):
304    """Maps a list of ids to tokens.
305
306    Args:
307      ids: A list of ids to map back to tokens.
308
309    Returns:
310      A list of tokens corresponding to those ids.
311
312    Raises:
313      ValueError: If ids is not a list.
314    """
315    if not isinstance(ids, list):
316      raise ValueError('reverse_lookup expected a list of ids.')
317    with tf.compat.v1.Session(graph=self._get_lookup_graph()) as sess:
318      return list(
319          sess.run(self._reverse_lookup_result,
320                   {self._reverse_lookup_placeholder: ids}))
321
322  @property
323  def special_ids(self):
324    """Returns a list of special token ids."""
325    return [t for t in [self.unk_id, self.bos_id, self.eos_id] if t is not None]
326
327  @property
328  def eos_id(self):
329    eos_id = self._dictionary_description.special_ids.eos
330    return eos_id if eos_id >= 0 else None
331
332  @property
333  def bos_id(self):
334    bos_id = self._dictionary_description.special_ids.bos
335    return bos_id if bos_id >= 0 else None
336
337  @property
338  def unk_id(self):
339    unk_id = self._dictionary_description.special_ids.unk
340    return unk_id if unk_id >= 0 else None
341
342  @property
343  def size(self):
344    with tf.compat.v1.Session(graph=self._get_lookup_graph()) as sess:
345      return sess.run(self._size_result)
346
347  @property
348  def output_blocklist_ids(self):
349    return list(self._dictionary_description.output_blocklist_ids.id)
350
351  @property
352  def output_blocklist_tokens(self):
353    return self.reverse_lookup(self.output_blocklist_ids)
354
355  @property
356  def tokens(self):
357    return self.reverse_lookup(list(range(len(self.special_ids), self.size)))
358
359  @property
360  def dictionary_description_proto(self):
361    """Serialized proto containing self.dictionary_description."""
362    return self.dictionary_description.SerializeToString()
363
364  @property
365  def dictionary_description(self):
366    """Returns the `DictionaryDescription` proto describing this dictionary.
367    """
368    desc = self._dictionary_description
369    return desc
370
371  def __len__(self):
372    return self.size
373