1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker 8*523fa7a6SAndroid Build Coastguard Worker# Script to rewrite tokenizer model given by sentencepiece, with lightweight 9*523fa7a6SAndroid Build Coastguard Worker# postprocessing logic. 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Workerimport argparse 12*523fa7a6SAndroid Build Coastguard Workerimport logging 13*523fa7a6SAndroid Build Coastguard Workerimport os 14*523fa7a6SAndroid Build Coastguard Workerimport struct 15*523fa7a6SAndroid Build Coastguard Workerfrom typing import List 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Workerfrom sentencepiece import SentencePieceProcessor as SentencePieceProcessor 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Worker 20*523fa7a6SAndroid Build Coastguard Workerclass Tokenizer: 21*523fa7a6SAndroid Build Coastguard Worker def __init__(self, model_path: str): 22*523fa7a6SAndroid Build Coastguard Worker assert os.path.isfile( 23*523fa7a6SAndroid Build Coastguard Worker model_path 24*523fa7a6SAndroid Build Coastguard Worker ), f"Need a valid tokenizer model path but got {model_path}" 25*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[28]: Unexpected keyword argument `model_file` to call `SentencePieceProcessor.__init__`. 26*523fa7a6SAndroid Build Coastguard Worker self.sp_model = SentencePieceProcessor(model_file=model_path) 27*523fa7a6SAndroid Build Coastguard Worker self.model_path = model_path 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Worker # BOS / EOS token IDs 30*523fa7a6SAndroid Build Coastguard Worker self.n_words: int = self.sp_model.vocab_size() 31*523fa7a6SAndroid Build Coastguard Worker self.bos_id: int = self.sp_model.bos_id() 32*523fa7a6SAndroid Build Coastguard Worker self.eos_id: int = self.sp_model.eos_id() 33*523fa7a6SAndroid Build Coastguard Worker logging.info( 34*523fa7a6SAndroid Build Coastguard Worker f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" 35*523fa7a6SAndroid Build Coastguard Worker ) 36*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_piece_size`. 37*523fa7a6SAndroid Build Coastguard Worker assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() 38*523fa7a6SAndroid Build Coastguard Worker 39*523fa7a6SAndroid Build Coastguard Worker def encode(self, s: str, bos: bool, eos: bool) -> List[int]: 40*523fa7a6SAndroid Build Coastguard Worker assert type(s) is str 41*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. 42*523fa7a6SAndroid Build Coastguard Worker t = self.sp_model.encode(s) 43*523fa7a6SAndroid Build Coastguard Worker if bos: 44*523fa7a6SAndroid Build Coastguard Worker t = [self.bos_id] + t 45*523fa7a6SAndroid Build Coastguard Worker if eos: 46*523fa7a6SAndroid Build Coastguard Worker t = t + [self.eos_id] 47*523fa7a6SAndroid Build Coastguard Worker return t 48*523fa7a6SAndroid Build Coastguard Worker 49*523fa7a6SAndroid Build Coastguard Worker def decode(self, t: List[int]) -> str: 50*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. 51*523fa7a6SAndroid Build Coastguard Worker return self.sp_model.decode(t) 52*523fa7a6SAndroid Build Coastguard Worker 53*523fa7a6SAndroid Build Coastguard Worker def decode_token(self, t: int) -> str: 54*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. 55*523fa7a6SAndroid Build Coastguard Worker return self.sp_model.decode(t) 56*523fa7a6SAndroid Build Coastguard Worker 57*523fa7a6SAndroid Build Coastguard Worker def export(self, output_path: str, *, prepend_padding: bool = False) -> None: 58*523fa7a6SAndroid Build Coastguard Worker """ 59*523fa7a6SAndroid Build Coastguard Worker Export tokenizer.model to another serialization format. Here we did some lightweight 60*523fa7a6SAndroid Build Coastguard Worker processing such as supporting prepend padding token, prepend max token length and 61*523fa7a6SAndroid Build Coastguard Worker replace '_' back to empty space. 62*523fa7a6SAndroid Build Coastguard Worker 63*523fa7a6SAndroid Build Coastguard Worker The binary format is: 64*523fa7a6SAndroid Build Coastguard Worker 1. vocab size: int32 65*523fa7a6SAndroid Build Coastguard Worker 2. bos token id: int32 66*523fa7a6SAndroid Build Coastguard Worker 3. eos token id: int32 67*523fa7a6SAndroid Build Coastguard Worker 4. max token length: int32 68*523fa7a6SAndroid Build Coastguard Worker 5. score: float32, len of bytes: int32, token bytes: [byte] for each token 69*523fa7a6SAndroid Build Coastguard Worker 70*523fa7a6SAndroid Build Coastguard Worker :param output_path: output path of the new binary. 71*523fa7a6SAndroid Build Coastguard Worker :param prepend_padding: a boolean to control if we want to prepend a padding token. 72*523fa7a6SAndroid Build Coastguard Worker 73*523fa7a6SAndroid Build Coastguard Worker :return: None 74*523fa7a6SAndroid Build Coastguard Worker """ 75*523fa7a6SAndroid Build Coastguard Worker 76*523fa7a6SAndroid Build Coastguard Worker # get all the tokens (postprocessed) and their scores as floats 77*523fa7a6SAndroid Build Coastguard Worker tokens, scores = [], [] 78*523fa7a6SAndroid Build Coastguard Worker 79*523fa7a6SAndroid Build Coastguard Worker if prepend_padding: 80*523fa7a6SAndroid Build Coastguard Worker # Here we use the default padding token and its score. 81*523fa7a6SAndroid Build Coastguard Worker tokens.append("<pad>".encode("utf-8")) 82*523fa7a6SAndroid Build Coastguard Worker scores.append(-1) 83*523fa7a6SAndroid Build Coastguard Worker 84*523fa7a6SAndroid Build Coastguard Worker for i in range(self.n_words): 85*523fa7a6SAndroid Build Coastguard Worker 86*523fa7a6SAndroid Build Coastguard Worker # decode the token and light postprocessing 87*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `id_to_piece`. 88*523fa7a6SAndroid Build Coastguard Worker t = self.sp_model.id_to_piece(i) 89*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_score`. 90*523fa7a6SAndroid Build Coastguard Worker s = self.sp_model.get_score(i) 91*523fa7a6SAndroid Build Coastguard Worker # sentencepiece use '<s>' as BOS and '</s>' for EOS 92*523fa7a6SAndroid Build Coastguard Worker if i == self.bos_id: 93*523fa7a6SAndroid Build Coastguard Worker t = "<s>" 94*523fa7a6SAndroid Build Coastguard Worker elif i == self.eos_id: 95*523fa7a6SAndroid Build Coastguard Worker t = "</s>" 96*523fa7a6SAndroid Build Coastguard Worker t = t.replace("▁", " ") # sentencepiece uses this character as whitespace 97*523fa7a6SAndroid Build Coastguard Worker b = t.encode("utf-8") # bytes of this token, utf-8 encoded 98*523fa7a6SAndroid Build Coastguard Worker 99*523fa7a6SAndroid Build Coastguard Worker tokens.append(b) 100*523fa7a6SAndroid Build Coastguard Worker scores.append(s) 101*523fa7a6SAndroid Build Coastguard Worker 102*523fa7a6SAndroid Build Coastguard Worker # record the max token length 103*523fa7a6SAndroid Build Coastguard Worker max_token_length = 0 if not tokens else max(len(t) for t in tokens) 104*523fa7a6SAndroid Build Coastguard Worker 105*523fa7a6SAndroid Build Coastguard Worker # write to a binary file 106*523fa7a6SAndroid Build Coastguard Worker with open(output_path, "wb") as f: 107*523fa7a6SAndroid Build Coastguard Worker # write the vocab size, bos/eos ids and max token length 108*523fa7a6SAndroid Build Coastguard Worker f.write( 109*523fa7a6SAndroid Build Coastguard Worker struct.pack( 110*523fa7a6SAndroid Build Coastguard Worker "IIII", self.n_words, self.bos_id, self.eos_id, max_token_length 111*523fa7a6SAndroid Build Coastguard Worker ) 112*523fa7a6SAndroid Build Coastguard Worker ) 113*523fa7a6SAndroid Build Coastguard Worker for bytes, score in zip(tokens, scores): 114*523fa7a6SAndroid Build Coastguard Worker f.write(struct.pack("fI", score, len(bytes))) 115*523fa7a6SAndroid Build Coastguard Worker f.write(bytes) 116*523fa7a6SAndroid Build Coastguard Worker logging.info(f"Wrote tokenizer to {output_path}") 117*523fa7a6SAndroid Build Coastguard Worker 118*523fa7a6SAndroid Build Coastguard Worker 119*523fa7a6SAndroid Build Coastguard Workerif __name__ == "__main__": 120*523fa7a6SAndroid Build Coastguard Worker parser = argparse.ArgumentParser() 121*523fa7a6SAndroid Build Coastguard Worker parser.add_argument( 122*523fa7a6SAndroid Build Coastguard Worker "-t", 123*523fa7a6SAndroid Build Coastguard Worker "--tokenizer-model", 124*523fa7a6SAndroid Build Coastguard Worker type=str, 125*523fa7a6SAndroid Build Coastguard Worker default="tokenizer.model", 126*523fa7a6SAndroid Build Coastguard Worker help="path to tokenizer model, given by sentencepiece", 127*523fa7a6SAndroid Build Coastguard Worker ) 128*523fa7a6SAndroid Build Coastguard Worker parser.add_argument( 129*523fa7a6SAndroid Build Coastguard Worker "-o", 130*523fa7a6SAndroid Build Coastguard Worker "--output-path", 131*523fa7a6SAndroid Build Coastguard Worker type=str, 132*523fa7a6SAndroid Build Coastguard Worker default=None, 133*523fa7a6SAndroid Build Coastguard Worker help="output path of postprocessed tokenizer model", 134*523fa7a6SAndroid Build Coastguard Worker ) 135*523fa7a6SAndroid Build Coastguard Worker parser.add_argument( 136*523fa7a6SAndroid Build Coastguard Worker "-p", 137*523fa7a6SAndroid Build Coastguard Worker "--prepend-padding", 138*523fa7a6SAndroid Build Coastguard Worker action="store_true", 139*523fa7a6SAndroid Build Coastguard Worker help="whether to prepend a padding token to the beginning of the tokenizer", 140*523fa7a6SAndroid Build Coastguard Worker ) 141*523fa7a6SAndroid Build Coastguard Worker 142*523fa7a6SAndroid Build Coastguard Worker args = parser.parse_args() 143*523fa7a6SAndroid Build Coastguard Worker 144*523fa7a6SAndroid Build Coastguard Worker t = Tokenizer(args.tokenizer_model) 145*523fa7a6SAndroid Build Coastguard Worker 146*523fa7a6SAndroid Build Coastguard Worker output_path = ( 147*523fa7a6SAndroid Build Coastguard Worker args.output_path 148*523fa7a6SAndroid Build Coastguard Worker if args.output_path 149*523fa7a6SAndroid Build Coastguard Worker else args.tokenizer_model.replace(".model", ".bin") 150*523fa7a6SAndroid Build Coastguard Worker ) 151*523fa7a6SAndroid Build Coastguard Worker t.export(output_path, prepend_padding=args.prepend_padding) 152