xref: /aosp_15_r20/external/executorch/extension/llm/tokenizer/tokenizer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker
8*523fa7a6SAndroid Build Coastguard Worker# Script to rewrite tokenizer model given by sentencepiece, with lightweight
9*523fa7a6SAndroid Build Coastguard Worker# postprocessing logic.
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Workerimport argparse
12*523fa7a6SAndroid Build Coastguard Workerimport logging
13*523fa7a6SAndroid Build Coastguard Workerimport os
14*523fa7a6SAndroid Build Coastguard Workerimport struct
15*523fa7a6SAndroid Build Coastguard Workerfrom typing import List
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Workerfrom sentencepiece import SentencePieceProcessor as SentencePieceProcessor
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Workerclass Tokenizer:
21*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, model_path: str):
22*523fa7a6SAndroid Build Coastguard Worker        assert os.path.isfile(
23*523fa7a6SAndroid Build Coastguard Worker            model_path
24*523fa7a6SAndroid Build Coastguard Worker        ), f"Need a valid tokenizer model path but got {model_path}"
25*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[28]: Unexpected keyword argument `model_file` to call `SentencePieceProcessor.__init__`.
26*523fa7a6SAndroid Build Coastguard Worker        self.sp_model = SentencePieceProcessor(model_file=model_path)
27*523fa7a6SAndroid Build Coastguard Worker        self.model_path = model_path
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Worker        # BOS / EOS token IDs
30*523fa7a6SAndroid Build Coastguard Worker        self.n_words: int = self.sp_model.vocab_size()
31*523fa7a6SAndroid Build Coastguard Worker        self.bos_id: int = self.sp_model.bos_id()
32*523fa7a6SAndroid Build Coastguard Worker        self.eos_id: int = self.sp_model.eos_id()
33*523fa7a6SAndroid Build Coastguard Worker        logging.info(
34*523fa7a6SAndroid Build Coastguard Worker            f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
35*523fa7a6SAndroid Build Coastguard Worker        )
36*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_piece_size`.
37*523fa7a6SAndroid Build Coastguard Worker        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
38*523fa7a6SAndroid Build Coastguard Worker
39*523fa7a6SAndroid Build Coastguard Worker    def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
40*523fa7a6SAndroid Build Coastguard Worker        assert type(s) is str
41*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
42*523fa7a6SAndroid Build Coastguard Worker        t = self.sp_model.encode(s)
43*523fa7a6SAndroid Build Coastguard Worker        if bos:
44*523fa7a6SAndroid Build Coastguard Worker            t = [self.bos_id] + t
45*523fa7a6SAndroid Build Coastguard Worker        if eos:
46*523fa7a6SAndroid Build Coastguard Worker            t = t + [self.eos_id]
47*523fa7a6SAndroid Build Coastguard Worker        return t
48*523fa7a6SAndroid Build Coastguard Worker
49*523fa7a6SAndroid Build Coastguard Worker    def decode(self, t: List[int]) -> str:
50*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
51*523fa7a6SAndroid Build Coastguard Worker        return self.sp_model.decode(t)
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Worker    def decode_token(self, t: int) -> str:
54*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
55*523fa7a6SAndroid Build Coastguard Worker        return self.sp_model.decode(t)
56*523fa7a6SAndroid Build Coastguard Worker
57*523fa7a6SAndroid Build Coastguard Worker    def export(self, output_path: str, *, prepend_padding: bool = False) -> None:
58*523fa7a6SAndroid Build Coastguard Worker        """
59*523fa7a6SAndroid Build Coastguard Worker        Export tokenizer.model to another serialization format. Here we did some lightweight
60*523fa7a6SAndroid Build Coastguard Worker        processing such as supporting prepend padding token, prepend max token length and
61*523fa7a6SAndroid Build Coastguard Worker        replace '_' back to empty space.
62*523fa7a6SAndroid Build Coastguard Worker
63*523fa7a6SAndroid Build Coastguard Worker        The binary format is:
64*523fa7a6SAndroid Build Coastguard Worker        1. vocab size: int32
65*523fa7a6SAndroid Build Coastguard Worker        2. bos token id: int32
66*523fa7a6SAndroid Build Coastguard Worker        3. eos token id: int32
67*523fa7a6SAndroid Build Coastguard Worker        4. max token length: int32
68*523fa7a6SAndroid Build Coastguard Worker        5. score: float32, len of bytes: int32, token bytes: [byte] for each token
69*523fa7a6SAndroid Build Coastguard Worker
70*523fa7a6SAndroid Build Coastguard Worker        :param output_path: output path of the new binary.
71*523fa7a6SAndroid Build Coastguard Worker        :param prepend_padding: a boolean to control if we want to prepend a padding token.
72*523fa7a6SAndroid Build Coastguard Worker
73*523fa7a6SAndroid Build Coastguard Worker        :return: None
74*523fa7a6SAndroid Build Coastguard Worker        """
75*523fa7a6SAndroid Build Coastguard Worker
76*523fa7a6SAndroid Build Coastguard Worker        # get all the tokens (postprocessed) and their scores as floats
77*523fa7a6SAndroid Build Coastguard Worker        tokens, scores = [], []
78*523fa7a6SAndroid Build Coastguard Worker
79*523fa7a6SAndroid Build Coastguard Worker        if prepend_padding:
80*523fa7a6SAndroid Build Coastguard Worker            # Here we use the default padding token and its score.
81*523fa7a6SAndroid Build Coastguard Worker            tokens.append("<pad>".encode("utf-8"))
82*523fa7a6SAndroid Build Coastguard Worker            scores.append(-1)
83*523fa7a6SAndroid Build Coastguard Worker
84*523fa7a6SAndroid Build Coastguard Worker        for i in range(self.n_words):
85*523fa7a6SAndroid Build Coastguard Worker
86*523fa7a6SAndroid Build Coastguard Worker            # decode the token and light postprocessing
87*523fa7a6SAndroid Build Coastguard Worker            # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `id_to_piece`.
88*523fa7a6SAndroid Build Coastguard Worker            t = self.sp_model.id_to_piece(i)
89*523fa7a6SAndroid Build Coastguard Worker            # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_score`.
90*523fa7a6SAndroid Build Coastguard Worker            s = self.sp_model.get_score(i)
91*523fa7a6SAndroid Build Coastguard Worker            # sentencepiece use '<s>' as BOS and '</s>' for EOS
92*523fa7a6SAndroid Build Coastguard Worker            if i == self.bos_id:
93*523fa7a6SAndroid Build Coastguard Worker                t = "<s>"
94*523fa7a6SAndroid Build Coastguard Worker            elif i == self.eos_id:
95*523fa7a6SAndroid Build Coastguard Worker                t = "</s>"
96*523fa7a6SAndroid Build Coastguard Worker            t = t.replace("▁", " ")  # sentencepiece uses this character as whitespace
97*523fa7a6SAndroid Build Coastguard Worker            b = t.encode("utf-8")  # bytes of this token, utf-8 encoded
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker            tokens.append(b)
100*523fa7a6SAndroid Build Coastguard Worker            scores.append(s)
101*523fa7a6SAndroid Build Coastguard Worker
102*523fa7a6SAndroid Build Coastguard Worker        # record the max token length
103*523fa7a6SAndroid Build Coastguard Worker        max_token_length = 0 if not tokens else max(len(t) for t in tokens)
104*523fa7a6SAndroid Build Coastguard Worker
105*523fa7a6SAndroid Build Coastguard Worker        # write to a binary file
106*523fa7a6SAndroid Build Coastguard Worker        with open(output_path, "wb") as f:
107*523fa7a6SAndroid Build Coastguard Worker            # write the vocab size, bos/eos ids and max token length
108*523fa7a6SAndroid Build Coastguard Worker            f.write(
109*523fa7a6SAndroid Build Coastguard Worker                struct.pack(
110*523fa7a6SAndroid Build Coastguard Worker                    "IIII", self.n_words, self.bos_id, self.eos_id, max_token_length
111*523fa7a6SAndroid Build Coastguard Worker                )
112*523fa7a6SAndroid Build Coastguard Worker            )
113*523fa7a6SAndroid Build Coastguard Worker            for bytes, score in zip(tokens, scores):
114*523fa7a6SAndroid Build Coastguard Worker                f.write(struct.pack("fI", score, len(bytes)))
115*523fa7a6SAndroid Build Coastguard Worker                f.write(bytes)
116*523fa7a6SAndroid Build Coastguard Worker        logging.info(f"Wrote tokenizer to {output_path}")
117*523fa7a6SAndroid Build Coastguard Worker
118*523fa7a6SAndroid Build Coastguard Worker
119*523fa7a6SAndroid Build Coastguard Workerif __name__ == "__main__":
120*523fa7a6SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser()
121*523fa7a6SAndroid Build Coastguard Worker    parser.add_argument(
122*523fa7a6SAndroid Build Coastguard Worker        "-t",
123*523fa7a6SAndroid Build Coastguard Worker        "--tokenizer-model",
124*523fa7a6SAndroid Build Coastguard Worker        type=str,
125*523fa7a6SAndroid Build Coastguard Worker        default="tokenizer.model",
126*523fa7a6SAndroid Build Coastguard Worker        help="path to tokenizer model, given by sentencepiece",
127*523fa7a6SAndroid Build Coastguard Worker    )
128*523fa7a6SAndroid Build Coastguard Worker    parser.add_argument(
129*523fa7a6SAndroid Build Coastguard Worker        "-o",
130*523fa7a6SAndroid Build Coastguard Worker        "--output-path",
131*523fa7a6SAndroid Build Coastguard Worker        type=str,
132*523fa7a6SAndroid Build Coastguard Worker        default=None,
133*523fa7a6SAndroid Build Coastguard Worker        help="output path of postprocessed tokenizer model",
134*523fa7a6SAndroid Build Coastguard Worker    )
135*523fa7a6SAndroid Build Coastguard Worker    parser.add_argument(
136*523fa7a6SAndroid Build Coastguard Worker        "-p",
137*523fa7a6SAndroid Build Coastguard Worker        "--prepend-padding",
138*523fa7a6SAndroid Build Coastguard Worker        action="store_true",
139*523fa7a6SAndroid Build Coastguard Worker        help="whether to prepend a padding token to the beginning of the tokenizer",
140*523fa7a6SAndroid Build Coastguard Worker    )
141*523fa7a6SAndroid Build Coastguard Worker
142*523fa7a6SAndroid Build Coastguard Worker    args = parser.parse_args()
143*523fa7a6SAndroid Build Coastguard Worker
144*523fa7a6SAndroid Build Coastguard Worker    t = Tokenizer(args.tokenizer_model)
145*523fa7a6SAndroid Build Coastguard Worker
146*523fa7a6SAndroid Build Coastguard Worker    output_path = (
147*523fa7a6SAndroid Build Coastguard Worker        args.output_path
148*523fa7a6SAndroid Build Coastguard Worker        if args.output_path
149*523fa7a6SAndroid Build Coastguard Worker        else args.tokenizer_model.replace(".model", ".bin")
150*523fa7a6SAndroid Build Coastguard Worker    )
151*523fa7a6SAndroid Build Coastguard Worker    t.export(output_path, prepend_padding=args.prepend_padding)
152