1# coding=utf-8
2# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3#
4# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5# and OPT implementations in this library. It has been modified from its
6# original forms to accommodate minor architectural differences compared
7# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8#
9# Licensed under the Apache License, Version 2.0 (the "License");
10# you may not use this file except in compliance with the License.
11# You may obtain a copy of the License at
12#
13#     http://www.apache.org/licenses/LICENSE-2.0
14#
15# Unless required by applicable law or agreed to in writing, software
16# distributed under the License is distributed on an "AS IS" BASIS,
17# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18# See the License for the specific language governing permissions and
19# limitations under the License.
20
21"""Tokenization classes for LLaMA."""
22import importlib.util
23import os
24from shutil import copyfile
25from typing import Any, Dict, List, Optional, Tuple
26
27import sentencepiece as spm
28
29from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
30from transformers.utils import logging
31
32
33def is_protobuf_available():
34    if importlib.util.find_spec("google") is None:
35        return False
36    return importlib.util.find_spec("google.protobuf") is not None
37
38
39def import_protobuf(error_message=""):
40    PROTOBUF_IMPORT_ERROR = """
41        {0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the
42        installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and follow the ones
43        that match your environment. Please note that you may need to restart your runtime after installation.
44        """
45    if is_protobuf_available():
46        import google.protobuf
47
48        if int(google.protobuf.__version__.split(".")[0]) < 4:
49            from transformers.utils import sentencepiece_model_pb2
50        else:
51            from transformers.utils import (
52                sentencepiece_model_pb2_new as sentencepiece_model_pb2,
53            )
54        return sentencepiece_model_pb2
55    else:
56        raise ImportError(PROTOBUF_IMPORT_ERROR.format(error_message))
57
58
59logger = logging.get_logger(__name__)
60
61VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
62
63PRETRAINED_VOCAB_FILES_MAP = {
64    "vocab_file": {
65        "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model",
66    },
67    "tokenizer_file": {
68        "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json",
69    },
70}
71PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
72    "hf-internal-testing/llama-tokenizer": 2048,
73}
74SPIECE_UNDERLINE = "▁"
75
76B_INST, E_INST = "[INST]", "[/INST]"
77B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
78
79# fmt: off
80DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \
81answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
82 that your responses are socially unbiased and positive in nature.
83
84If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \
85correct. If you don't know the answer to a question, please don't share false information."""
86# fmt: on
87
88
89class LlamaTokenizer(PreTrainedTokenizer):
90    """
91    Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
92    no padding token in the original model.
93
94    Args:
95        vocab_file (`str`):
96            Path to the vocabulary file.
97        unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
98            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
99            token instead.
100        bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<s>"`):
101            The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
102        eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"</s>"`):
103            The end of sequence token.
104        pad_token (`str` or `tokenizers.AddedToken`, *optional*):
105            A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
106            attention mechanisms or loss computation.
107        sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
108            Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
109            SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
110            to set:
111
112            - `enable_sampling`: Enable subword regularization.
113            - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
114
115              - `nbest_size = {0,1}`: No sampling is performed.
116              - `nbest_size > 1`: samples from the nbest_size results.
117              - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
118                using forward-filtering-and-backward-sampling algorithm.
119
120            - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
121              BPE-dropout.
122
123        add_bos_token (`bool`, *optional*, defaults to `True`):
124            Whether or not to add an `bos_token` at the start of sequences.
125        add_eos_token (`bool`, *optional*, defaults to `False`):
126            Whether or not to add an `eos_token` at the end of sequences.
127        clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
128            Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
129            extra spaces.
130        use_default_system_prompt (`bool`, *optional*, defaults to `False`):
131            Whether or not the default system prompt for Llama should be used.
132        spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
133            Whether or not to add spaces between special tokens.
134        legacy (`bool`, *optional*):
135            Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622
136            and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple
137            example:
138
139            - `legacy=True`:
140            ```python
141            >>> from transformers import T5Tokenizer
142
143            >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=True)
144            >>> tokenizer.encode("Hello <extra_id_0>.")
145            [8774, 32099, 3, 5, 1]
146            ```
147            - `legacy=False`:
148            ```python
149            >>> from transformers import T5Tokenizer
150
151            >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=False)
152            >>> tokenizer.encode("Hello <extra_id_0>.")  # the extra space `[3]` is no longer here
153            [8774, 32099, 5, 1]
154            ```
155            Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details.
156
157    """
158
159    vocab_files_names = VOCAB_FILES_NAMES
160    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
161    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
162    model_input_names = ["input_ids", "attention_mask"]
163
164    def __init__(
165        self,
166        vocab_file,
167        unk_token="<unk>",
168        bos_token="<s>",
169        eos_token="</s>",
170        pad_token=None,
171        sp_model_kwargs: Optional[Dict[str, Any]] = None,
172        add_bos_token=True,
173        add_eos_token=False,
174        clean_up_tokenization_spaces=False,
175        use_default_system_prompt=False,
176        spaces_between_special_tokens=False,
177        legacy=None,
178        **kwargs,
179    ):
180        self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
181        bos_token = (
182            AddedToken(bos_token, normalized=False, special=True)
183            if isinstance(bos_token, str)
184            else bos_token
185        )
186        eos_token = (
187            AddedToken(eos_token, normalized=False, special=True)
188            if isinstance(eos_token, str)
189            else eos_token
190        )
191        unk_token = (
192            AddedToken(unk_token, normalized=False, special=True)
193            if isinstance(unk_token, str)
194            else unk_token
195        )
196        pad_token = (
197            AddedToken(pad_token, normalized=False, special=True)
198            if isinstance(pad_token, str)
199            else pad_token
200        )
201
202        if legacy is None:
203            legacy = True
204
205        self.legacy = legacy
206        self.vocab_file = vocab_file
207        self.add_bos_token = add_bos_token
208        self.add_eos_token = add_eos_token
209        self.use_default_system_prompt = use_default_system_prompt
210        self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False))
211
212        super().__init__(
213            bos_token=bos_token,
214            eos_token=eos_token,
215            unk_token=unk_token,
216            pad_token=pad_token,
217            add_bos_token=add_bos_token,
218            add_eos_token=add_eos_token,
219            sp_model_kwargs=self.sp_model_kwargs,
220            clean_up_tokenization_spaces=clean_up_tokenization_spaces,
221            use_default_system_prompt=use_default_system_prompt,
222            spaces_between_special_tokens=spaces_between_special_tokens,
223            legacy=legacy,
224            **kwargs,
225        )
226
227    @property
228    def unk_token_length(self):
229        return len(self.sp_model.encode(str(self.unk_token)))
230
231    # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor
232    def get_spm_processor(self, from_slow=False):
233        tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
234        if self.legacy or from_slow:  # no dependency on protobuf
235            tokenizer.Load(self.vocab_file)
236            return tokenizer
237
238        with open(self.vocab_file, "rb") as f:
239            sp_model = f.read()
240            model_pb2 = import_protobuf(
241                f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)"
242            )
243            model = model_pb2.ModelProto.FromString(sp_model)
244            normalizer_spec = model_pb2.NormalizerSpec()
245            normalizer_spec.add_dummy_prefix = False
246            model.normalizer_spec.MergeFrom(normalizer_spec)
247            sp_model = model.SerializeToString()
248            tokenizer.LoadFromSerializedProto(sp_model)
249        return tokenizer
250
251    def __getstate__(self):
252        state = self.__dict__.copy()
253        state["sp_model"] = None
254        state["sp_model_proto"] = self.sp_model.serialized_model_proto()
255        return state
256
257    def __setstate__(self, d):
258        self.__dict__ = d
259        self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
260        self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
261
262    @property
263    def vocab_size(self):
264        """Returns vocab size"""
265        return self.sp_model.get_piece_size()
266
267    def get_vocab(self):
268        """Returns vocab as a dict"""
269        vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
270        vocab.update(self.added_tokens_encoder)
271        return vocab
272
273    # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
274    def tokenize(self, text, add_special_tokens=False, **kwargs) -> List[str]:
275        """
276        Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the
277        first token is special.
278        """
279        if self.legacy or len(text) == 0:
280            return super().tokenize(text, **kwargs)
281
282        tokens = super().tokenize(
283            SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs
284        )
285
286        if (
287            len(tokens) > 1
288            and tokens[0] == SPIECE_UNDERLINE
289            and tokens[1] in self.all_special_tokens
290        ):
291            tokens = tokens[1:]
292        return tokens
293
294    # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
295    def _tokenize(self, text, **kwargs):
296        """
297        Returns a tokenized string.
298
299        We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
300        SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give
301        `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the
302        `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
303        `self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
304        """
305        tokens = self.sp_model.encode(text, out_type=str)
306        if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")):
307            return tokens
308
309        # 1. Encode string + prefix ex: "<unk> Hey"
310        tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
311        # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
312        return (
313            tokens[self.unk_token_length :]
314            if len(tokens) >= self.unk_token_length
315            else tokens
316        )
317
318    def _convert_token_to_id(self, token):
319        """Converts a token (str) in an id using the vocab."""
320        return self.sp_model.piece_to_id(token)
321
322    def _convert_id_to_token(self, index):
323        """Converts an index (integer) in a token (str) using the vocab."""
324        token = self.sp_model.IdToPiece(index)
325        return token
326
327    def convert_tokens_to_string(self, tokens):
328        """Converts a sequence of tokens (string) in a single string."""
329        # since we manually add the prefix space, we have to remove it when decoding
330        if tokens[0].startswith(SPIECE_UNDERLINE):
331            tokens[0] = tokens[0][1:]
332
333        current_sub_tokens = []
334        out_string = ""
335        prev_is_special = False
336        for i, token in enumerate(tokens):
337            # make sure that special tokens are not decoded using sentencepiece model
338            if token in self.all_special_tokens:
339                if not prev_is_special and i != 0 and self.legacy:
340                    out_string += " "
341                out_string += self.sp_model.decode(current_sub_tokens) + token
342                prev_is_special = True
343                current_sub_tokens = []
344            else:
345                current_sub_tokens.append(token)
346                prev_is_special = False
347        out_string += self.sp_model.decode(current_sub_tokens)
348        return out_string
349
350    def save_vocabulary(
351        self, save_directory, filename_prefix: Optional[str] = None
352    ) -> Tuple[str]:
353        """
354        Save the vocabulary and special tokens file to a directory.
355
356        Args:
357            save_directory (`str`):
358                The directory in which to save the vocabulary.
359
360        Returns:
361            `Tuple(str)`: Paths to the files saved.
362        """
363        if not os.path.isdir(save_directory):
364            logger.error(f"Vocabulary path ({save_directory}) should be a directory")
365            return
366        out_vocab_file = os.path.join(
367            save_directory,
368            (filename_prefix + "-" if filename_prefix else "")
369            + VOCAB_FILES_NAMES["vocab_file"],
370        )
371
372        if os.path.abspath(self.vocab_file) != os.path.abspath(
373            out_vocab_file
374        ) and os.path.isfile(self.vocab_file):
375            copyfile(self.vocab_file, out_vocab_file)
376        elif not os.path.isfile(self.vocab_file):
377            with open(out_vocab_file, "wb") as fi:
378                content_spiece_model = self.sp_model.serialized_model_proto()
379                fi.write(content_spiece_model)
380
381        return (out_vocab_file,)
382
383    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
384        bos_token_id = [self.bos_token_id] if self.add_bos_token else []
385        eos_token_id = [self.eos_token_id] if self.add_eos_token else []
386
387        output = bos_token_id + token_ids_0 + eos_token_id
388
389        if token_ids_1 is not None:
390            output = output + bos_token_id + token_ids_1 + eos_token_id
391
392        return output
393
394    def get_special_tokens_mask(
395        self,
396        token_ids_0: List[int],
397        token_ids_1: Optional[List[int]] = None,
398        already_has_special_tokens: bool = False,
399    ) -> List[int]:
400        """
401        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
402        special tokens using the tokenizer `prepare_for_model` method.
403
404        Args:
405            token_ids_0 (`List[int]`):
406                List of IDs.
407            token_ids_1 (`List[int]`, *optional*):
408                Optional second list of IDs for sequence pairs.
409            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
410                Whether or not the token list is already formatted with special tokens for the model.
411
412        Returns:
413            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
414        """
415        if already_has_special_tokens:
416            return super().get_special_tokens_mask(
417                token_ids_0=token_ids_0,
418                token_ids_1=token_ids_1,
419                already_has_special_tokens=True,
420            )
421
422        bos_token_id = [1] if self.add_bos_token else []
423        eos_token_id = [1] if self.add_eos_token else []
424
425        if token_ids_1 is None:
426            return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
427        return (
428            bos_token_id
429            + ([0] * len(token_ids_0))
430            + eos_token_id
431            + bos_token_id
432            + ([0] * len(token_ids_1))
433            + eos_token_id
434        )
435
436    def create_token_type_ids_from_sequences(
437        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
438    ) -> List[int]:
439        """
440        Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
441        sequence pair mask has the following format:
442
443        ```
444        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
445        | first sequence    | second sequence |
446        ```
447
448        if token_ids_1 is None, only returns the first portion of the mask (0s).
449
450        Args:
451            token_ids_0 (`List[int]`):
452                List of ids.
453            token_ids_1 (`List[int]`, *optional*):
454                Optional second list of IDs for sequence pairs.
455
456        Returns:
457            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
458        """
459        bos_token_id = [self.bos_token_id] if self.add_bos_token else []
460        eos_token_id = [self.eos_token_id] if self.add_eos_token else []
461
462        output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
463
464        if token_ids_1 is not None:
465            output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
466
467        return output
468
469    @property
470    def default_chat_template(self):
471        """
472        LLaMA uses [INST] and [/INST] to indicate user messages, and <<SYS>> and <</SYS>> to indicate system messages.
473        Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict
474        user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering
475        rather than needing special tokens. The system message is partly 'embedded' in the first user message, which
476        results in an unusual token ordering when it is present. This template should definitely be changed if you wish
477        to fine-tune a model with more flexible role ordering!
478
479        The output should look something like:
480
481        <bos>[INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer <eos><bos>[INST] Prompt [/INST] Answer <eos>
482        <bos>[INST] Prompt [/INST]
483
484        The reference for this chat template is [this code
485        snippet](https://github.com/facebookresearch/llama/blob/556949fdfb72da27c2f4a40b7f0e4cf0b8153a28/llama/generation.py#L320-L362)
486        in the original repository.
487        """
488        logger.warning_once(
489            "\nNo chat template is defined for this tokenizer - using the default template "
490            f"for the {self.__class__.__name__} class. If the default is not appropriate for "
491            "your model, please set `tokenizer.chat_template` to an appropriate template. "
492            "See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n"
493        )
494        template = (
495            "{% if messages[0]['role'] == 'system' %}"
496            "{% set loop_messages = messages[1:] %}"  # Extract system message if it's present
497            "{% set system_message = messages[0]['content'] %}"
498            "{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
499            "{% set loop_messages = messages %}"  # Or use the default system message if the flag is set
500            "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
501            "{% else %}"
502            "{% set loop_messages = messages %}"
503            "{% set system_message = false %}"
504            "{% endif %}"
505            "{% for message in loop_messages %}"  # Loop over all non-system messages
506            "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
507            "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
508            "{% endif %}"
509            "{% if loop.index0 == 0 and system_message != false %}"  # Embed system message in first message
510            "{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}"
511            "{% else %}"
512            "{% set content = message['content'] %}"
513            "{% endif %}"
514            "{% if message['role'] == 'user' %}"  # After all of that, handle messages/roles in a fairly normal way
515            "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}"
516            "{% elif message['role'] == 'system' %}"
517            "{{ '<<SYS>>\\n' + content.strip() + '\\n<</SYS>>\\n\\n' }}"
518            "{% elif message['role'] == 'assistant' %}"
519            "{{ ' '  + content.strip() + ' ' + eos_token }}"
520            "{% endif %}"
521            "{% endfor %}"
522        )
523        template = template.replace(
524            "USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false"
525        )
526        default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
527        template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
528
529        return template
530