1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8from argparse import ArgumentParser, BooleanOptionalAction 9 10import torch 11from executorch.backends.xnnpack.partition.config.xnnpack_config import ( 12 ConfigPrecisionType, 13) 14from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner 15from executorch.examples.models.llama.export_llama_lib import ( 16 build_args_parser, 17 get_quantizer_and_quant_params, 18) 19from executorch.examples.models.llama.source_transformation.quantize import ( 20 EmbeddingQuantHandler, 21 get_quant_weight_transform, 22) 23from executorch.examples.models.llama.source_transformation.sdpa import ( 24 replace_sdpa_with_custom_op, 25) 26from executorch.examples.models.llava.image_util import serialize_image 27from executorch.examples.models.llava.model import LlavaModel 28from executorch.exir import ( 29 EdgeCompileConfig, 30 ExecutorchBackendConfig, 31 to_edge_transform_and_lower, 32) 33 34from executorch.exir.passes import MemoryPlanningPass 35from executorch.exir.passes.quant_fusion_pass import QuantFusionPass 36from executorch.exir.passes.sym_shape_eval_pass import ( 37 ConstraintBasedSymShapeEvalPass, 38 HintBasedSymShapeEvalPass, 39) 40 41from executorch.extension.llm.export.builder import DType, LLMEdgeManager 42from executorch.extension.llm.tokenizer.tokenizer import Tokenizer 43from executorch.util.activation_memory_profiler import generate_memory_trace 44from torch.ao.quantization.quantizer.xnnpack_quantizer import ( 45 get_symmetric_quantization_config, 46 XNNPACKQuantizer, 47) 48from torch.export import Dim 49from torch.nn.attention import SDPBackend 50 51FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 52logging.basicConfig(level=logging.INFO, format=FORMAT) 53 54 55class LlavaEdgeManager(LLMEdgeManager): 56 def export(self) -> "LlavaEdgeManager": 57 dynamic_shape = self._get_dynamic_shape() 58 # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing 59 # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) 60 with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): 61 self.export_program = torch.export.export( 62 self.model, 63 self.example_inputs, 64 dynamic_shapes=dynamic_shape, 65 strict=False, 66 ) 67 self.pre_autograd_graph_module = self.export_program.module() 68 return self 69 70 71def export_text_model(llava, embeddings, dynamic_shapes): 72 class LlavaTextModel(torch.nn.Module): 73 """Takes images and prompts and encode them into embeddings. Result will be sent to the text model LlavaTextModel.""" 74 75 def __init__(self, llava): 76 super().__init__() 77 self.text_model = llava.text_model 78 79 def forward(self, input_pos, embeddings): 80 return self.text_model(None, input_pos, embeddings) 81 82 llava_text_model = LlavaTextModel(llava) 83 84 text_model_em = LLMEdgeManager( 85 model=llava_text_model, 86 modelname="llava_text_model", 87 max_seq_len=llava.text_model_args.max_seq_len, 88 dtype=DType.fp32, 89 use_kv_cache=True, 90 example_inputs=(torch.tensor([0], dtype=torch.int64), embeddings), 91 dynamic_shapes=dynamic_shapes, 92 args=llava.text_model_args, 93 ) 94 95 dtype_override = DType.fp32 96 parser = build_args_parser() 97 args = parser.parse_args( 98 ["-X", "-qmode", "8da4w", "--group_size", "128", "--embedding-quantize", "4,32"] 99 ) 100 quant_transform = get_quant_weight_transform(args, dtype_override, False) 101 _, quantizers, _ = get_quantizer_and_quant_params(args) 102 source_transforms = [] 103 if llava.use_sdpa_with_kv_cache_op: 104 source_transforms.append(replace_sdpa_with_custom_op) 105 source_transforms.append(quant_transform) 106 manager = ( 107 text_model_em.set_output_dir("./") 108 .to_dtype(dtype_override) 109 .source_transform(source_transforms) 110 .export() 111 .pt2e_quantize(quantizers) 112 ) 113 114 with torch.no_grad(): 115 text_model_ep = torch.export.export( 116 manager.pre_autograd_graph_module, 117 manager.example_inputs, 118 dynamic_shapes=manager._get_dynamic_shape(), 119 ) 120 return text_model_ep 121 122 123def export_image_encoder(llava, resized, dynamic_shapes): 124 class LlavaImageEncoder(torch.nn.Module): 125 """Takes images and prompts and encode them into embeddings. Result will be sent to the text model LlavaTextModel.""" 126 127 def __init__(self, llava): 128 super().__init__() 129 self.llava = llava 130 131 def forward(self, images): 132 return self.llava.image_embedding(images) 133 134 llava_image_encode = LlavaImageEncoder(llava) 135 136 # quantizer 137 quantizer = XNNPACKQuantizer() 138 quantizer.set_global(get_symmetric_quantization_config()) 139 140 manager = ( 141 LlavaEdgeManager( 142 model=llava_image_encode, 143 modelname="llava_image_encoder", 144 max_seq_len=llava.text_model_args.max_seq_len, # This may not be right 145 dtype=DType.fp32, 146 use_kv_cache=True, 147 example_inputs=(resized,), 148 dynamic_shapes=dynamic_shapes, 149 args=None, 150 ) 151 .export() 152 .pt2e_quantize([quantizer]) 153 ) 154 155 # lower to executorch 156 with torch.no_grad(): 157 image_encoder_ep = torch.export.export( 158 manager.pre_autograd_graph_module, 159 manager.example_inputs, 160 dynamic_shapes=manager.dynamic_shapes, 161 ) 162 return image_encoder_ep 163 164 165def export_token_embedding(llava, prompt): 166 def quant_embedding(model): 167 return EmbeddingQuantHandler( 168 model, 169 bitwidth=8, 170 group_size=32, 171 packed=False, 172 ).quantized_model() 173 174 quantized_token_embed = quant_embedding(llava.model_.language_model.model) 175 token_dim_1 = Dim("token_dim_1", min=2, max=llava.text_model_args.max_seq_len) 176 dynamic_shapes = [{1: token_dim_1}] 177 with torch.no_grad(): 178 token_embedding_ep = torch.export.export( 179 quantized_token_embed.embed_tokens, (prompt,), dynamic_shapes=dynamic_shapes 180 ) 181 return token_embedding_ep 182 183 184def export_all(llava_model: LlavaModel): 185 llava = llava_model.get_eager_model() 186 187 ( 188 prompt_before_image, 189 resized, 190 prompt_after_image, 191 ) = llava_model.get_inputs_for_prefill() 192 193 image_encoder_ep = export_image_encoder( 194 llava, resized, llava_model._get_image_dynamic_shapes() 195 ) 196 197 embeddings = llava.prefill_embedding( 198 prompt_before_image, resized, prompt_after_image 199 ) 200 201 text_model_ep = export_text_model( 202 llava, embeddings, llava_model._get_prompt_dynamic_shapes() 203 ) 204 205 token_embedding_ep = export_token_embedding(llava, prompt_before_image) 206 207 lowered_and_edge = to_edge_transform_and_lower( 208 { 209 "image_encoder": image_encoder_ep, 210 "token_embedding": token_embedding_ep, 211 "text_model": text_model_ep, 212 }, 213 partitioner={ 214 "image_encoder": [XnnpackPartitioner()], 215 "text_model": [ 216 # First partition the DQLinear nodes, then partition the rest of the nodes, 217 # to avoid multiple DQLinear nodes in the same partition, 218 # to avoid holding multiple unpacked and packed weight buffers in memory, 219 # to reduce peak memory footprint. 220 XnnpackPartitioner( 221 config_precisions=ConfigPrecisionType.DYNAMIC_QUANT, 222 per_op_mode=True, 223 ), 224 XnnpackPartitioner(), 225 ], 226 }, 227 compile_config=EdgeCompileConfig(_check_ir_validity=False), 228 ) 229 230 executorch_program = lowered_and_edge.to_executorch( 231 ExecutorchBackendConfig( 232 extract_delegate_segments=True, 233 passes=[ 234 QuantFusionPass(), 235 ], 236 memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), 237 sym_shape_eval_pass={ 238 "image_encoder": ConstraintBasedSymShapeEvalPass(), 239 "text_model": ConstraintBasedSymShapeEvalPass(), 240 "token_embedding": HintBasedSymShapeEvalPass(), 241 }, 242 ) 243 ) 244 for execution_plan in executorch_program._emitter_output.program.execution_plan: 245 logging.info( 246 f"Required memory for activation in bytes: {execution_plan.non_const_buffer_sizes}" 247 ) 248 return executorch_program 249 250 251def get_image_tensor_for_llava_runner(llava_model): 252 # llava runner doesn't have image reader so an image tensor is needed. 253 (resized,) = llava_model.get_example_inputs() 254 255 serialize_image(resized, "image.pt") 256 257 258def get_tokenizer_for_llava_runner(llava_model): 259 # serialize tokenizer into tokenizer.bin 260 llava_model.tokenizer.save_vocabulary("./") 261 t = Tokenizer("tokenizer.model") 262 t.export("tokenizer.bin") 263 264 265def main(): 266 parser = ArgumentParser() 267 parser.add_argument( 268 "--use-sdpa-with-kv-cache", 269 default=True, 270 action=BooleanOptionalAction, 271 help="Use sdpa_with_kv_cache custom op in LLava text model.", 272 ) 273 parser.add_argument( 274 "--max-seq-len", 275 default=768, 276 type=int, 277 help="Maximum sequence length for the text model.", 278 ) 279 parser.add_argument( 280 "--pte-name", 281 default="llava_combined_xnnpack.pte", 282 help="Name of the exported ExecuTorch program.", 283 ) 284 parser.add_argument( 285 "--with-artifacts", 286 default=False, 287 action=BooleanOptionalAction, 288 help="Generate artifacts for llava runner.", 289 ) 290 parser.add_argument( 291 "--profile_memory", 292 required=False, 293 action="store_true", 294 help="Generate chrome trace of activation memory for intermediate tensors.", 295 ) 296 args = parser.parse_args() 297 logging.info( 298 f"Exporting Llava model to ExecuTorch with sdpa_with_kv_cache: {args.use_sdpa_with_kv_cache}, max_seq_len: {args.max_seq_len}" 299 ) 300 llava_model = LlavaModel( 301 use_sdpa_with_kv_cache_op=args.use_sdpa_with_kv_cache, 302 max_seq_len=args.max_seq_len, 303 ) 304 305 executorch_program = export_all(llava_model) 306 307 # memory profiling 308 if args.profile_memory: 309 for method_name in executorch_program.methods: 310 generate_memory_trace( 311 executorch_program, 312 f"{args.pte_name}_{method_name}.json", 313 method_name=method_name, 314 ) 315 316 with open(args.pte_name, "wb") as f: 317 executorch_program.write_to_file(f) 318 logging.info(f"Exported ExecuTorch program to {args.pte_name}") 319 320 # artifacts 321 if args.with_artifacts: 322 get_image_tensor_for_llava_runner(llava_model) 323 get_tokenizer_for_llava_runner(llava_model) 324 325 326if __name__ == "__main__": 327 main() 328