1*14675a02SAndroid Build Coastguard Worker# Copyright 2022 Google LLC 2*14675a02SAndroid Build Coastguard Worker# 3*14675a02SAndroid Build Coastguard Worker# Licensed under the Apache License, Version 2.0 (the "License"); 4*14675a02SAndroid Build Coastguard Worker# you may not use this file except in compliance with the License. 5*14675a02SAndroid Build Coastguard Worker# You may obtain a copy of the License at 6*14675a02SAndroid Build Coastguard Worker# 7*14675a02SAndroid Build Coastguard Worker# http://www.apache.org/licenses/LICENSE-2.0 8*14675a02SAndroid Build Coastguard Worker# 9*14675a02SAndroid Build Coastguard Worker# Unless required by applicable law or agreed to in writing, software 10*14675a02SAndroid Build Coastguard Worker# distributed under the License is distributed on an "AS IS" BASIS, 11*14675a02SAndroid Build Coastguard Worker# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12*14675a02SAndroid Build Coastguard Worker# See the License for the specific language governing permissions and 13*14675a02SAndroid Build Coastguard Worker# limitations under the License. 14*14675a02SAndroid Build Coastguard Worker"""Python and TensorFlow functions to work with dictionaries. 15*14675a02SAndroid Build Coastguard Worker 16*14675a02SAndroid Build Coastguard WorkerPlease see fcp/dictionary/dictionary.h for more on this type of 17*14675a02SAndroid Build Coastguard Workerdictionary. 18*14675a02SAndroid Build Coastguard Worker 19*14675a02SAndroid Build Coastguard WorkerPython Classes: 20*14675a02SAndroid Build Coastguard Worker 21*14675a02SAndroid Build Coastguard Worker* `Dictionary`: A Python analogue to fcp/dictionary/dictionary.h 22*14675a02SAndroid Build Coastguard Worker that includes additional helpers for dictionary construction. 23*14675a02SAndroid Build Coastguard Worker 24*14675a02SAndroid Build Coastguard WorkerTensorFlow ops: 25*14675a02SAndroid Build Coastguard Worker 26*14675a02SAndroid Build Coastguard Worker* dictionary_size 27*14675a02SAndroid Build Coastguard Worker Queries the size of a dictionary. 28*14675a02SAndroid Build Coastguard Worker 29*14675a02SAndroid Build Coastguard Worker* dictionary_lookup 30*14675a02SAndroid Build Coastguard Worker Looks up ids for string tokens in the dictionary. 31*14675a02SAndroid Build Coastguard Worker 32*14675a02SAndroid Build Coastguard Worker* dictionary_reverse_lookup 33*14675a02SAndroid Build Coastguard Worker Looks up string tokens from ids in the dictionary. 34*14675a02SAndroid Build Coastguard Worker 35*14675a02SAndroid Build Coastguard WorkerCanonical use (note that the dictionary is known at graph construction time): 36*14675a02SAndroid Build Coastguard Worker dictionary = Dictionary.from_tokens( 37*14675a02SAndroid Build Coastguard Worker tokens=['some', 'token', 'list'], unk_id=0, 38*14675a02SAndroid Build Coastguard Worker vocabulary_type=VocabularyType.TOKEN_INDEX) 39*14675a02SAndroid Build Coastguard Worker 40*14675a02SAndroid Build Coastguard Worker with tf.Graph().as_default(): 41*14675a02SAndroid Build Coastguard Worker tokens = tf.compat.v1.placeholder(tf.String, ...) # Tokens to look up. 42*14675a02SAndroid Build Coastguard Worker ids = dictionary_lookup( 43*14675a02SAndroid Build Coastguard Worker tokens, dictionary.dictionary_description_proto) 44*14675a02SAndroid Build Coastguard Worker""" 45*14675a02SAndroid Build Coastguard Worker 46*14675a02SAndroid Build Coastguard Workerimport collections 47*14675a02SAndroid Build Coastguard Workerimport enum 48*14675a02SAndroid Build Coastguard Worker 49*14675a02SAndroid Build Coastguard Workerimport tensorflow as tf 50*14675a02SAndroid Build Coastguard Worker 51*14675a02SAndroid Build Coastguard Workerfrom fcp.dictionary.dictionary_pb2 import DictionaryDescription # pylint: disable=g-importing-member 52*14675a02SAndroid Build Coastguard Workerfrom fcp.tensorflow.gen_dictionary_ops import dictionary_lookup 53*14675a02SAndroid Build Coastguard Workerfrom fcp.tensorflow.gen_dictionary_ops import dictionary_reverse_lookup 54*14675a02SAndroid Build Coastguard Workerfrom fcp.tensorflow.gen_dictionary_ops import dictionary_size 55*14675a02SAndroid Build Coastguard Worker 56*14675a02SAndroid Build Coastguard Worker_dictionary_ops = tf.load_op_library( 57*14675a02SAndroid Build Coastguard Worker tf.compat.v1.resource_loader.get_path_to_datafile('./_dictionary_ops.so')) 58*14675a02SAndroid Build Coastguard Worker 59*14675a02SAndroid Build Coastguard Worker 60*14675a02SAndroid Build Coastguard Workerdef ignore_ids_mask(token_ids, ignore_ids, name=None): 61*14675a02SAndroid Build Coastguard Worker """Creates a bool mask with True everywhere token_ids is not in ignore_ids.""" 62*14675a02SAndroid Build Coastguard Worker with tf.op_scope([token_ids, ignore_ids], name, 'ignore_ids_mask'): 63*14675a02SAndroid Build Coastguard Worker # Yay broadcasting 64*14675a02SAndroid Build Coastguard Worker all_check = tf.not_equal(tf.expand_dims(token_ids, -1), ignore_ids) 65*14675a02SAndroid Build Coastguard Worker check = tf.reduce_all(all_check, reduction_indices=tf.rank(all_check) - 1) 66*14675a02SAndroid Build Coastguard Worker check.set_shape(token_ids.get_shape()) 67*14675a02SAndroid Build Coastguard Worker return check 68*14675a02SAndroid Build Coastguard Worker 69*14675a02SAndroid Build Coastguard Worker 70*14675a02SAndroid Build Coastguard Workerdef mask_and_replace_padding(token_ids, 71*14675a02SAndroid Build Coastguard Worker lengths, 72*14675a02SAndroid Build Coastguard Worker eos_id=None, 73*14675a02SAndroid Build Coastguard Worker special_tokens=(), 74*14675a02SAndroid Build Coastguard Worker name=None): 75*14675a02SAndroid Build Coastguard Worker """Creates a mask of valid tokens and sets padded values in id space. 76*14675a02SAndroid Build Coastguard Worker 77*14675a02SAndroid Build Coastguard Worker This creates a mask the same shape as token_ids with a boolean indicating 78*14675a02SAndroid Build Coastguard Worker if the id was a valid token (i.e not padding or a special token). If 79*14675a02SAndroid Build Coastguard Worker provided, this also remaps tokens after lengths to the eos_id. Since the 80*14675a02SAndroid Build Coastguard Worker dictionary doesn't map tokens to eos or bos ids, it would generally be the 81*14675a02SAndroid Build Coastguard Worker unknown token id which is not correct if you need to predict the eos. 82*14675a02SAndroid Build Coastguard Worker 83*14675a02SAndroid Build Coastguard Worker Args: 84*14675a02SAndroid Build Coastguard Worker token_ids: A matrix `Tensor` of integer ids. 85*14675a02SAndroid Build Coastguard Worker lengths: A vector `Tensor` of lengths for each row in token_ids. 86*14675a02SAndroid Build Coastguard Worker eos_id: The end of sequence id, if provided then all token ids after length 87*14675a02SAndroid Build Coastguard Worker in a row will be replaced with `eos_id`. 88*14675a02SAndroid Build Coastguard Worker special_tokens: An iterable of special tokens for ids that are not 89*14675a02SAndroid Build Coastguard Worker considered valid. 90*14675a02SAndroid Build Coastguard Worker name: Name scope for these ops. 91*14675a02SAndroid Build Coastguard Worker 92*14675a02SAndroid Build Coastguard Worker Returns: 93*14675a02SAndroid Build Coastguard Worker token_ids: `token_ids` with all tokens after a row's length replaced with 94*14675a02SAndroid Build Coastguard Worker eos if provided. 95*14675a02SAndroid Build Coastguard Worker mask: A bool `Tensor` the same shape as `token_ids` indicating which tokens 96*14675a02SAndroid Build Coastguard Worker are valid. 97*14675a02SAndroid Build Coastguard Worker """ 98*14675a02SAndroid Build Coastguard Worker with tf.op_scope([token_ids, lengths, eos_id, special_tokens], name, 99*14675a02SAndroid Build Coastguard Worker 'mask_and_replace_padding'): 100*14675a02SAndroid Build Coastguard Worker ranges = tf.range(0, tf.gather(tf.shape(token_ids), 1)) 101*14675a02SAndroid Build Coastguard Worker 102*14675a02SAndroid Build Coastguard Worker # Yay! Broadcasting. 103*14675a02SAndroid Build Coastguard Worker selected = tf.less(ranges, tf.expand_dims(lengths, -1)) 104*14675a02SAndroid Build Coastguard Worker 105*14675a02SAndroid Build Coastguard Worker if eos_id is not None: 106*14675a02SAndroid Build Coastguard Worker token_ids = tf.where( 107*14675a02SAndroid Build Coastguard Worker selected, token_ids, 108*14675a02SAndroid Build Coastguard Worker tf.fill( 109*14675a02SAndroid Build Coastguard Worker tf.shape(token_ids), tf.constant(eos_id, dtype=token_ids.dtype))) 110*14675a02SAndroid Build Coastguard Worker if special_tokens: 111*14675a02SAndroid Build Coastguard Worker mask = tf.logical_and( 112*14675a02SAndroid Build Coastguard Worker ignore_ids_mask(token_ids, special_tokens), selected) 113*14675a02SAndroid Build Coastguard Worker else: 114*14675a02SAndroid Build Coastguard Worker mask = selected 115*14675a02SAndroid Build Coastguard Worker return token_ids, mask 116*14675a02SAndroid Build Coastguard Worker 117*14675a02SAndroid Build Coastguard Workertf.no_gradient('DictionarySize') 118*14675a02SAndroid Build Coastguard Workertf.no_gradient('DictionaryLookup') 119*14675a02SAndroid Build Coastguard Workertf.no_gradient('DictionaryReverseLookup') 120*14675a02SAndroid Build Coastguard Worker 121*14675a02SAndroid Build Coastguard Worker 122*14675a02SAndroid Build Coastguard Workerclass VocabularyType(enum.Enum): 123*14675a02SAndroid Build Coastguard Worker """Valid vocabulary types for Dictionary construction. 124*14675a02SAndroid Build Coastguard Worker 125*14675a02SAndroid Build Coastguard Worker TOKEN_INDEX: dictionary.dictionary_description contains an embedded map of 126*14675a02SAndroid Build Coastguard Worker string names stored in order with ids assigned starting from the lowest 127*14675a02SAndroid Build Coastguard Worker non-special id. Preserves order but is not compact. 128*14675a02SAndroid Build Coastguard Worker """ 129*14675a02SAndroid Build Coastguard Worker TOKEN_INDEX = 3 130*14675a02SAndroid Build Coastguard Worker 131*14675a02SAndroid Build Coastguard Worker 132*14675a02SAndroid Build Coastguard Workerclass Dictionary(object): 133*14675a02SAndroid Build Coastguard Worker """Utility for working with fcp/dictionary/ via TensorFlow.""" 134*14675a02SAndroid Build Coastguard Worker 135*14675a02SAndroid Build Coastguard Worker def __init__( 136*14675a02SAndroid Build Coastguard Worker self, 137*14675a02SAndroid Build Coastguard Worker dictionary_description 138*14675a02SAndroid Build Coastguard Worker ): 139*14675a02SAndroid Build Coastguard Worker """Creates a dictionary from a dictionary_description. 140*14675a02SAndroid Build Coastguard Worker 141*14675a02SAndroid Build Coastguard Worker Use static from_* constructor methods for building dictionaries from 142*14675a02SAndroid Build Coastguard Worker common data types. 143*14675a02SAndroid Build Coastguard Worker 144*14675a02SAndroid Build Coastguard Worker Args: 145*14675a02SAndroid Build Coastguard Worker dictionary_description: A `dictionary_pb2.DictionaryDescription` 146*14675a02SAndroid Build Coastguard Worker describing the dictionary. 147*14675a02SAndroid Build Coastguard Worker 148*14675a02SAndroid Build Coastguard Worker Raises: 149*14675a02SAndroid Build Coastguard Worker ValueError: An invalid dictionary description. 150*14675a02SAndroid Build Coastguard Worker """ 151*14675a02SAndroid Build Coastguard Worker if not isinstance(dictionary_description, DictionaryDescription): 152*14675a02SAndroid Build Coastguard Worker raise ValueError('Expected a DictionaryDescription') 153*14675a02SAndroid Build Coastguard Worker if not dictionary_description.HasField('vocabulary'): 154*14675a02SAndroid Build Coastguard Worker raise ValueError('dictionary_description has no vocabulary') 155*14675a02SAndroid Build Coastguard Worker 156*14675a02SAndroid Build Coastguard Worker self._dictionary_description = dictionary_description 157*14675a02SAndroid Build Coastguard Worker 158*14675a02SAndroid Build Coastguard Worker # Lazily constructed fields for lookup. 159*14675a02SAndroid Build Coastguard Worker self._lookup_graph = None 160*14675a02SAndroid Build Coastguard Worker self._lookup_placeholder = None 161*14675a02SAndroid Build Coastguard Worker self._lookup_result = None 162*14675a02SAndroid Build Coastguard Worker self._reverse_lookup_placeholder = None 163*14675a02SAndroid Build Coastguard Worker self._reverse_lookup_result = None 164*14675a02SAndroid Build Coastguard Worker 165*14675a02SAndroid Build Coastguard Worker @classmethod 166*14675a02SAndroid Build Coastguard Worker def from_tokens( 167*14675a02SAndroid Build Coastguard Worker cls, 168*14675a02SAndroid Build Coastguard Worker tokens, 169*14675a02SAndroid Build Coastguard Worker bos_id=None, 170*14675a02SAndroid Build Coastguard Worker eos_id=None, 171*14675a02SAndroid Build Coastguard Worker unk_id=None, 172*14675a02SAndroid Build Coastguard Worker output_blocklist_tokens=None, 173*14675a02SAndroid Build Coastguard Worker output_size=None, 174*14675a02SAndroid Build Coastguard Worker vocabulary_type=VocabularyType.TOKEN_INDEX 175*14675a02SAndroid Build Coastguard Worker ): 176*14675a02SAndroid Build Coastguard Worker """Creates a dictionary from a provided list of tokens. 177*14675a02SAndroid Build Coastguard Worker 178*14675a02SAndroid Build Coastguard Worker The id mappings to token ids depend on the vocabulary_type requested. 179*14675a02SAndroid Build Coastguard Worker 180*14675a02SAndroid Build Coastguard Worker NB: the special tokens must be the first ids [0, num-specials) 181*14675a02SAndroid Build Coastguard Worker 182*14675a02SAndroid Build Coastguard Worker Args: 183*14675a02SAndroid Build Coastguard Worker tokens: An unordered iterable of tokens for the dictionary. 184*14675a02SAndroid Build Coastguard Worker bos_id: Token id for start of sequence. 185*14675a02SAndroid Build Coastguard Worker eos_id: Token id for end of sequence. 186*14675a02SAndroid Build Coastguard Worker unk_id: Token id for unknown words. 187*14675a02SAndroid Build Coastguard Worker output_blocklist_tokens: A list of vocabulary tokens that should be 188*14675a02SAndroid Build Coastguard Worker filtered from predictions (e.g., punctuation, bad words etc.). 189*14675a02SAndroid Build Coastguard Worker output_size: If a positive integer, tokens with ids greater than this are 190*14675a02SAndroid Build Coastguard Worker automatically added to the output blocklist. 191*14675a02SAndroid Build Coastguard Worker vocabulary_type: `VocabularyType` to use, defaults to TOKEN_INDEX. 192*14675a02SAndroid Build Coastguard Worker 193*14675a02SAndroid Build Coastguard Worker Returns: 194*14675a02SAndroid Build Coastguard Worker A `Dictionary` instance. 195*14675a02SAndroid Build Coastguard Worker 196*14675a02SAndroid Build Coastguard Worker Raises: 197*14675a02SAndroid Build Coastguard Worker ValueError: If the special tokens don't have the lowest ids. 198*14675a02SAndroid Build Coastguard Worker ValueError: If there are duplicates in tokens. 199*14675a02SAndroid Build Coastguard Worker """ 200*14675a02SAndroid Build Coastguard Worker dictionary_description = DictionaryDescription() 201*14675a02SAndroid Build Coastguard Worker 202*14675a02SAndroid Build Coastguard Worker # Special ids. 203*14675a02SAndroid Build Coastguard Worker special_ids = [] 204*14675a02SAndroid Build Coastguard Worker if unk_id is not None: 205*14675a02SAndroid Build Coastguard Worker dictionary_description.special_ids.unk = unk_id 206*14675a02SAndroid Build Coastguard Worker special_ids.append(unk_id) 207*14675a02SAndroid Build Coastguard Worker if bos_id is not None: 208*14675a02SAndroid Build Coastguard Worker dictionary_description.special_ids.bos = bos_id 209*14675a02SAndroid Build Coastguard Worker special_ids.append(bos_id) 210*14675a02SAndroid Build Coastguard Worker if eos_id is not None: 211*14675a02SAndroid Build Coastguard Worker dictionary_description.special_ids.eos = eos_id 212*14675a02SAndroid Build Coastguard Worker special_ids.append(eos_id) 213*14675a02SAndroid Build Coastguard Worker if sorted(special_ids) != list(range(len(special_ids))): 214*14675a02SAndroid Build Coastguard Worker raise ValueError( 215*14675a02SAndroid Build Coastguard Worker 'Special ids must be the first items of the dictionary starting at 0' 216*14675a02SAndroid Build Coastguard Worker 'or None. eos: %s; bos %s; unk: %s' % (eos_id, bos_id, unk_id)) 217*14675a02SAndroid Build Coastguard Worker 218*14675a02SAndroid Build Coastguard Worker # Vocabulary. 219*14675a02SAndroid Build Coastguard Worker if len(tokens) != len(set(tokens)): 220*14675a02SAndroid Build Coastguard Worker raise ValueError('Duplicate tokens provided') 221*14675a02SAndroid Build Coastguard Worker for token in tokens: 222*14675a02SAndroid Build Coastguard Worker if not isinstance(token, (str, bytes)): 223*14675a02SAndroid Build Coastguard Worker raise ValueError('Bad type in tokens %s' % token) 224*14675a02SAndroid Build Coastguard Worker if vocabulary_type == VocabularyType.TOKEN_INDEX: 225*14675a02SAndroid Build Coastguard Worker for token in tokens: 226*14675a02SAndroid Build Coastguard Worker dictionary_description.vocabulary.index.token.append(token) 227*14675a02SAndroid Build Coastguard Worker else: 228*14675a02SAndroid Build Coastguard Worker raise AssertionError('Unsupported vocabulary_type: %s' % vocabulary_type) 229*14675a02SAndroid Build Coastguard Worker 230*14675a02SAndroid Build Coastguard Worker # Output blocklist. 231*14675a02SAndroid Build Coastguard Worker output_blocklist_tokens = list(output_blocklist_tokens or []) 232*14675a02SAndroid Build Coastguard Worker if output_size: 233*14675a02SAndroid Build Coastguard Worker assert output_size >= len(special_ids), ( 234*14675a02SAndroid Build Coastguard Worker 'Cannot blocklist special tokens via output_size.') 235*14675a02SAndroid Build Coastguard Worker assert isinstance(tokens, list) # Make sure order preserving pre-slice. 236*14675a02SAndroid Build Coastguard Worker output_blocklist_tokens.extend(tokens[output_size - len(special_ids):]) 237*14675a02SAndroid Build Coastguard Worker for token in output_blocklist_tokens: 238*14675a02SAndroid Build Coastguard Worker assert token in tokens, "Unexpected blocklist token: '%s'" % token 239*14675a02SAndroid Build Coastguard Worker with tf.compat.v1.Session(graph=tf.Graph()) as sess: 240*14675a02SAndroid Build Coastguard Worker output_blocklist_ids = sess.run( 241*14675a02SAndroid Build Coastguard Worker dictionary_lookup(output_blocklist_tokens, 242*14675a02SAndroid Build Coastguard Worker dictionary_description.SerializeToString())) 243*14675a02SAndroid Build Coastguard Worker dictionary_description.output_blocklist_ids.id.extend( 244*14675a02SAndroid Build Coastguard Worker sorted(output_blocklist_ids)) 245*14675a02SAndroid Build Coastguard Worker assert (len(set(dictionary_description.output_blocklist_ids.id)) == len( 246*14675a02SAndroid Build Coastguard Worker output_blocklist_tokens)), 'blocklist contains dups or unks?' 247*14675a02SAndroid Build Coastguard Worker 248*14675a02SAndroid Build Coastguard Worker # Return completed dictionary. 249*14675a02SAndroid Build Coastguard Worker return cls( 250*14675a02SAndroid Build Coastguard Worker dictionary_description=dictionary_description) 251*14675a02SAndroid Build Coastguard Worker 252*14675a02SAndroid Build Coastguard Worker @classmethod 253*14675a02SAndroid Build Coastguard Worker def from_dictionary_description(cls, 254*14675a02SAndroid Build Coastguard Worker dictionary_description): 255*14675a02SAndroid Build Coastguard Worker """Returns a Dictionary from a DictionaryDescription.""" 256*14675a02SAndroid Build Coastguard Worker return cls( 257*14675a02SAndroid Build Coastguard Worker dictionary_description=dictionary_description) 258*14675a02SAndroid Build Coastguard Worker 259*14675a02SAndroid Build Coastguard Worker def _get_lookup_graph(self): 260*14675a02SAndroid Build Coastguard Worker """Returns a graph to use for lookup, reverse lookup, and size queries.""" 261*14675a02SAndroid Build Coastguard Worker if self._lookup_graph is None: 262*14675a02SAndroid Build Coastguard Worker self._lookup_graph = tf.Graph() 263*14675a02SAndroid Build Coastguard Worker serialized_description_proto = ( 264*14675a02SAndroid Build Coastguard Worker self._dictionary_description.SerializeToString()) 265*14675a02SAndroid Build Coastguard Worker with self._lookup_graph.as_default(): 266*14675a02SAndroid Build Coastguard Worker self._lookup_placeholder = tf.compat.v1.placeholder( 267*14675a02SAndroid Build Coastguard Worker tf.string, shape=None) 268*14675a02SAndroid Build Coastguard Worker self._reverse_lookup_placeholder = tf.compat.v1.placeholder( 269*14675a02SAndroid Build Coastguard Worker tf.int64, shape=None) 270*14675a02SAndroid Build Coastguard Worker 271*14675a02SAndroid Build Coastguard Worker # Use Dictionary(Op) (without blob) variants. 272*14675a02SAndroid Build Coastguard Worker self._lookup_result = dictionary_lookup( 273*14675a02SAndroid Build Coastguard Worker self._lookup_placeholder, 274*14675a02SAndroid Build Coastguard Worker dictionary_description_proto=serialized_description_proto) 275*14675a02SAndroid Build Coastguard Worker self._reverse_lookup_result = dictionary_reverse_lookup( 276*14675a02SAndroid Build Coastguard Worker self._reverse_lookup_placeholder, 277*14675a02SAndroid Build Coastguard Worker dictionary_description_proto=serialized_description_proto) 278*14675a02SAndroid Build Coastguard Worker self._size_result = dictionary_size( 279*14675a02SAndroid Build Coastguard Worker dictionary_description_proto=serialized_description_proto) 280*14675a02SAndroid Build Coastguard Worker 281*14675a02SAndroid Build Coastguard Worker return self._lookup_graph 282*14675a02SAndroid Build Coastguard Worker 283*14675a02SAndroid Build Coastguard Worker def lookup(self, tokens): 284*14675a02SAndroid Build Coastguard Worker """Maps a list of tokens to a list of ids. 285*14675a02SAndroid Build Coastguard Worker 286*14675a02SAndroid Build Coastguard Worker Args: 287*14675a02SAndroid Build Coastguard Worker tokens: A list of tokens to lookup. 288*14675a02SAndroid Build Coastguard Worker 289*14675a02SAndroid Build Coastguard Worker Returns: 290*14675a02SAndroid Build Coastguard Worker A list of token ids of the same size. 291*14675a02SAndroid Build Coastguard Worker 292*14675a02SAndroid Build Coastguard Worker Raises: 293*14675a02SAndroid Build Coastguard Worker ValueError: If tokens is not a list. 294*14675a02SAndroid Build Coastguard Worker """ 295*14675a02SAndroid Build Coastguard Worker if not isinstance(tokens, list): 296*14675a02SAndroid Build Coastguard Worker raise ValueError('lookup expected a list of tokens.') 297*14675a02SAndroid Build Coastguard Worker 298*14675a02SAndroid Build Coastguard Worker with tf.compat.v1.Session(graph=self._get_lookup_graph()) as sess: 299*14675a02SAndroid Build Coastguard Worker return sess.run(self._lookup_result, { 300*14675a02SAndroid Build Coastguard Worker self._lookup_placeholder: tokens 301*14675a02SAndroid Build Coastguard Worker }).tolist() 302*14675a02SAndroid Build Coastguard Worker 303*14675a02SAndroid Build Coastguard Worker def reverse_lookup(self, ids): 304*14675a02SAndroid Build Coastguard Worker """Maps a list of ids to tokens. 305*14675a02SAndroid Build Coastguard Worker 306*14675a02SAndroid Build Coastguard Worker Args: 307*14675a02SAndroid Build Coastguard Worker ids: A list of ids to map back to tokens. 308*14675a02SAndroid Build Coastguard Worker 309*14675a02SAndroid Build Coastguard Worker Returns: 310*14675a02SAndroid Build Coastguard Worker A list of tokens corresponding to those ids. 311*14675a02SAndroid Build Coastguard Worker 312*14675a02SAndroid Build Coastguard Worker Raises: 313*14675a02SAndroid Build Coastguard Worker ValueError: If ids is not a list. 314*14675a02SAndroid Build Coastguard Worker """ 315*14675a02SAndroid Build Coastguard Worker if not isinstance(ids, list): 316*14675a02SAndroid Build Coastguard Worker raise ValueError('reverse_lookup expected a list of ids.') 317*14675a02SAndroid Build Coastguard Worker with tf.compat.v1.Session(graph=self._get_lookup_graph()) as sess: 318*14675a02SAndroid Build Coastguard Worker return list( 319*14675a02SAndroid Build Coastguard Worker sess.run(self._reverse_lookup_result, 320*14675a02SAndroid Build Coastguard Worker {self._reverse_lookup_placeholder: ids})) 321*14675a02SAndroid Build Coastguard Worker 322*14675a02SAndroid Build Coastguard Worker @property 323*14675a02SAndroid Build Coastguard Worker def special_ids(self): 324*14675a02SAndroid Build Coastguard Worker """Returns a list of special token ids.""" 325*14675a02SAndroid Build Coastguard Worker return [t for t in [self.unk_id, self.bos_id, self.eos_id] if t is not None] 326*14675a02SAndroid Build Coastguard Worker 327*14675a02SAndroid Build Coastguard Worker @property 328*14675a02SAndroid Build Coastguard Worker def eos_id(self): 329*14675a02SAndroid Build Coastguard Worker eos_id = self._dictionary_description.special_ids.eos 330*14675a02SAndroid Build Coastguard Worker return eos_id if eos_id >= 0 else None 331*14675a02SAndroid Build Coastguard Worker 332*14675a02SAndroid Build Coastguard Worker @property 333*14675a02SAndroid Build Coastguard Worker def bos_id(self): 334*14675a02SAndroid Build Coastguard Worker bos_id = self._dictionary_description.special_ids.bos 335*14675a02SAndroid Build Coastguard Worker return bos_id if bos_id >= 0 else None 336*14675a02SAndroid Build Coastguard Worker 337*14675a02SAndroid Build Coastguard Worker @property 338*14675a02SAndroid Build Coastguard Worker def unk_id(self): 339*14675a02SAndroid Build Coastguard Worker unk_id = self._dictionary_description.special_ids.unk 340*14675a02SAndroid Build Coastguard Worker return unk_id if unk_id >= 0 else None 341*14675a02SAndroid Build Coastguard Worker 342*14675a02SAndroid Build Coastguard Worker @property 343*14675a02SAndroid Build Coastguard Worker def size(self): 344*14675a02SAndroid Build Coastguard Worker with tf.compat.v1.Session(graph=self._get_lookup_graph()) as sess: 345*14675a02SAndroid Build Coastguard Worker return sess.run(self._size_result) 346*14675a02SAndroid Build Coastguard Worker 347*14675a02SAndroid Build Coastguard Worker @property 348*14675a02SAndroid Build Coastguard Worker def output_blocklist_ids(self): 349*14675a02SAndroid Build Coastguard Worker return list(self._dictionary_description.output_blocklist_ids.id) 350*14675a02SAndroid Build Coastguard Worker 351*14675a02SAndroid Build Coastguard Worker @property 352*14675a02SAndroid Build Coastguard Worker def output_blocklist_tokens(self): 353*14675a02SAndroid Build Coastguard Worker return self.reverse_lookup(self.output_blocklist_ids) 354*14675a02SAndroid Build Coastguard Worker 355*14675a02SAndroid Build Coastguard Worker @property 356*14675a02SAndroid Build Coastguard Worker def tokens(self): 357*14675a02SAndroid Build Coastguard Worker return self.reverse_lookup(list(range(len(self.special_ids), self.size))) 358*14675a02SAndroid Build Coastguard Worker 359*14675a02SAndroid Build Coastguard Worker @property 360*14675a02SAndroid Build Coastguard Worker def dictionary_description_proto(self): 361*14675a02SAndroid Build Coastguard Worker """Serialized proto containing self.dictionary_description.""" 362*14675a02SAndroid Build Coastguard Worker return self.dictionary_description.SerializeToString() 363*14675a02SAndroid Build Coastguard Worker 364*14675a02SAndroid Build Coastguard Worker @property 365*14675a02SAndroid Build Coastguard Worker def dictionary_description(self): 366*14675a02SAndroid Build Coastguard Worker """Returns the `DictionaryDescription` proto describing this dictionary. 367*14675a02SAndroid Build Coastguard Worker """ 368*14675a02SAndroid Build Coastguard Worker desc = self._dictionary_description 369*14675a02SAndroid Build Coastguard Worker return desc 370*14675a02SAndroid Build Coastguard Worker 371*14675a02SAndroid Build Coastguard Worker def __len__(self): 372*14675a02SAndroid Build Coastguard Worker return self.size 373