xref: /aosp_15_r20/external/executorch/examples/mediatek/executor_runner/llama_runner/LlamaConfig.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) 2024 MediaTek Inc.
3  *
4  * Licensed under the BSD License (the "License"); you may not use this file
5  * except in compliance with the License. See the license file in the root
6  * directory of this source tree for more details.
7  */
8 
9 #pragma once
10 
11 #include <string>
12 #include <vector>
13 
14 #include "llm_helper/include/llm_types.h"
15 
16 namespace example {
17 
18 using llm_helper::LLMType;
19 
20 struct LlamaModelOptions {
21   // Sizes
22   size_t prompt_token_batch_size = 1;
23   size_t cache_size = 1024;
24   size_t hidden_size = 4096;
25   size_t num_head = 32;
26   size_t num_layer = 32;
27   size_t max_token_length = 2048;
28   double rot_emb_base = 10000.0f;
29 
30   // Types
31   LLMType model_input_type = LLMType::INT16;
32   LLMType model_output_type = LLMType::INT16;
33   LLMType cache_type = LLMType::INT16;
34   LLMType mask_type = LLMType::INT16;
35   LLMType rot_emb_type = LLMType::INT16;
36 };
37 
38 struct LlamaModelPaths {
39   std::string tokenizer_path;
40   std::string token_embedding_path;
41   std::vector<std::string> prompt_model_paths;
42   std::vector<std::string> gen_model_paths;
43 };
44 
45 } // namespace example
46