xref: /aosp_15_r20/external/executorch/extension/llm/tokenizer/tokenizer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7
8# Script to rewrite tokenizer model given by sentencepiece, with lightweight
9# postprocessing logic.
10
11import argparse
12import logging
13import os
14import struct
15from typing import List
16
17from sentencepiece import SentencePieceProcessor as SentencePieceProcessor
18
19
20class Tokenizer:
21    def __init__(self, model_path: str):
22        assert os.path.isfile(
23            model_path
24        ), f"Need a valid tokenizer model path but got {model_path}"
25        # pyre-fixme[28]: Unexpected keyword argument `model_file` to call `SentencePieceProcessor.__init__`.
26        self.sp_model = SentencePieceProcessor(model_file=model_path)
27        self.model_path = model_path
28
29        # BOS / EOS token IDs
30        self.n_words: int = self.sp_model.vocab_size()
31        self.bos_id: int = self.sp_model.bos_id()
32        self.eos_id: int = self.sp_model.eos_id()
33        logging.info(
34            f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
35        )
36        # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_piece_size`.
37        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
38
39    def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
40        assert type(s) is str
41        # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
42        t = self.sp_model.encode(s)
43        if bos:
44            t = [self.bos_id] + t
45        if eos:
46            t = t + [self.eos_id]
47        return t
48
49    def decode(self, t: List[int]) -> str:
50        # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
51        return self.sp_model.decode(t)
52
53    def decode_token(self, t: int) -> str:
54        # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
55        return self.sp_model.decode(t)
56
57    def export(self, output_path: str, *, prepend_padding: bool = False) -> None:
58        """
59        Export tokenizer.model to another serialization format. Here we did some lightweight
60        processing such as supporting prepend padding token, prepend max token length and
61        replace '_' back to empty space.
62
63        The binary format is:
64        1. vocab size: int32
65        2. bos token id: int32
66        3. eos token id: int32
67        4. max token length: int32
68        5. score: float32, len of bytes: int32, token bytes: [byte] for each token
69
70        :param output_path: output path of the new binary.
71        :param prepend_padding: a boolean to control if we want to prepend a padding token.
72
73        :return: None
74        """
75
76        # get all the tokens (postprocessed) and their scores as floats
77        tokens, scores = [], []
78
79        if prepend_padding:
80            # Here we use the default padding token and its score.
81            tokens.append("<pad>".encode("utf-8"))
82            scores.append(-1)
83
84        for i in range(self.n_words):
85
86            # decode the token and light postprocessing
87            # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `id_to_piece`.
88            t = self.sp_model.id_to_piece(i)
89            # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_score`.
90            s = self.sp_model.get_score(i)
91            # sentencepiece use '<s>' as BOS and '</s>' for EOS
92            if i == self.bos_id:
93                t = "<s>"
94            elif i == self.eos_id:
95                t = "</s>"
96            t = t.replace("▁", " ")  # sentencepiece uses this character as whitespace
97            b = t.encode("utf-8")  # bytes of this token, utf-8 encoded
98
99            tokens.append(b)
100            scores.append(s)
101
102        # record the max token length
103        max_token_length = 0 if not tokens else max(len(t) for t in tokens)
104
105        # write to a binary file
106        with open(output_path, "wb") as f:
107            # write the vocab size, bos/eos ids and max token length
108            f.write(
109                struct.pack(
110                    "IIII", self.n_words, self.bos_id, self.eos_id, max_token_length
111                )
112            )
113            for bytes, score in zip(tokens, scores):
114                f.write(struct.pack("fI", score, len(bytes)))
115                f.write(bytes)
116        logging.info(f"Wrote tokenizer to {output_path}")
117
118
119if __name__ == "__main__":
120    parser = argparse.ArgumentParser()
121    parser.add_argument(
122        "-t",
123        "--tokenizer-model",
124        type=str,
125        default="tokenizer.model",
126        help="path to tokenizer model, given by sentencepiece",
127    )
128    parser.add_argument(
129        "-o",
130        "--output-path",
131        type=str,
132        default=None,
133        help="output path of postprocessed tokenizer model",
134    )
135    parser.add_argument(
136        "-p",
137        "--prepend-padding",
138        action="store_true",
139        help="whether to prepend a padding token to the beginning of the tokenizer",
140    )
141
142    args = parser.parse_args()
143
144    t = Tokenizer(args.tokenizer_model)
145
146    output_path = (
147        args.output_path
148        if args.output_path
149        else args.tokenizer_model.replace(".model", ".bin")
150    )
151    t.export(output_path, prepend_padding=args.prepend_padding)
152