1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7 8import argparse 9 10import torch 11 12from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( 13 DuplicateDynamicQuantChainPass, 14) 15from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner 16from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config 17from executorch.exir import to_edge 18from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 19 20from torch.ao.quantization.quantizer.xnnpack_quantizer import ( 21 get_symmetric_quantization_config, 22 XNNPACKQuantizer, 23) 24from torch.export import export_for_training 25 26from transformers import Phi3ForCausalLM 27 28from .phi_3_mini import Phi3Mini 29 30 31def export(args) -> None: 32 torch.manual_seed(0) 33 34 if args.context_length == "4k": 35 model_name = "microsoft/Phi-3-mini-4k-instruct" 36 elif args.context_length == "128k": 37 model_name = "microsoft/Phi-3-mini-128k-instruct" 38 else: 39 raise Exception( 40 f"Invalid context length {args.context_length}. Should be either 4k or 128k" 41 ) 42 43 with torch.no_grad(): 44 model = Phi3Mini( 45 # pyre-ignore: Undefined attribute [16]: Module `transformers` has no attribute `Phi3ForCausalLM` 46 model=Phi3ForCausalLM.from_pretrained(model_name), 47 max_batch_size=1, 48 max_seq_len=args.seq_len, 49 ) 50 example_inputs = ( 51 torch.tensor( 52 [[1048, 263, 931, 746]], dtype=torch.long, requires_grad=False 53 ), 54 ) 55 dynamic_shapes = { 56 "input_ids": { 57 1: torch.export.Dim("sequence_length", min=1, max=args.seq_len) 58 } 59 } 60 61 xnnpack_quant_config = get_symmetric_quantization_config( 62 is_per_channel=True, is_dynamic=True 63 ) 64 xnnpack_quantizer = XNNPACKQuantizer() 65 xnnpack_quantizer.set_global(xnnpack_quant_config) 66 67 model = export_for_training( 68 model, example_inputs, dynamic_shapes=dynamic_shapes 69 ).module() 70 model = prepare_pt2e(model, xnnpack_quantizer) # pyre-fixme[6] 71 model(*example_inputs) 72 model = convert_pt2e(model) 73 DuplicateDynamicQuantChainPass()(model) 74 # TODO(lunwenh): update it to use export once 75 # https://github.com/pytorch/pytorch/issues/128394 is resolved. 76 model = torch.export._trace._export( 77 model, 78 example_inputs, 79 dynamic_shapes=dynamic_shapes, 80 strict=False, 81 pre_dispatch=False, 82 ) 83 84 edge_config = get_xnnpack_edge_compile_config() 85 edge_manager = to_edge(model, compile_config=edge_config) 86 edge_manager = edge_manager.to_backend(XnnpackPartitioner()) 87 et_program = edge_manager.to_executorch() 88 89 with open(args.output_name, "wb") as file: 90 file.write(et_program.buffer) 91 92 93def main(): 94 parser = argparse.ArgumentParser() 95 parser.add_argument( 96 "-c", 97 "--context_length", 98 type=str, 99 default="4k", 100 choices=["4k", "128k"], 101 help="Phi-3-mini provides two context length variants: 4k and 128k", 102 ) 103 parser.add_argument( 104 "-s", 105 "--seq_len", 106 type=int, 107 default=128, 108 help="Maximum number of tokens including prompt to generate", 109 ) 110 parser.add_argument( 111 "-o", 112 "--output_name", 113 default="phi-3-mini.pte", 114 help="Override the output filename of the saved pte model file.", 115 ) 116 export(parser.parse_args()) 117 118 119if __name__ == "__main__": 120 main() 121