xref: /aosp_15_r20/external/executorch/examples/models/phi-3-mini/export_phi-3-mini.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7
8import argparse
9
10import torch
11
12from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
13    DuplicateDynamicQuantChainPass,
14)
15from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
16from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
17from executorch.exir import to_edge
18from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
19
20from torch.ao.quantization.quantizer.xnnpack_quantizer import (
21    get_symmetric_quantization_config,
22    XNNPACKQuantizer,
23)
24from torch.export import export_for_training
25
26from transformers import Phi3ForCausalLM
27
28from .phi_3_mini import Phi3Mini
29
30
31def export(args) -> None:
32    torch.manual_seed(0)
33
34    if args.context_length == "4k":
35        model_name = "microsoft/Phi-3-mini-4k-instruct"
36    elif args.context_length == "128k":
37        model_name = "microsoft/Phi-3-mini-128k-instruct"
38    else:
39        raise Exception(
40            f"Invalid context length {args.context_length}. Should be either 4k or 128k"
41        )
42
43    with torch.no_grad():
44        model = Phi3Mini(
45            # pyre-ignore: Undefined attribute [16]: Module `transformers` has no attribute `Phi3ForCausalLM`
46            model=Phi3ForCausalLM.from_pretrained(model_name),
47            max_batch_size=1,
48            max_seq_len=args.seq_len,
49        )
50        example_inputs = (
51            torch.tensor(
52                [[1048, 263, 931, 746]], dtype=torch.long, requires_grad=False
53            ),
54        )
55        dynamic_shapes = {
56            "input_ids": {
57                1: torch.export.Dim("sequence_length", min=1, max=args.seq_len)
58            }
59        }
60
61        xnnpack_quant_config = get_symmetric_quantization_config(
62            is_per_channel=True, is_dynamic=True
63        )
64        xnnpack_quantizer = XNNPACKQuantizer()
65        xnnpack_quantizer.set_global(xnnpack_quant_config)
66
67        model = export_for_training(
68            model, example_inputs, dynamic_shapes=dynamic_shapes
69        ).module()
70        model = prepare_pt2e(model, xnnpack_quantizer)  # pyre-fixme[6]
71        model(*example_inputs)
72        model = convert_pt2e(model)
73        DuplicateDynamicQuantChainPass()(model)
74        # TODO(lunwenh): update it to use export once
75        # https://github.com/pytorch/pytorch/issues/128394 is resolved.
76        model = torch.export._trace._export(
77            model,
78            example_inputs,
79            dynamic_shapes=dynamic_shapes,
80            strict=False,
81            pre_dispatch=False,
82        )
83
84    edge_config = get_xnnpack_edge_compile_config()
85    edge_manager = to_edge(model, compile_config=edge_config)
86    edge_manager = edge_manager.to_backend(XnnpackPartitioner())
87    et_program = edge_manager.to_executorch()
88
89    with open(args.output_name, "wb") as file:
90        file.write(et_program.buffer)
91
92
93def main():
94    parser = argparse.ArgumentParser()
95    parser.add_argument(
96        "-c",
97        "--context_length",
98        type=str,
99        default="4k",
100        choices=["4k", "128k"],
101        help="Phi-3-mini provides two context length variants: 4k and 128k",
102    )
103    parser.add_argument(
104        "-s",
105        "--seq_len",
106        type=int,
107        default=128,
108        help="Maximum number of tokens including prompt to generate",
109    )
110    parser.add_argument(
111        "-o",
112        "--output_name",
113        default="phi-3-mini.pte",
114        help="Override the output filename of the saved pte model file.",
115    )
116    export(parser.parse_args())
117
118
119if __name__ == "__main__":
120    main()
121