xref: /aosp_15_r20/external/executorch/extension/llm/tokenizer/tokenizer.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #pragma once
10*523fa7a6SAndroid Build Coastguard Worker 
11*523fa7a6SAndroid Build Coastguard Worker #include <cinttypes>
12*523fa7a6SAndroid Build Coastguard Worker #include <string>
13*523fa7a6SAndroid Build Coastguard Worker #include <vector>
14*523fa7a6SAndroid Build Coastguard Worker 
15*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/error.h>
16*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/result.h>
17*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/platform/compiler.h>
18*523fa7a6SAndroid Build Coastguard Worker 
19*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
20*523fa7a6SAndroid Build Coastguard Worker namespace extension {
21*523fa7a6SAndroid Build Coastguard Worker namespace llm {
22*523fa7a6SAndroid Build Coastguard Worker 
23*523fa7a6SAndroid Build Coastguard Worker // A tokenizer interface.
24*523fa7a6SAndroid Build Coastguard Worker class ET_EXPERIMENTAL Tokenizer {
25*523fa7a6SAndroid Build Coastguard Worker  public:
Tokenizer()26*523fa7a6SAndroid Build Coastguard Worker   explicit Tokenizer() {}
~Tokenizer()27*523fa7a6SAndroid Build Coastguard Worker   virtual ~Tokenizer() {}
28*523fa7a6SAndroid Build Coastguard Worker 
29*523fa7a6SAndroid Build Coastguard Worker   virtual ::executorch::runtime::Error load(
30*523fa7a6SAndroid Build Coastguard Worker       const std::string& tokenizer_path) = 0;
31*523fa7a6SAndroid Build Coastguard Worker 
32*523fa7a6SAndroid Build Coastguard Worker   virtual ::executorch::runtime::Result<std::vector<uint64_t>>
33*523fa7a6SAndroid Build Coastguard Worker   encode(const std::string& input, int8_t bos, int8_t eos) const = 0;
34*523fa7a6SAndroid Build Coastguard Worker 
decode_verify(uint64_t token)35*523fa7a6SAndroid Build Coastguard Worker   ::executorch::runtime::Error decode_verify(uint64_t token) const {
36*523fa7a6SAndroid Build Coastguard Worker     if (!initialized_) {
37*523fa7a6SAndroid Build Coastguard Worker       ET_LOG(Error, "Tokenizer not initialized");
38*523fa7a6SAndroid Build Coastguard Worker       return ::executorch::runtime::Error::NotSupported;
39*523fa7a6SAndroid Build Coastguard Worker     }
40*523fa7a6SAndroid Build Coastguard Worker     if (token >= vocab_size_) {
41*523fa7a6SAndroid Build Coastguard Worker       ET_LOG(
42*523fa7a6SAndroid Build Coastguard Worker           Error,
43*523fa7a6SAndroid Build Coastguard Worker           "token  %" PRIu64 " is out side of vacab range %d",
44*523fa7a6SAndroid Build Coastguard Worker           token,
45*523fa7a6SAndroid Build Coastguard Worker           vocab_size_);
46*523fa7a6SAndroid Build Coastguard Worker       return ::executorch::runtime::Error::NotSupported;
47*523fa7a6SAndroid Build Coastguard Worker     }
48*523fa7a6SAndroid Build Coastguard Worker     return ::executorch::runtime::Error::Ok;
49*523fa7a6SAndroid Build Coastguard Worker   }
50*523fa7a6SAndroid Build Coastguard Worker 
51*523fa7a6SAndroid Build Coastguard Worker   virtual ::executorch::runtime::Result<std::string> decode(
52*523fa7a6SAndroid Build Coastguard Worker       uint64_t prev_token,
53*523fa7a6SAndroid Build Coastguard Worker       uint64_t token) const = 0;
54*523fa7a6SAndroid Build Coastguard Worker 
55*523fa7a6SAndroid Build Coastguard Worker   // getters
vocab_size()56*523fa7a6SAndroid Build Coastguard Worker   int32_t vocab_size() const {
57*523fa7a6SAndroid Build Coastguard Worker     return vocab_size_;
58*523fa7a6SAndroid Build Coastguard Worker   }
59*523fa7a6SAndroid Build Coastguard Worker 
bos_tok()60*523fa7a6SAndroid Build Coastguard Worker   uint64_t bos_tok() const {
61*523fa7a6SAndroid Build Coastguard Worker     return bos_tok_;
62*523fa7a6SAndroid Build Coastguard Worker   }
63*523fa7a6SAndroid Build Coastguard Worker 
eos_tok()64*523fa7a6SAndroid Build Coastguard Worker   uint64_t eos_tok() const {
65*523fa7a6SAndroid Build Coastguard Worker     return eos_tok_;
66*523fa7a6SAndroid Build Coastguard Worker   }
67*523fa7a6SAndroid Build Coastguard Worker 
68*523fa7a6SAndroid Build Coastguard Worker  protected:
69*523fa7a6SAndroid Build Coastguard Worker   bool initialized_ = false;
70*523fa7a6SAndroid Build Coastguard Worker   int32_t vocab_size_ = 0;
71*523fa7a6SAndroid Build Coastguard Worker   uint64_t bos_tok_ = 0;
72*523fa7a6SAndroid Build Coastguard Worker   uint64_t eos_tok_ = 0;
73*523fa7a6SAndroid Build Coastguard Worker };
74*523fa7a6SAndroid Build Coastguard Worker 
75*523fa7a6SAndroid Build Coastguard Worker } // namespace llm
76*523fa7a6SAndroid Build Coastguard Worker } // namespace extension
77*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
78*523fa7a6SAndroid Build Coastguard Worker 
79*523fa7a6SAndroid Build Coastguard Worker namespace torch {
80*523fa7a6SAndroid Build Coastguard Worker namespace executor {
81*523fa7a6SAndroid Build Coastguard Worker // TODO(T197294990): Remove these deprecated aliases once all users have moved
82*523fa7a6SAndroid Build Coastguard Worker // to the new `::executorch` namespaces.
83*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::llm::Tokenizer;
84*523fa7a6SAndroid Build Coastguard Worker } // namespace executor
85*523fa7a6SAndroid Build Coastguard Worker } // namespace torch
86