1# coding=utf-8
2# Copyright 2020 The HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""
16 Tokenization classes for fast tokenizers (provided by HuggingFace's tokenizers library). For slow (python) tokenizers
17 see tokenization_utils.py
18"""
19import copy
20import json
21import os
22from collections import defaultdict
23from typing import Any, Dict, List, Optional, Tuple, Union
24
25import tokenizers.pre_tokenizers as pre_tokenizers_fast
26from tokenizers import Encoding as EncodingFast, Tokenizer as TokenizerFast
27from tokenizers.decoders import Decoder as DecoderFast
28from tokenizers.trainers import (
29    BpeTrainer,
30    UnigramTrainer,
31    WordLevelTrainer,
32    WordPieceTrainer,
33)
34
35from transformers.convert_slow_tokenizer import convert_slow_tokenizer
36from transformers.tokenization_utils import PreTrainedTokenizer
37from transformers.tokenization_utils_base import (
38    AddedToken,
39    BatchEncoding,
40    INIT_TOKENIZER_DOCSTRING,
41    PreTokenizedInput,
42    PreTokenizedInputPair,
43    PreTrainedTokenizerBase,
44    SpecialTokensMixin,
45    TextInput,
46    TextInputPair,
47    TruncationStrategy,
48)
49from transformers.utils import add_end_docstrings, logging, PaddingStrategy
50
51
52# flake8: noqa: C901
53
54
55logger = logging.get_logger(__name__)
56
57# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
58TOKENIZER_FILE = "tokenizer.json"
59SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
60TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
61
62# Slow tokenizers have an additional added tokens files
63ADDED_TOKENS_FILE = "added_tokens.json"
64
65INIT_TOKENIZER_DOCSTRING += """
66        tokenizer_object ([`tokenizers.Tokenizer`]):
67            A [`tokenizers.Tokenizer`] object from �� tokenizers to instantiate from. See [Using tokenizers from ��
68            tokenizers](../fast_tokenizers) for more information.
69        tokenizer_file ([`str`]):
70            A path to a local JSON file representing a previously serialized [`tokenizers.Tokenizer`] object from ��
71            tokenizers.
72"""
73
74MODEL_TO_TRAINER_MAPPING = {
75    "BPE": BpeTrainer,
76    "Unigram": UnigramTrainer,
77    "WordLevel": WordLevelTrainer,
78    "WordPiece": WordPieceTrainer,
79}
80
81VOCAB_FILES_NAMES = {"tokenizer_file": TOKENIZER_FILE}
82
83
84@add_end_docstrings(INIT_TOKENIZER_DOCSTRING)
85class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
86    """
87    Base class for all fast tokenizers (wrapping HuggingFace tokenizers library).
88
89    Inherits from [`~tokenization_utils_base.PreTrainedTokenizerBase`].
90
91    Handles all the shared methods for tokenization and special tokens, as well as methods for
92    downloading/caching/loading pretrained tokenizers, as well as adding tokens to the vocabulary.
93
94    This class also contains the added tokens in a unified way on top of all tokenizers so we don't have to handle the
95    specific vocabulary augmentation methods of the various underlying dictionary structures (BPE, sentencepiece...).
96    """
97
98    vocab_files_names = VOCAB_FILES_NAMES
99    slow_tokenizer_class: PreTrainedTokenizer = None
100
101    def __init__(self, *args, **kwargs):
102        tokenizer_object = kwargs.pop("tokenizer_object", None)
103        slow_tokenizer = kwargs.pop("__slow_tokenizer", None)
104        fast_tokenizer_file = kwargs.pop("tokenizer_file", None)
105        from_slow = kwargs.pop("from_slow", False)
106        added_tokens_decoder = kwargs.pop("added_tokens_decoder", {})
107
108        if from_slow and slow_tokenizer is None and self.slow_tokenizer_class is None:
109            raise ValueError(
110                "Cannot instantiate this tokenizer from a slow version. If it's based on sentencepiece, make sure you "
111                "have sentencepiece installed."
112            )
113
114        if tokenizer_object is not None:
115            fast_tokenizer = copy.deepcopy(tokenizer_object)
116        elif fast_tokenizer_file is not None and not from_slow:
117            # We have a serialization from tokenizers which let us directly build the backend
118            fast_tokenizer = TokenizerFast.from_file(fast_tokenizer_file)
119        elif slow_tokenizer is not None:
120            # We need to convert a slow tokenizer to build the backend
121            fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)
122        elif self.slow_tokenizer_class is not None:
123            # We need to create and convert a slow tokenizer to build the backend
124            slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs)
125            fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)
126        else:
127            raise ValueError(
128                "Couldn't instantiate the backend tokenizer from one of: \n"
129                "(1) a `tokenizers` library serialization file, \n"
130                "(2) a slow tokenizer instance to convert or \n"
131                "(3) an equivalent slow tokenizer class to instantiate and convert. \n"
132                "You need to have sentencepiece installed to convert a slow tokenizer to a fast one."
133            )
134
135        self._tokenizer = fast_tokenizer
136
137        if slow_tokenizer is not None:
138            kwargs.update(slow_tokenizer.init_kwargs)
139
140        self._decode_use_source_tokenizer = False
141
142        _truncation = self._tokenizer.truncation
143
144        if _truncation is not None:
145            self._tokenizer.enable_truncation(**_truncation)
146            kwargs.setdefault("max_length", _truncation["max_length"])
147            kwargs.setdefault("truncation_side", _truncation["direction"])
148            kwargs.setdefault("stride", _truncation["stride"])
149            kwargs.setdefault("truncation_strategy", _truncation["strategy"])
150        else:
151            self._tokenizer.no_truncation()
152
153        _padding = self._tokenizer.padding
154        if _padding is not None:
155            self._tokenizer.enable_padding(**_padding)
156            kwargs.setdefault("pad_token", _padding["pad_token"])
157            kwargs.setdefault("pad_token_type_id", _padding["pad_type_id"])
158            kwargs.setdefault("padding_side", _padding["direction"])
159            kwargs.setdefault("max_length", _padding["length"])
160            kwargs.setdefault("pad_to_multiple_of", _padding["pad_to_multiple_of"])
161
162        # We call this after having initialized the backend tokenizer because we update it.
163        super().__init__(**kwargs)
164
165        # The following logic will be replace with a single add_tokens once a fix is pushed to tokenizers
166        # allows converting a slow -> fast, non-legacy: if the `tokenizer.json` does not have all the added tokens
167        # uses the information stored in `added_tokens_decoder`.
168        # this is costly for fast tokenizers as we re-compute the regex again. But not all tokens are added tokens
169        tokens_to_add = [
170            token
171            for index, token in sorted(added_tokens_decoder.items(), key=lambda x: x[0])
172            if token not in self.added_tokens_decoder
173        ]
174        encoder = list(self.added_tokens_encoder.keys()) + [
175            str(token) for token in tokens_to_add
176        ]
177        # if some of the special tokens are strings, we check if we don't already have a token
178        tokens_to_add += [
179            token
180            for token in self.all_special_tokens_extended
181            if token not in encoder and token not in tokens_to_add
182        ]
183        if len(tokens_to_add) > 0:
184            # super hack: if a token.special is set, tokenizer ignores it for now so FIXME @ArthurZ
185            # Accumulate added tokens into batches of special/non-special tokens, because calling add_tokens() for
186            # individual tokens would repeatedly rebuild a trie, which can be slow.
187            is_last_special = None
188            tokens = []
189            special_tokens = self.all_special_tokens
190            for token in tokens_to_add:
191                is_special = (
192                    (token.special or str(token) in special_tokens)
193                    if isinstance(token, AddedToken)
194                    else str(token) in special_tokens
195                )
196                if is_last_special is None or is_last_special == is_special:
197                    tokens.append(token)
198                else:
199                    self._add_tokens(tokens, special_tokens=is_last_special)
200                    tokens = [token]
201                is_last_special = is_special
202            if tokens:
203                self._add_tokens(tokens, special_tokens=is_last_special)
204
205    @property
206    def is_fast(self) -> bool:
207        return True
208
209    @property
210    def can_save_slow_tokenizer(self) -> bool:
211        """
212        `bool`: Whether or not the slow tokenizer can be saved. Usually for sentencepiece based slow tokenizer, this
213        can only be `True` if the original `"sentencepiece.model"` was not deleted.
214        """
215        return True
216
217    @property
218    def vocab_size(self) -> int:
219        """
220        `int`: Size of the base vocabulary (without the added tokens).
221        """
222        return self._tokenizer.get_vocab_size(with_added_tokens=False)
223
224    def get_vocab(self) -> Dict[str, int]:
225        return self._tokenizer.get_vocab(with_added_tokens=True)
226
227    @property
228    def vocab(self) -> Dict[str, int]:
229        return self.get_vocab()
230
231    @property
232    def added_tokens_encoder(self) -> Dict[str, int]:
233        """
234        Returns the sorted mapping from string to index. The added tokens encoder is cached for performance
235        optimisation in `self._added_tokens_encoder` for the slow tokenizers.
236        """
237        return {
238            k.content: v
239            for v, k in sorted(
240                self.added_tokens_decoder.items(), key=lambda item: item[0]
241            )
242        }
243
244    @property
245    def added_tokens_decoder(self) -> Dict[int, AddedToken]:
246        """
247        Returns the added tokens in the vocabulary as a dictionary of index to AddedToken.
248
249        Returns:
250            `Dict[str, int]`: The added tokens.
251        """
252        return self._tokenizer.get_added_tokens_decoder()
253
254    def get_added_vocab(self) -> Dict[str, int]:
255        """
256        Returns the added tokens in the vocabulary as a dictionary of token to index.
257
258        Returns:
259            `Dict[str, int]`: The added tokens.
260        """
261        return {
262            k.content: v
263            for v, k in sorted(
264                self.added_tokens_decoder.items(), key=lambda item: item[0]
265            )
266        }
267
268    def __len__(self) -> int:
269        """
270        Size of the full vocabulary with the added tokens.
271        """
272        return self._tokenizer.get_vocab_size(with_added_tokens=True)
273
274    @property
275    def backend_tokenizer(self) -> TokenizerFast:
276        """
277        `tokenizers.implementations.BaseTokenizer`: The Rust tokenizer used as a backend.
278        """
279        return self._tokenizer
280
281    @property
282    def decoder(self) -> DecoderFast:
283        """
284        `tokenizers.decoders.Decoder`: The Rust decoder for this tokenizer.
285        """
286        return self._tokenizer.decoder
287
288    def _convert_encoding(
289        self,
290        encoding: EncodingFast,
291        return_token_type_ids: Optional[bool] = None,
292        return_attention_mask: Optional[bool] = None,
293        return_overflowing_tokens: bool = False,
294        return_special_tokens_mask: bool = False,
295        return_offsets_mapping: bool = False,
296        return_length: bool = False,
297        verbose: bool = True,
298    ) -> Tuple[Dict[str, Any], List[EncodingFast]]:
299        """
300        Convert the encoding representation (from low-level HuggingFace tokenizer output) to a python Dict and a list
301        of encodings, take care of building a batch from overflowing tokens.
302
303        Overflowing tokens are converted to additional examples (like batches) so the output values of the dict are
304        lists (overflows) of lists (tokens).
305
306        Output shape: (overflows, sequence length)
307        """
308        if return_token_type_ids is None:
309            return_token_type_ids = "token_type_ids" in self.model_input_names
310        if return_attention_mask is None:
311            return_attention_mask = "attention_mask" in self.model_input_names
312
313        if return_overflowing_tokens and encoding.overflowing is not None:
314            encodings = [encoding] + encoding.overflowing
315        else:
316            encodings = [encoding]
317
318        encoding_dict = defaultdict(list)
319        for e in encodings:
320            encoding_dict["input_ids"].append(e.ids)
321
322            if return_token_type_ids:
323                encoding_dict["token_type_ids"].append(e.type_ids)
324            if return_attention_mask:
325                encoding_dict["attention_mask"].append(e.attention_mask)
326            if return_special_tokens_mask:
327                encoding_dict["special_tokens_mask"].append(e.special_tokens_mask)
328            if return_offsets_mapping:
329                encoding_dict["offset_mapping"].append(e.offsets)
330            if return_length:
331                encoding_dict["length"].append(len(e.ids))
332
333        return encoding_dict, encodings
334
335    def convert_tokens_to_ids(
336        self, tokens: Union[str, List[str]]
337    ) -> Union[int, List[int]]:
338        """
339        Converts a token string (or a sequence of tokens) in a single integer id (or a sequence of ids), using the
340        vocabulary.
341
342        Args:
343            tokens (`str` or `List[str]`): One or several token(s) to convert to token id(s).
344
345        Returns:
346            `int` or `List[int]`: The token id or list of token ids.
347        """
348        if tokens is None:
349            return None
350
351        if isinstance(tokens, str):
352            return self._convert_token_to_id_with_added_voc(tokens)
353
354        return [self._convert_token_to_id_with_added_voc(token) for token in tokens]
355
356    def _convert_token_to_id_with_added_voc(self, token: str) -> int:
357        index = self._tokenizer.token_to_id(token)
358        if index is None:
359            return self.unk_token_id
360        return index
361
362    def _convert_id_to_token(self, index: int) -> Optional[str]:
363        return self._tokenizer.id_to_token(int(index))
364
365    def _add_tokens(
366        self, new_tokens: List[Union[str, AddedToken]], special_tokens=False
367    ) -> int:
368        if special_tokens:
369            return self._tokenizer.add_special_tokens(new_tokens)
370
371        return self._tokenizer.add_tokens(new_tokens)
372
373    def num_special_tokens_to_add(self, pair: bool = False) -> int:
374        """
375        Returns the number of added tokens when encoding a sequence with special tokens.
376
377        <Tip>
378
379        This encodes a dummy input and checks the number of added tokens, and is therefore not efficient. Do not put
380        this inside your training loop.
381
382        </Tip>
383
384        Args:
385            pair (`bool`, *optional*, defaults to `False`):
386                Whether the number of added tokens should be computed in the case of a sequence pair or a single
387                sequence.
388
389        Returns:
390            `int`: Number of special tokens added to sequences.
391        """
392        return self._tokenizer.num_special_tokens_to_add(pair)
393
394    def convert_ids_to_tokens(
395        self, ids: Union[int, List[int]], skip_special_tokens: bool = False
396    ) -> Union[str, List[str]]:
397        """
398        Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
399        added tokens.
400
401        Args:
402            ids (`int` or `List[int]`):
403                The token id (or token ids) to convert to tokens.
404            skip_special_tokens (`bool`, *optional*, defaults to `False`):
405                Whether or not to remove special tokens in the decoding.
406
407        Returns:
408            `str` or `List[str]`: The decoded token(s).
409        """
410        if isinstance(ids, int):
411            return self._tokenizer.id_to_token(ids)
412        tokens = []
413        for index in ids:
414            index = int(index)
415            if skip_special_tokens and index in self.all_special_ids:
416                continue
417            tokens.append(self._tokenizer.id_to_token(index))
418        return tokens
419
420    def tokenize(
421        self,
422        text: str,
423        pair: Optional[str] = None,
424        add_special_tokens: bool = False,
425        **kwargs,
426    ) -> List[str]:
427        return self.encode_plus(
428            text=text, text_pair=pair, add_special_tokens=add_special_tokens, **kwargs
429        ).tokens()
430
431    def set_truncation_and_padding(
432        self,
433        padding_strategy: PaddingStrategy,
434        truncation_strategy: TruncationStrategy,
435        max_length: Optional[int],
436        stride: int,
437        pad_to_multiple_of: Optional[int],
438    ):
439        """
440        Define the truncation and the padding strategies for fast tokenizers (provided by HuggingFace tokenizers
441        library) and restore the tokenizer settings afterwards.
442
443        The provided tokenizer has no padding / truncation strategy before the managed section. If your tokenizer set a
444        padding / truncation strategy before, then it will be reset to no padding / truncation when exiting the managed
445        section.
446
447        Args:
448            padding_strategy ([`~utils.PaddingStrategy`]):
449                The kind of padding that will be applied to the input
450            truncation_strategy ([`~tokenization_utils_base.TruncationStrategy`]):
451                The kind of truncation that will be applied to the input
452            max_length (`int`):
453                The maximum size of a sequence.
454            stride (`int`):
455                The stride to use when handling overflow.
456            pad_to_multiple_of (`int`, *optional*):
457                If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
458                the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
459        """
460        _truncation = self._tokenizer.truncation
461        _padding = self._tokenizer.padding
462        # Set truncation and padding on the backend tokenizer
463        if truncation_strategy == TruncationStrategy.DO_NOT_TRUNCATE:
464            if _truncation is not None:
465                self._tokenizer.no_truncation()
466        else:
467            target = {
468                "max_length": max_length,
469                "stride": stride,
470                "strategy": truncation_strategy.value,
471                "direction": self.truncation_side,
472            }
473
474            # _truncation might contain more keys that the target `transformers`
475            # supports. Use only the target keys to trigger `enable_truncation`.
476            # This should enable this code to works on various `tokenizers`
477            # targets.
478            if _truncation is None:
479                current = None
480            else:
481                current = {k: _truncation.get(k, None) for k in target}
482
483            if current != target:
484                self._tokenizer.enable_truncation(**target)
485
486        if padding_strategy == PaddingStrategy.DO_NOT_PAD:
487            if _padding is not None:
488                self._tokenizer.no_padding()
489        else:
490            length = (
491                max_length if padding_strategy == PaddingStrategy.MAX_LENGTH else None
492            )
493            target = {
494                "length": length,
495                "direction": self.padding_side,
496                "pad_id": self.pad_token_id,
497                "pad_token": self.pad_token,
498                "pad_type_id": self.pad_token_type_id,
499                "pad_to_multiple_of": pad_to_multiple_of,
500            }
501            if _padding != target:
502                self._tokenizer.enable_padding(**target)
503
504    def _batch_encode_plus(
505        self,
506        batch_text_or_text_pairs: Union[
507            List[TextInput],
508            List[TextInputPair],
509            List[PreTokenizedInput],
510            List[PreTokenizedInputPair],
511        ],
512        add_special_tokens: bool = True,
513        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
514        truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
515        max_length: Optional[int] = None,
516        stride: int = 0,
517        is_split_into_words: bool = False,
518        pad_to_multiple_of: Optional[int] = None,
519        return_tensors: Optional[str] = None,
520        return_token_type_ids: Optional[bool] = None,
521        return_attention_mask: Optional[bool] = None,
522        return_overflowing_tokens: bool = False,
523        return_special_tokens_mask: bool = False,
524        return_offsets_mapping: bool = False,
525        return_length: bool = False,
526        verbose: bool = True,
527    ) -> BatchEncoding:
528        if not isinstance(batch_text_or_text_pairs, (tuple, list)):
529            raise TypeError(
530                f"batch_text_or_text_pairs has to be a list or a tuple (got {type(batch_text_or_text_pairs)})"
531            )
532        # Set the truncation and padding strategy and restore the initial configuration
533        self.set_truncation_and_padding(
534            padding_strategy=padding_strategy,
535            truncation_strategy=truncation_strategy,
536            max_length=max_length,
537            stride=stride,
538            pad_to_multiple_of=pad_to_multiple_of,
539        )
540
541        encodings = self._tokenizer.encode_batch(
542            batch_text_or_text_pairs,
543            add_special_tokens=add_special_tokens,
544            is_pretokenized=is_split_into_words,
545        )
546
547        # Convert encoding to dict
548        # `Tokens` has type: Tuple[
549        #                       List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]],
550        #                       List[EncodingFast]
551        #                    ]
552        # with nested dimensions corresponding to batch, overflows, sequence length
553        tokens_and_encodings = [
554            self._convert_encoding(
555                encoding=encoding,
556                return_token_type_ids=return_token_type_ids,
557                return_attention_mask=return_attention_mask,
558                return_overflowing_tokens=return_overflowing_tokens,
559                return_special_tokens_mask=return_special_tokens_mask,
560                return_offsets_mapping=return_offsets_mapping,
561                return_length=return_length,
562                verbose=verbose,
563            )
564            for encoding in encodings
565        ]
566
567        # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension
568        # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length)
569        # (we say ~ because the number of overflow varies with the example in the batch)
570        #
571        # To match each overflowing sample with the original sample in the batch
572        # we add an overflow_to_sample_mapping array (see below)
573        sanitized_tokens = {}
574        for key in tokens_and_encodings[0][0].keys():
575            stack = [e for item, _ in tokens_and_encodings for e in item[key]]
576            sanitized_tokens[key] = stack
577        sanitized_encodings = [e for _, item in tokens_and_encodings for e in item]
578
579        # If returning overflowing tokens, we need to return a mapping
580        # from the batch idx to the original sample
581        if return_overflowing_tokens:
582            overflow_to_sample_mapping = []
583            for i, (toks, _) in enumerate(tokens_and_encodings):
584                overflow_to_sample_mapping += [i] * len(toks["input_ids"])
585            sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping
586
587        for input_ids in sanitized_tokens["input_ids"]:
588            self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose)
589        return BatchEncoding(
590            sanitized_tokens, sanitized_encodings, tensor_type=return_tensors
591        )
592
593    def _encode_plus(
594        self,
595        text: Union[TextInput, PreTokenizedInput],
596        text_pair: Optional[Union[TextInput, PreTokenizedInput]] = None,
597        add_special_tokens: bool = True,
598        padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
599        truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
600        max_length: Optional[int] = None,
601        stride: int = 0,
602        is_split_into_words: bool = False,
603        pad_to_multiple_of: Optional[int] = None,
604        return_tensors: Optional[bool] = None,
605        return_token_type_ids: Optional[bool] = None,
606        return_attention_mask: Optional[bool] = None,
607        return_overflowing_tokens: bool = False,
608        return_special_tokens_mask: bool = False,
609        return_offsets_mapping: bool = False,
610        return_length: bool = False,
611        verbose: bool = True,
612        **kwargs,
613    ) -> BatchEncoding:
614        batched_input = [(text, text_pair)] if text_pair else [text]
615        batched_output = self._batch_encode_plus(
616            batched_input,
617            is_split_into_words=is_split_into_words,
618            add_special_tokens=add_special_tokens,
619            padding_strategy=padding_strategy,
620            truncation_strategy=truncation_strategy,
621            max_length=max_length,
622            stride=stride,
623            pad_to_multiple_of=pad_to_multiple_of,
624            return_tensors=return_tensors,
625            return_token_type_ids=return_token_type_ids,
626            return_attention_mask=return_attention_mask,
627            return_overflowing_tokens=return_overflowing_tokens,
628            return_special_tokens_mask=return_special_tokens_mask,
629            return_offsets_mapping=return_offsets_mapping,
630            return_length=return_length,
631            verbose=verbose,
632            **kwargs,
633        )
634
635        # Return tensor is None, then we can remove the leading batch axis
636        # Overflowing tokens are returned as a batch of output so we keep them in this case
637        if return_tensors is None and not return_overflowing_tokens:
638            batched_output = BatchEncoding(
639                {
640                    key: (
641                        value[0]
642                        if len(value) > 0 and isinstance(value[0], list)
643                        else value
644                    )
645                    for key, value in batched_output.items()
646                },
647                batched_output.encodings,
648            )
649
650        self._eventual_warn_about_too_long_sequence(
651            batched_output["input_ids"], max_length, verbose
652        )
653
654        return batched_output
655
656    def convert_tokens_to_string(self, tokens: List[str]) -> str:
657        return self.backend_tokenizer.decoder.decode(tokens)
658
659    def _decode(
660        self,
661        token_ids: Union[int, List[int]],
662        skip_special_tokens: bool = False,
663        clean_up_tokenization_spaces: bool = None,
664        **kwargs,
665    ) -> str:
666        self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
667
668        if isinstance(token_ids, int):
669            token_ids = [token_ids]
670        text = self._tokenizer.decode(
671            token_ids, skip_special_tokens=skip_special_tokens
672        )
673
674        clean_up_tokenization_spaces = (
675            clean_up_tokenization_spaces
676            if clean_up_tokenization_spaces is not None
677            else self.clean_up_tokenization_spaces
678        )
679        if clean_up_tokenization_spaces:
680            clean_text = self.clean_up_tokenization(text)
681            return clean_text
682        else:
683            return text
684
685    def _save_pretrained(
686        self,
687        save_directory: Union[str, os.PathLike],
688        file_names: Tuple[str],
689        legacy_format: Optional[bool] = None,
690        filename_prefix: Optional[str] = None,
691    ) -> Tuple[str]:
692        """
693        Save a tokenizer using the slow-tokenizer/legacy format: vocabulary + added tokens as well as in a unique JSON
694        file containing {config + vocab + added-tokens}.
695        """
696        save_directory = str(save_directory)
697
698        if self.slow_tokenizer_class is None and legacy_format is True:
699            raise ValueError(
700                "Your tokenizer does not have a legacy version defined and therefore cannot register this version. You"
701                " might consider leaving the legacy_format at `None` or setting it to `False`."
702            )
703
704        save_slow = (
705            (legacy_format is None or legacy_format is True)
706            and self.slow_tokenizer_class is not None
707            and self.can_save_slow_tokenizer
708        )
709        save_fast = legacy_format is None or legacy_format is False
710
711        if save_slow:
712            added_tokens_file = os.path.join(
713                save_directory,
714                (filename_prefix + "-" if filename_prefix else "") + ADDED_TOKENS_FILE,
715            )
716            # make sure to be foward compatible
717            added_vocab = {
718                tok: index
719                for tok, index in self.added_tokens_encoder.items()
720                if index >= self.vocab_size
721            }
722            if added_vocab:
723                with open(added_tokens_file, "w", encoding="utf-8") as f:
724                    out_str = (
725                        json.dumps(
726                            added_vocab, indent=2, sort_keys=True, ensure_ascii=False
727                        )
728                        + "\n"
729                    )
730                    f.write(out_str)
731
732            vocab_files = self.save_vocabulary(
733                save_directory, filename_prefix=filename_prefix
734            )
735            file_names = file_names + vocab_files + (added_tokens_file,)
736
737        if save_fast:
738            tokenizer_file = os.path.join(
739                save_directory,
740                (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_FILE,
741            )
742            self.backend_tokenizer.save(tokenizer_file)
743            file_names = file_names + (tokenizer_file,)
744
745        return file_names
746
747    def train_new_from_iterator(
748        self,
749        text_iterator,
750        vocab_size,
751        length=None,
752        new_special_tokens=None,
753        special_tokens_map=None,
754        **kwargs,
755    ):
756        """
757        Trains a tokenizer on a new corpus with the same defaults (in terms of special tokens or tokenization pipeline)
758        as the current one.
759
760        Args:
761            text_iterator (generator of `List[str]`):
762                The training corpus. Should be a generator of batches of texts, for instance a list of lists of texts
763                if you have everything in memory.
764            vocab_size (`int`):
765                The size of the vocabulary you want for your tokenizer.
766            length (`int`, *optional*):
767                The total number of sequences in the iterator. This is used to provide meaningful progress tracking
768            new_special_tokens (list of `str` or `AddedToken`, *optional*):
769                A list of new special tokens to add to the tokenizer you are training.
770            special_tokens_map (`Dict[str, str]`, *optional*):
771                If you want to rename some of the special tokens this tokenizer uses, pass along a mapping old special
772                token name to new special token name in this argument.
773            kwargs (`Dict[str, Any]`, *optional*):
774                Additional keyword arguments passed along to the trainer from the �� Tokenizers library.
775
776        Returns:
777            [`PreTrainedTokenizerFast`]: A new tokenizer of the same type as the original one, trained on
778            `text_iterator`.
779
780        """
781        tokenizer_json = json.loads(self._tokenizer.to_str())
782        # Remove added tokens for now (uses IDs of tokens)
783        added_tokens = tokenizer_json.pop("added_tokens")
784        # Remove post processor for now (uses IDs of tokens)
785        post_processor = tokenizer_json.pop("post_processor")
786
787        unk_token = None
788        # Remove vocab
789        if tokenizer_json["model"]["type"] == "BPE":
790            tokenizer_json["model"]["vocab"] = {}
791            tokenizer_json["model"]["merges"] = []
792        elif tokenizer_json["model"]["type"] == "Unigram":
793            if tokenizer_json["model"]["unk_id"] is not None:
794                unk_id = tokenizer_json["model"]["unk_id"]
795                unk_token = tokenizer_json["model"]["vocab"][unk_id][0]
796                if special_tokens_map is not None and unk_token in special_tokens_map:
797                    unk_token = special_tokens_map[unk_token]
798                tokenizer_json["model"]["unk_id"] = 0
799                tokenizer_json["model"]["vocab"] = [[unk_token, 0.0]]
800        elif tokenizer_json["model"]["type"] in ["WordLevel", "WordPiece"]:
801            tokenizer_json["model"]["vocab"] = {}
802        else:
803            raise ValueError(
804                f"This method does not support this type of tokenizer (found {tokenizer_json['model']['type']}) "
805                "only BPE, Unigram, WordLevel and WordPiece."
806            )
807
808        if (
809            special_tokens_map is not None
810            and "unk_token" in tokenizer_json["model"]
811            and tokenizer_json["model"]["unk_token"] in special_tokens_map
812        ):
813            tokenizer_json["model"]["unk_token"] = special_tokens_map[
814                tokenizer_json["model"]["unk_token"]
815            ]
816
817        tokenizer = TokenizerFast.from_str(json.dumps(tokenizer_json))
818
819        # Get the special tokens from the current tokenizer if none are specified.
820        special_tokens = []
821        for added_token in added_tokens:
822            special = added_token.pop("special", None)
823            _ = added_token.pop("id", None)
824            if tokenizer_json["model"]["type"] != "Unigram" and not special:
825                continue
826            if (
827                special_tokens_map is not None
828                and added_token["content"] in special_tokens_map
829            ):
830                added_token["content"] = special_tokens_map[added_token["content"]]
831            special_tokens.append(AddedToken(**added_token))
832
833        if new_special_tokens is not None:
834            special_tokens.extend(new_special_tokens)
835
836        # Trainer needs to know the end of word / continuing subword thingies in BPE
837        if (
838            tokenizer_json["model"]["type"] == "BPE"
839            and "continuing_subword_prefix" not in kwargs
840            and tokenizer_json["model"]["continuing_subword_prefix"] is not None
841        ):
842            kwargs["continuing_subword_prefix"] = tokenizer_json["model"][
843                "continuing_subword_prefix"
844            ]
845        if (
846            tokenizer_json["model"]["type"] == "BPE"
847            and "end_of_word_suffix" not in kwargs
848            and tokenizer_json["model"]["end_of_word_suffix"] is not None
849        ):
850            kwargs["end_of_word_suffix"] = tokenizer_json["model"]["end_of_word_suffix"]
851        if tokenizer_json["model"]["type"] == "Unigram" and unk_token is not None:
852            kwargs["unk_token"] = unk_token
853        if (
854            tokenizer_json["pre_tokenizer"] is not None
855            and tokenizer_json["pre_tokenizer"]["type"] == "ByteLevel"
856        ):
857            kwargs["initial_alphabet"] = pre_tokenizers_fast.ByteLevel.alphabet()
858
859        trainer_class = MODEL_TO_TRAINER_MAPPING[tokenizer_json["model"]["type"]]
860        trainer = trainer_class(
861            vocab_size=vocab_size, special_tokens=special_tokens, **kwargs
862        )
863        tokenizer.train_from_iterator(text_iterator, length=length, trainer=trainer)
864
865        if post_processor is not None:
866            trained_tokenizer_json = json.loads(tokenizer.to_str())
867            # Almost done, we just have to adjust the token IDs in the post processor
868            if "special_tokens" in post_processor:
869                for key in post_processor["special_tokens"]:
870                    tokens = post_processor["special_tokens"][key]["tokens"]
871                    if special_tokens_map is not None:
872                        tokens = [
873                            special_tokens_map.get(token, token) for token in tokens
874                        ]
875                    post_processor["special_tokens"][key]["tokens"] = tokens
876                    post_processor["special_tokens"][key]["ids"] = [
877                        tokenizer.token_to_id(token) for token in tokens
878                    ]
879
880            for special_token in ["cls", "sep"]:
881                if special_token in post_processor:
882                    token, _ = post_processor[special_token]
883                    if special_tokens_map is not None and token in special_tokens_map:
884                        token = special_tokens_map[token]
885                    token_id = tokenizer.token_to_id(token)
886                    post_processor[special_token] = [token, token_id]
887
888            trained_tokenizer_json["post_processor"] = post_processor
889            tokenizer = TokenizerFast.from_str(json.dumps(trained_tokenizer_json))
890
891        kwargs = self.init_kwargs.copy()
892        # Map pad/cls/mask token at the Transformers level
893        special_tokens_list = SpecialTokensMixin.SPECIAL_TOKENS_ATTRIBUTES.copy()
894        special_tokens_list.remove("additional_special_tokens")
895        for token in special_tokens_list:
896            # Get the private one to avoid unnecessary warnings.
897            if getattr(self, f"_{token}") is not None:
898                special_token = getattr(self, token)
899                if (
900                    special_tokens_map is not None
901                    and special_token in special_tokens_map
902                ):
903                    special_token = special_tokens_map[special_token]
904
905                special_token_full = getattr(self, f"_{token}")
906                if isinstance(special_token_full, AddedToken):
907                    # Create an added token with the same parameters except the content
908                    kwargs[token] = AddedToken(
909                        special_token,
910                        single_word=special_token_full.single_word,
911                        lstrip=special_token_full.lstrip,
912                        rstrip=special_token_full.rstrip,
913                        normalized=special_token_full.normalized,
914                        special=True,
915                    )
916                else:
917                    kwargs[token] = special_token
918
919        additional_special_tokens = self.additional_special_tokens
920        if new_special_tokens is not None:
921            additional_special_tokens.extend(new_special_tokens)
922        if len(additional_special_tokens) > 0:
923            kwargs["additional_special_tokens"] = additional_special_tokens
924
925        return self.__class__(tokenizer_object=tokenizer, **kwargs)
926