xref: /aosp_15_r20/external/executorch/examples/mediatek/executor_runner/run_llama3_sample.sh (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 # Copyright (c) 2024 MediaTek Inc.
2 #
3 # Licensed under the BSD License (the "License"); you may not use this file
4 # except in compliance with the License. See the license file in the root
5 # directory of this source tree for more details.
6
7# Runtime
8MAX_RESPONSE=200
9
10# Model External
11PROMPT_TOKEN_BATCH_SIZE=128
12CACHE_SIZE=512
13
14# Model Internals
15HIDDEN_SIZE=4096
16NUM_HEAD=32
17NUM_LAYER=32
18MAX_TOKEN_LENGTH=8192
19ROT_EMB_BASE=500000
20
21# Model IO Types
22INPUT_TYPE=fp32
23OUTPUT_TYPE=fp32
24CACHE_TYPE=fp32
25MASK_TYPE=fp32
26ROT_EMB_TYPE=fp32
27
28# Tokenizer
29VOCAB_SIZE=128000
30BOS_TOKEN=128000
31EOS_TOKEN=128001
32TOKENIZER_TYPE=tiktoken  # Use "bpe" for LLAMA2, "tiktoken" for LLAMA3
33
34# Paths
35TOKENIZER_PATH="/data/local/tmp/llama3/tokenizer.model"
36TOKEN_EMBEDDING_PATH="/data/local/tmp/llama3/embedding_llama3_8b_instruct_fp32.bin"
37
38# Comma-Separated Paths
39PROMPT_MODEL_PATHS="\
40/data/local/tmp/llama3/llama3_8b_SC_sym4W_sym16A_4_chunks_Overall_128t512c_0.pte,\
41/data/local/tmp/llama3/llama3_8b_SC_sym4W_sym16A_4_chunks_Overall_128t512c_1.pte,\
42/data/local/tmp/llama3/llama3_8b_SC_sym4W_sym16A_4_chunks_Overall_128t512c_2.pte,\
43/data/local/tmp/llama3/llama3_8b_SC_sym4W_sym16A_4_chunks_Overall_128t512c_3.pte,"
44
45# Comma-Separated Paths
46GEN_MODEL_PATHS="\
47/data/local/tmp/llama3/llama3_8b_SC_sym4W_sym16A_4_chunks_Overall_1t512c_0.pte,\
48/data/local/tmp/llama3/llama3_8b_SC_sym4W_sym16A_4_chunks_Overall_1t512c_1.pte,\
49/data/local/tmp/llama3/llama3_8b_SC_sym4W_sym16A_4_chunks_Overall_1t512c_2.pte,\
50/data/local/tmp/llama3/llama3_8b_SC_sym4W_sym16A_4_chunks_Overall_1t512c_3.pte,"
51
52PROMPT_FILE=/data/local/tmp/llama3/sample_prompt.txt
53
54chmod +x mtk_llama_executor_runner
55
56export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$PWD
57
58./mtk_llama_executor_runner \
59    --max_response=$MAX_RESPONSE \
60    --prompt_token_batch_size=$PROMPT_TOKEN_BATCH_SIZE \
61    --cache_size=$CACHE_SIZE \
62    --hidden_size=$HIDDEN_SIZE \
63    --num_head=$NUM_HEAD \
64    --num_layer=$NUM_LAYER \
65    --max_token_length=$MAX_TOKEN_LENGTH \
66    --rot_emb_base=$ROT_EMB_BASE \
67    --input_type=$INPUT_TYPE \
68    --output_type=$OUTPUT_TYPE \
69    --cache_type=$CACHE_TYPE \
70    --mask_type=$MASK_TYPE \
71    --rot_emb_type=$ROT_EMB_TYPE \
72    --vocab_size=$VOCAB_SIZE \
73    --bos_token=$BOS_TOKEN \
74    --eos_token=$EOS_TOKEN \
75    --tokenizer_type=$TOKENIZER_TYPE \
76    --tokenizer_path=$TOKENIZER_PATH \
77    --token_embedding_path=$TOKEN_EMBEDDING_PATH \
78    --prompt_model_paths=$PROMPT_MODEL_PATHS \
79    --gen_model_paths=$GEN_MODEL_PATHS \
80    --prompt_file=$PROMPT_FILE