1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Python and TensorFlow functions to work with dictionaries. 15 16Please see fcp/dictionary/dictionary.h for more on this type of 17dictionary. 18 19Python Classes: 20 21* `Dictionary`: A Python analogue to fcp/dictionary/dictionary.h 22 that includes additional helpers for dictionary construction. 23 24TensorFlow ops: 25 26* dictionary_size 27 Queries the size of a dictionary. 28 29* dictionary_lookup 30 Looks up ids for string tokens in the dictionary. 31 32* dictionary_reverse_lookup 33 Looks up string tokens from ids in the dictionary. 34 35Canonical use (note that the dictionary is known at graph construction time): 36 dictionary = Dictionary.from_tokens( 37 tokens=['some', 'token', 'list'], unk_id=0, 38 vocabulary_type=VocabularyType.TOKEN_INDEX) 39 40 with tf.Graph().as_default(): 41 tokens = tf.compat.v1.placeholder(tf.String, ...) # Tokens to look up. 42 ids = dictionary_lookup( 43 tokens, dictionary.dictionary_description_proto) 44""" 45 46import collections 47import enum 48 49import tensorflow as tf 50 51from fcp.dictionary.dictionary_pb2 import DictionaryDescription # pylint: disable=g-importing-member 52from fcp.tensorflow.gen_dictionary_ops import dictionary_lookup 53from fcp.tensorflow.gen_dictionary_ops import dictionary_reverse_lookup 54from fcp.tensorflow.gen_dictionary_ops import dictionary_size 55 56_dictionary_ops = tf.load_op_library( 57 tf.compat.v1.resource_loader.get_path_to_datafile('./_dictionary_ops.so')) 58 59 60def ignore_ids_mask(token_ids, ignore_ids, name=None): 61 """Creates a bool mask with True everywhere token_ids is not in ignore_ids.""" 62 with tf.op_scope([token_ids, ignore_ids], name, 'ignore_ids_mask'): 63 # Yay broadcasting 64 all_check = tf.not_equal(tf.expand_dims(token_ids, -1), ignore_ids) 65 check = tf.reduce_all(all_check, reduction_indices=tf.rank(all_check) - 1) 66 check.set_shape(token_ids.get_shape()) 67 return check 68 69 70def mask_and_replace_padding(token_ids, 71 lengths, 72 eos_id=None, 73 special_tokens=(), 74 name=None): 75 """Creates a mask of valid tokens and sets padded values in id space. 76 77 This creates a mask the same shape as token_ids with a boolean indicating 78 if the id was a valid token (i.e not padding or a special token). If 79 provided, this also remaps tokens after lengths to the eos_id. Since the 80 dictionary doesn't map tokens to eos or bos ids, it would generally be the 81 unknown token id which is not correct if you need to predict the eos. 82 83 Args: 84 token_ids: A matrix `Tensor` of integer ids. 85 lengths: A vector `Tensor` of lengths for each row in token_ids. 86 eos_id: The end of sequence id, if provided then all token ids after length 87 in a row will be replaced with `eos_id`. 88 special_tokens: An iterable of special tokens for ids that are not 89 considered valid. 90 name: Name scope for these ops. 91 92 Returns: 93 token_ids: `token_ids` with all tokens after a row's length replaced with 94 eos if provided. 95 mask: A bool `Tensor` the same shape as `token_ids` indicating which tokens 96 are valid. 97 """ 98 with tf.op_scope([token_ids, lengths, eos_id, special_tokens], name, 99 'mask_and_replace_padding'): 100 ranges = tf.range(0, tf.gather(tf.shape(token_ids), 1)) 101 102 # Yay! Broadcasting. 103 selected = tf.less(ranges, tf.expand_dims(lengths, -1)) 104 105 if eos_id is not None: 106 token_ids = tf.where( 107 selected, token_ids, 108 tf.fill( 109 tf.shape(token_ids), tf.constant(eos_id, dtype=token_ids.dtype))) 110 if special_tokens: 111 mask = tf.logical_and( 112 ignore_ids_mask(token_ids, special_tokens), selected) 113 else: 114 mask = selected 115 return token_ids, mask 116 117tf.no_gradient('DictionarySize') 118tf.no_gradient('DictionaryLookup') 119tf.no_gradient('DictionaryReverseLookup') 120 121 122class VocabularyType(enum.Enum): 123 """Valid vocabulary types for Dictionary construction. 124 125 TOKEN_INDEX: dictionary.dictionary_description contains an embedded map of 126 string names stored in order with ids assigned starting from the lowest 127 non-special id. Preserves order but is not compact. 128 """ 129 TOKEN_INDEX = 3 130 131 132class Dictionary(object): 133 """Utility for working with fcp/dictionary/ via TensorFlow.""" 134 135 def __init__( 136 self, 137 dictionary_description 138 ): 139 """Creates a dictionary from a dictionary_description. 140 141 Use static from_* constructor methods for building dictionaries from 142 common data types. 143 144 Args: 145 dictionary_description: A `dictionary_pb2.DictionaryDescription` 146 describing the dictionary. 147 148 Raises: 149 ValueError: An invalid dictionary description. 150 """ 151 if not isinstance(dictionary_description, DictionaryDescription): 152 raise ValueError('Expected a DictionaryDescription') 153 if not dictionary_description.HasField('vocabulary'): 154 raise ValueError('dictionary_description has no vocabulary') 155 156 self._dictionary_description = dictionary_description 157 158 # Lazily constructed fields for lookup. 159 self._lookup_graph = None 160 self._lookup_placeholder = None 161 self._lookup_result = None 162 self._reverse_lookup_placeholder = None 163 self._reverse_lookup_result = None 164 165 @classmethod 166 def from_tokens( 167 cls, 168 tokens, 169 bos_id=None, 170 eos_id=None, 171 unk_id=None, 172 output_blocklist_tokens=None, 173 output_size=None, 174 vocabulary_type=VocabularyType.TOKEN_INDEX 175 ): 176 """Creates a dictionary from a provided list of tokens. 177 178 The id mappings to token ids depend on the vocabulary_type requested. 179 180 NB: the special tokens must be the first ids [0, num-specials) 181 182 Args: 183 tokens: An unordered iterable of tokens for the dictionary. 184 bos_id: Token id for start of sequence. 185 eos_id: Token id for end of sequence. 186 unk_id: Token id for unknown words. 187 output_blocklist_tokens: A list of vocabulary tokens that should be 188 filtered from predictions (e.g., punctuation, bad words etc.). 189 output_size: If a positive integer, tokens with ids greater than this are 190 automatically added to the output blocklist. 191 vocabulary_type: `VocabularyType` to use, defaults to TOKEN_INDEX. 192 193 Returns: 194 A `Dictionary` instance. 195 196 Raises: 197 ValueError: If the special tokens don't have the lowest ids. 198 ValueError: If there are duplicates in tokens. 199 """ 200 dictionary_description = DictionaryDescription() 201 202 # Special ids. 203 special_ids = [] 204 if unk_id is not None: 205 dictionary_description.special_ids.unk = unk_id 206 special_ids.append(unk_id) 207 if bos_id is not None: 208 dictionary_description.special_ids.bos = bos_id 209 special_ids.append(bos_id) 210 if eos_id is not None: 211 dictionary_description.special_ids.eos = eos_id 212 special_ids.append(eos_id) 213 if sorted(special_ids) != list(range(len(special_ids))): 214 raise ValueError( 215 'Special ids must be the first items of the dictionary starting at 0' 216 'or None. eos: %s; bos %s; unk: %s' % (eos_id, bos_id, unk_id)) 217 218 # Vocabulary. 219 if len(tokens) != len(set(tokens)): 220 raise ValueError('Duplicate tokens provided') 221 for token in tokens: 222 if not isinstance(token, (str, bytes)): 223 raise ValueError('Bad type in tokens %s' % token) 224 if vocabulary_type == VocabularyType.TOKEN_INDEX: 225 for token in tokens: 226 dictionary_description.vocabulary.index.token.append(token) 227 else: 228 raise AssertionError('Unsupported vocabulary_type: %s' % vocabulary_type) 229 230 # Output blocklist. 231 output_blocklist_tokens = list(output_blocklist_tokens or []) 232 if output_size: 233 assert output_size >= len(special_ids), ( 234 'Cannot blocklist special tokens via output_size.') 235 assert isinstance(tokens, list) # Make sure order preserving pre-slice. 236 output_blocklist_tokens.extend(tokens[output_size - len(special_ids):]) 237 for token in output_blocklist_tokens: 238 assert token in tokens, "Unexpected blocklist token: '%s'" % token 239 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 240 output_blocklist_ids = sess.run( 241 dictionary_lookup(output_blocklist_tokens, 242 dictionary_description.SerializeToString())) 243 dictionary_description.output_blocklist_ids.id.extend( 244 sorted(output_blocklist_ids)) 245 assert (len(set(dictionary_description.output_blocklist_ids.id)) == len( 246 output_blocklist_tokens)), 'blocklist contains dups or unks?' 247 248 # Return completed dictionary. 249 return cls( 250 dictionary_description=dictionary_description) 251 252 @classmethod 253 def from_dictionary_description(cls, 254 dictionary_description): 255 """Returns a Dictionary from a DictionaryDescription.""" 256 return cls( 257 dictionary_description=dictionary_description) 258 259 def _get_lookup_graph(self): 260 """Returns a graph to use for lookup, reverse lookup, and size queries.""" 261 if self._lookup_graph is None: 262 self._lookup_graph = tf.Graph() 263 serialized_description_proto = ( 264 self._dictionary_description.SerializeToString()) 265 with self._lookup_graph.as_default(): 266 self._lookup_placeholder = tf.compat.v1.placeholder( 267 tf.string, shape=None) 268 self._reverse_lookup_placeholder = tf.compat.v1.placeholder( 269 tf.int64, shape=None) 270 271 # Use Dictionary(Op) (without blob) variants. 272 self._lookup_result = dictionary_lookup( 273 self._lookup_placeholder, 274 dictionary_description_proto=serialized_description_proto) 275 self._reverse_lookup_result = dictionary_reverse_lookup( 276 self._reverse_lookup_placeholder, 277 dictionary_description_proto=serialized_description_proto) 278 self._size_result = dictionary_size( 279 dictionary_description_proto=serialized_description_proto) 280 281 return self._lookup_graph 282 283 def lookup(self, tokens): 284 """Maps a list of tokens to a list of ids. 285 286 Args: 287 tokens: A list of tokens to lookup. 288 289 Returns: 290 A list of token ids of the same size. 291 292 Raises: 293 ValueError: If tokens is not a list. 294 """ 295 if not isinstance(tokens, list): 296 raise ValueError('lookup expected a list of tokens.') 297 298 with tf.compat.v1.Session(graph=self._get_lookup_graph()) as sess: 299 return sess.run(self._lookup_result, { 300 self._lookup_placeholder: tokens 301 }).tolist() 302 303 def reverse_lookup(self, ids): 304 """Maps a list of ids to tokens. 305 306 Args: 307 ids: A list of ids to map back to tokens. 308 309 Returns: 310 A list of tokens corresponding to those ids. 311 312 Raises: 313 ValueError: If ids is not a list. 314 """ 315 if not isinstance(ids, list): 316 raise ValueError('reverse_lookup expected a list of ids.') 317 with tf.compat.v1.Session(graph=self._get_lookup_graph()) as sess: 318 return list( 319 sess.run(self._reverse_lookup_result, 320 {self._reverse_lookup_placeholder: ids})) 321 322 @property 323 def special_ids(self): 324 """Returns a list of special token ids.""" 325 return [t for t in [self.unk_id, self.bos_id, self.eos_id] if t is not None] 326 327 @property 328 def eos_id(self): 329 eos_id = self._dictionary_description.special_ids.eos 330 return eos_id if eos_id >= 0 else None 331 332 @property 333 def bos_id(self): 334 bos_id = self._dictionary_description.special_ids.bos 335 return bos_id if bos_id >= 0 else None 336 337 @property 338 def unk_id(self): 339 unk_id = self._dictionary_description.special_ids.unk 340 return unk_id if unk_id >= 0 else None 341 342 @property 343 def size(self): 344 with tf.compat.v1.Session(graph=self._get_lookup_graph()) as sess: 345 return sess.run(self._size_result) 346 347 @property 348 def output_blocklist_ids(self): 349 return list(self._dictionary_description.output_blocklist_ids.id) 350 351 @property 352 def output_blocklist_tokens(self): 353 return self.reverse_lookup(self.output_blocklist_ids) 354 355 @property 356 def tokens(self): 357 return self.reverse_lookup(list(range(len(self.special_ids), self.size))) 358 359 @property 360 def dictionary_description_proto(self): 361 """Serialized proto containing self.dictionary_description.""" 362 return self.dictionary_description.SerializeToString() 363 364 @property 365 def dictionary_description(self): 366 """Returns the `DictionaryDescription` proto describing this dictionary. 367 """ 368 desc = self._dictionary_description 369 return desc 370 371 def __len__(self): 372 return self.size 373