xref: /aosp_15_r20/external/executorch/examples/models/llava/export_llava.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8from argparse import ArgumentParser, BooleanOptionalAction
9
10import torch
11from executorch.backends.xnnpack.partition.config.xnnpack_config import (
12    ConfigPrecisionType,
13)
14from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
15from executorch.examples.models.llama.export_llama_lib import (
16    build_args_parser,
17    get_quantizer_and_quant_params,
18)
19from executorch.examples.models.llama.source_transformation.quantize import (
20    EmbeddingQuantHandler,
21    get_quant_weight_transform,
22)
23from executorch.examples.models.llama.source_transformation.sdpa import (
24    replace_sdpa_with_custom_op,
25)
26from executorch.examples.models.llava.image_util import serialize_image
27from executorch.examples.models.llava.model import LlavaModel
28from executorch.exir import (
29    EdgeCompileConfig,
30    ExecutorchBackendConfig,
31    to_edge_transform_and_lower,
32)
33
34from executorch.exir.passes import MemoryPlanningPass
35from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
36from executorch.exir.passes.sym_shape_eval_pass import (
37    ConstraintBasedSymShapeEvalPass,
38    HintBasedSymShapeEvalPass,
39)
40
41from executorch.extension.llm.export.builder import DType, LLMEdgeManager
42from executorch.extension.llm.tokenizer.tokenizer import Tokenizer
43from executorch.util.activation_memory_profiler import generate_memory_trace
44from torch.ao.quantization.quantizer.xnnpack_quantizer import (
45    get_symmetric_quantization_config,
46    XNNPACKQuantizer,
47)
48from torch.export import Dim
49from torch.nn.attention import SDPBackend
50
51FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
52logging.basicConfig(level=logging.INFO, format=FORMAT)
53
54
55class LlavaEdgeManager(LLMEdgeManager):
56    def export(self) -> "LlavaEdgeManager":
57        dynamic_shape = self._get_dynamic_shape()
58        # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
59        # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
60        with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
61            self.export_program = torch.export.export(
62                self.model,
63                self.example_inputs,
64                dynamic_shapes=dynamic_shape,
65                strict=False,
66            )
67            self.pre_autograd_graph_module = self.export_program.module()
68        return self
69
70
71def export_text_model(llava, embeddings, dynamic_shapes):
72    class LlavaTextModel(torch.nn.Module):
73        """Takes images and prompts and encode them into embeddings. Result will be sent to the text model LlavaTextModel."""
74
75        def __init__(self, llava):
76            super().__init__()
77            self.text_model = llava.text_model
78
79        def forward(self, input_pos, embeddings):
80            return self.text_model(None, input_pos, embeddings)
81
82    llava_text_model = LlavaTextModel(llava)
83
84    text_model_em = LLMEdgeManager(
85        model=llava_text_model,
86        modelname="llava_text_model",
87        max_seq_len=llava.text_model_args.max_seq_len,
88        dtype=DType.fp32,
89        use_kv_cache=True,
90        example_inputs=(torch.tensor([0], dtype=torch.int64), embeddings),
91        dynamic_shapes=dynamic_shapes,
92        args=llava.text_model_args,
93    )
94
95    dtype_override = DType.fp32
96    parser = build_args_parser()
97    args = parser.parse_args(
98        ["-X", "-qmode", "8da4w", "--group_size", "128", "--embedding-quantize", "4,32"]
99    )
100    quant_transform = get_quant_weight_transform(args, dtype_override, False)
101    _, quantizers, _ = get_quantizer_and_quant_params(args)
102    source_transforms = []
103    if llava.use_sdpa_with_kv_cache_op:
104        source_transforms.append(replace_sdpa_with_custom_op)
105    source_transforms.append(quant_transform)
106    manager = (
107        text_model_em.set_output_dir("./")
108        .to_dtype(dtype_override)
109        .source_transform(source_transforms)
110        .export()
111        .pt2e_quantize(quantizers)
112    )
113
114    with torch.no_grad():
115        text_model_ep = torch.export.export(
116            manager.pre_autograd_graph_module,
117            manager.example_inputs,
118            dynamic_shapes=manager._get_dynamic_shape(),
119        )
120    return text_model_ep
121
122
123def export_image_encoder(llava, resized, dynamic_shapes):
124    class LlavaImageEncoder(torch.nn.Module):
125        """Takes images and prompts and encode them into embeddings. Result will be sent to the text model LlavaTextModel."""
126
127        def __init__(self, llava):
128            super().__init__()
129            self.llava = llava
130
131        def forward(self, images):
132            return self.llava.image_embedding(images)
133
134    llava_image_encode = LlavaImageEncoder(llava)
135
136    # quantizer
137    quantizer = XNNPACKQuantizer()
138    quantizer.set_global(get_symmetric_quantization_config())
139
140    manager = (
141        LlavaEdgeManager(
142            model=llava_image_encode,
143            modelname="llava_image_encoder",
144            max_seq_len=llava.text_model_args.max_seq_len,  # This may not be right
145            dtype=DType.fp32,
146            use_kv_cache=True,
147            example_inputs=(resized,),
148            dynamic_shapes=dynamic_shapes,
149            args=None,
150        )
151        .export()
152        .pt2e_quantize([quantizer])
153    )
154
155    # lower to executorch
156    with torch.no_grad():
157        image_encoder_ep = torch.export.export(
158            manager.pre_autograd_graph_module,
159            manager.example_inputs,
160            dynamic_shapes=manager.dynamic_shapes,
161        )
162    return image_encoder_ep
163
164
165def export_token_embedding(llava, prompt):
166    def quant_embedding(model):
167        return EmbeddingQuantHandler(
168            model,
169            bitwidth=8,
170            group_size=32,
171            packed=False,
172        ).quantized_model()
173
174    quantized_token_embed = quant_embedding(llava.model_.language_model.model)
175    token_dim_1 = Dim("token_dim_1", min=2, max=llava.text_model_args.max_seq_len)
176    dynamic_shapes = [{1: token_dim_1}]
177    with torch.no_grad():
178        token_embedding_ep = torch.export.export(
179            quantized_token_embed.embed_tokens, (prompt,), dynamic_shapes=dynamic_shapes
180        )
181    return token_embedding_ep
182
183
184def export_all(llava_model: LlavaModel):
185    llava = llava_model.get_eager_model()
186
187    (
188        prompt_before_image,
189        resized,
190        prompt_after_image,
191    ) = llava_model.get_inputs_for_prefill()
192
193    image_encoder_ep = export_image_encoder(
194        llava, resized, llava_model._get_image_dynamic_shapes()
195    )
196
197    embeddings = llava.prefill_embedding(
198        prompt_before_image, resized, prompt_after_image
199    )
200
201    text_model_ep = export_text_model(
202        llava, embeddings, llava_model._get_prompt_dynamic_shapes()
203    )
204
205    token_embedding_ep = export_token_embedding(llava, prompt_before_image)
206
207    lowered_and_edge = to_edge_transform_and_lower(
208        {
209            "image_encoder": image_encoder_ep,
210            "token_embedding": token_embedding_ep,
211            "text_model": text_model_ep,
212        },
213        partitioner={
214            "image_encoder": [XnnpackPartitioner()],
215            "text_model": [
216                # First partition the DQLinear nodes, then partition the rest of the nodes,
217                # to avoid multiple DQLinear nodes in the same partition,
218                # to avoid holding multiple unpacked and packed weight buffers in memory,
219                # to reduce peak memory footprint.
220                XnnpackPartitioner(
221                    config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
222                    per_op_mode=True,
223                ),
224                XnnpackPartitioner(),
225            ],
226        },
227        compile_config=EdgeCompileConfig(_check_ir_validity=False),
228    )
229
230    executorch_program = lowered_and_edge.to_executorch(
231        ExecutorchBackendConfig(
232            extract_delegate_segments=True,
233            passes=[
234                QuantFusionPass(),
235            ],
236            memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
237            sym_shape_eval_pass={
238                "image_encoder": ConstraintBasedSymShapeEvalPass(),
239                "text_model": ConstraintBasedSymShapeEvalPass(),
240                "token_embedding": HintBasedSymShapeEvalPass(),
241            },
242        )
243    )
244    for execution_plan in executorch_program._emitter_output.program.execution_plan:
245        logging.info(
246            f"Required memory for activation in bytes: {execution_plan.non_const_buffer_sizes}"
247        )
248    return executorch_program
249
250
251def get_image_tensor_for_llava_runner(llava_model):
252    # llava runner doesn't have image reader so an image tensor is needed.
253    (resized,) = llava_model.get_example_inputs()
254
255    serialize_image(resized, "image.pt")
256
257
258def get_tokenizer_for_llava_runner(llava_model):
259    # serialize tokenizer into tokenizer.bin
260    llava_model.tokenizer.save_vocabulary("./")
261    t = Tokenizer("tokenizer.model")
262    t.export("tokenizer.bin")
263
264
265def main():
266    parser = ArgumentParser()
267    parser.add_argument(
268        "--use-sdpa-with-kv-cache",
269        default=True,
270        action=BooleanOptionalAction,
271        help="Use sdpa_with_kv_cache custom op in LLava text model.",
272    )
273    parser.add_argument(
274        "--max-seq-len",
275        default=768,
276        type=int,
277        help="Maximum sequence length for the text model.",
278    )
279    parser.add_argument(
280        "--pte-name",
281        default="llava_combined_xnnpack.pte",
282        help="Name of the exported ExecuTorch program.",
283    )
284    parser.add_argument(
285        "--with-artifacts",
286        default=False,
287        action=BooleanOptionalAction,
288        help="Generate artifacts for llava runner.",
289    )
290    parser.add_argument(
291        "--profile_memory",
292        required=False,
293        action="store_true",
294        help="Generate chrome trace of activation memory for intermediate tensors.",
295    )
296    args = parser.parse_args()
297    logging.info(
298        f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {args.use_sdpa_with_kv_cache}, max_seq_len: {args.max_seq_len}"
299    )
300    llava_model = LlavaModel(
301        use_sdpa_with_kv_cache_op=args.use_sdpa_with_kv_cache,
302        max_seq_len=args.max_seq_len,
303    )
304
305    executorch_program = export_all(llava_model)
306
307    # memory profiling
308    if args.profile_memory:
309        for method_name in executorch_program.methods:
310            generate_memory_trace(
311                executorch_program,
312                f"{args.pte_name}_{method_name}.json",
313                method_name=method_name,
314            )
315
316    with open(args.pte_name, "wb") as f:
317        executorch_program.write_to_file(f)
318    logging.info(f"Exported ExecuTorch program to {args.pte_name}")
319
320    # artifacts
321    if args.with_artifacts:
322        get_image_tensor_for_llava_runner(llava_model)
323        get_tokenizer_for_llava_runner(llava_model)
324
325
326if __name__ == "__main__":
327    main()
328