xref: /aosp_15_r20/external/executorch/examples/llm_pte_finetuning/model_exporter.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import argparse
10
11import torch
12from executorch.examples.llm_pte_finetuning.model_loading_lib import (
13    export_model_lora_training,
14    load_checkpoint,
15    setup_model,
16)
17
18from executorch.examples.llm_pte_finetuning.training_lib import (
19    get_dataloader,
20    TrainingModule,
21)
22
23from omegaconf import OmegaConf
24from torch.nn import functional as F
25from torchtune import config
26
27from torchtune.training import MODEL_KEY
28
29parser = argparse.ArgumentParser(
30    prog="ModelExporter",
31    description="Export a LoRA model to ExecuTorch.",
32    epilog="Model exported to be used for fine-tuning.",
33)
34
35parser.add_argument("--cfg", type=str, help="Path to the config file.")
36parser.add_argument("--output_file", type=str, help="Path to the output ET model.")
37
38
39def main() -> None:
40    args = parser.parse_args()
41    config_file = args.cfg
42    output_file = args.output_file
43    cfg = OmegaConf.load(config_file)
44    tokenizer = config.instantiate(
45        cfg.tokenizer,
46    )
47
48    loss_fn = config.instantiate(cfg.loss)
49
50    ds = config.instantiate(cfg.dataset, tokenizer)
51    train_set, val_set = torch.utils.data.random_split(ds, [0.8, 0.2])
52    train_dataloader = get_dataloader(cfg, train_set, tokenizer, loss_fn)
53
54    max_seq_len = cfg.tokenizer.max_seq_len
55
56    # Example inputs, needed for ET export.
57    batch = next(iter(train_dataloader))
58    tokens, labels = batch["tokens"], batch["labels"]
59    token_size = tokens.shape[1]
60    labels_size = labels.shape[1]
61
62    if token_size > max_seq_len:
63        tokens = tokens[:, :max_seq_len]
64    else:
65        tokens = F.pad(tokens, (0, max_seq_len - token_size), value=0)
66
67    if labels_size > max_seq_len:
68        labels = labels[:, :max_seq_len]
69    else:
70        labels = F.pad(labels, (0, max_seq_len - labels_size), value=0)
71
72    # Load pre-trained checkpoint.
73    checkpoint_dict = load_checkpoint(cfg=cfg)
74    model = setup_model(
75        # pyre-ignore
76        cfg=cfg,
77        base_model_state_dict=checkpoint_dict[MODEL_KEY],
78    )
79
80    training_module = TrainingModule(model, loss_fn)
81
82    # Export the model to ExecuTorch for training.
83    export_model_lora_training(training_module, (tokens, labels), output_file)
84
85
86if __name__ == "__main__":
87    main()
88