1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import argparse 10 11import torch 12from executorch.examples.llm_pte_finetuning.model_loading_lib import ( 13 export_model_lora_training, 14 load_checkpoint, 15 setup_model, 16) 17 18from executorch.examples.llm_pte_finetuning.training_lib import ( 19 get_dataloader, 20 TrainingModule, 21) 22 23from omegaconf import OmegaConf 24from torch.nn import functional as F 25from torchtune import config 26 27from torchtune.training import MODEL_KEY 28 29parser = argparse.ArgumentParser( 30 prog="ModelExporter", 31 description="Export a LoRA model to ExecuTorch.", 32 epilog="Model exported to be used for fine-tuning.", 33) 34 35parser.add_argument("--cfg", type=str, help="Path to the config file.") 36parser.add_argument("--output_file", type=str, help="Path to the output ET model.") 37 38 39def main() -> None: 40 args = parser.parse_args() 41 config_file = args.cfg 42 output_file = args.output_file 43 cfg = OmegaConf.load(config_file) 44 tokenizer = config.instantiate( 45 cfg.tokenizer, 46 ) 47 48 loss_fn = config.instantiate(cfg.loss) 49 50 ds = config.instantiate(cfg.dataset, tokenizer) 51 train_set, val_set = torch.utils.data.random_split(ds, [0.8, 0.2]) 52 train_dataloader = get_dataloader(cfg, train_set, tokenizer, loss_fn) 53 54 max_seq_len = cfg.tokenizer.max_seq_len 55 56 # Example inputs, needed for ET export. 57 batch = next(iter(train_dataloader)) 58 tokens, labels = batch["tokens"], batch["labels"] 59 token_size = tokens.shape[1] 60 labels_size = labels.shape[1] 61 62 if token_size > max_seq_len: 63 tokens = tokens[:, :max_seq_len] 64 else: 65 tokens = F.pad(tokens, (0, max_seq_len - token_size), value=0) 66 67 if labels_size > max_seq_len: 68 labels = labels[:, :max_seq_len] 69 else: 70 labels = F.pad(labels, (0, max_seq_len - labels_size), value=0) 71 72 # Load pre-trained checkpoint. 73 checkpoint_dict = load_checkpoint(cfg=cfg) 74 model = setup_model( 75 # pyre-ignore 76 cfg=cfg, 77 base_model_state_dict=checkpoint_dict[MODEL_KEY], 78 ) 79 80 training_module = TrainingModule(model, loss_fn) 81 82 # Export the model to ExecuTorch for training. 83 export_model_lora_training(training_module, (tokens, labels), output_file) 84 85 86if __name__ == "__main__": 87 main() 88