1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# An ExecuTorch friendly implementation of Llava-1.5. 8 9import re 10 11from typing import Any, Dict, Optional, Tuple 12 13import requests 14import torch 15from executorch.examples.models.llama.llama_transformer import ModelArgs, Transformer 16 17from executorch.examples.models.llama.source_transformation.sdpa import ( 18 replace_sdpa_with_custom_op, 19) 20from executorch.examples.models.llava.image_util import prepare_image 21from executorch.examples.models.model_base import EagerModelBase 22from PIL import Image 23 24from torch.export import Dim 25from torchvision.transforms.v2 import functional as F 26 27from transformers import ( 28 AutoProcessor, 29 CLIPImageProcessor, 30 LlamaForCausalLM, 31 LlavaForConditionalGeneration, 32) 33 34 35class Llava(torch.nn.Module): 36 def __init__( 37 self, 38 llava_model: LlavaForConditionalGeneration, 39 image_processor: CLIPImageProcessor, 40 use_sdpa_with_kv_cache_op: bool = True, 41 max_seq_len: int = 768, 42 ): 43 super().__init__() 44 self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op 45 self.model_ = llava_model 46 self.image_processor = image_processor 47 self.vision_feature_layer = self.model_.config.vision_feature_layer 48 self.vision_feature_select_strategy = ( 49 self.model_.config.vision_feature_select_strategy 50 ) 51 self.text_model_args = ModelArgs( 52 use_kv_cache=True, 53 vocab_size=self.model_.config.text_config.vocab_size, 54 hidden_dim=self.model_.config.text_config.intermediate_size, 55 max_batch_size=1, # doesn't work with default batch size 32 56 ffn_dim_multiplier=1, # TODO: a hack to make rotary embedding happy 57 enable_dynamic_shape=True, # allow parallel prefill 58 use_sdpa_with_kv_cache_op=use_sdpa_with_kv_cache_op, # use sdpa_with_kv_cache op 59 use_hf_rope=True, 60 max_seq_len=max_seq_len, 61 ) 62 self.text_model = Transformer(self.text_model_args) 63 # use custom op for SDPA. 64 if use_sdpa_with_kv_cache_op: 65 self.text_model = replace_sdpa_with_custom_op(self.text_model) 66 # load state dict 67 self.text_model.load_state_dict( 68 state_dict=self._translate_state_dict_for_text_model(), 69 strict=False, 70 assign=True, 71 ) 72 73 def _translate_state_dict_for_text_model(self) -> Dict[str, Any]: 74 state_dict = self.model_.language_model.state_dict() 75 key_map = { 76 # fmt: off 77 r"model.layers.([0-9]+).self_attn.q_proj.": r"layers.\1.attention.wq.", 78 r"model.layers.([0-9]+).self_attn.k_proj.": r"layers.\1.attention.wk.", 79 r"model.layers.([0-9]+).self_attn.v_proj.": r"layers.\1.attention.wv.", 80 r"model.layers.([0-9]+).self_attn.o_proj.": r"layers.\1.attention.wo.", 81 r"model.layers.([0-9]+).input_layernorm.": r"layers.\1.attention_norm.", 82 r"model.layers.([0-9]+).mlp.gate_proj.": r"layers.\1.feed_forward.w1.", 83 r"model.layers.([0-9]+).mlp.down_proj.": r"layers.\1.feed_forward.w2.", 84 r"model.layers.([0-9]+).mlp.up_proj.": r"layers.\1.feed_forward.w3.", 85 r"model.layers.([0-9]+).post_attention_layernorm.": r"layers.\1.ffn_norm.", 86 r"model.norm.": r"norm.", 87 # r"model.embed_tokens.": r"tok_embeddings.", # load separately 88 r"lm_head.": r"output.", 89 # fmt: on 90 } 91 92 new_state_dict = {} 93 94 def get_new_key(old_key: str) -> str: 95 for old_pattern, replacement in key_map.items(): 96 if (new_key := re.sub(old_pattern, replacement, old_key)) != old_key: 97 return new_key 98 99 return old_key 100 101 # Convert module keys from hf transformer to Llama transformer. 102 for old_key in state_dict.keys(): 103 new_key = get_new_key(old_key) 104 105 new_state_dict[new_key] = state_dict[old_key] 106 107 return new_state_dict 108 109 def _feature_select(self, image_outputs): 110 selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer] 111 112 if self.vision_feature_select_strategy == "default": 113 selected_image_feature = selected_image_feature[:, 1:] 114 elif self.vision_feature_select_strategy == "full": 115 selected_image_feature = selected_image_feature 116 else: 117 raise ValueError( 118 f"Unexpected select feature: {self.vision_feature_select_strategy}" 119 ) 120 return selected_image_feature 121 122 def get_model(self): 123 return self.model_.get_model() 124 125 def embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor: 126 return self.model_.language_model.model.embed_tokens(tokens) 127 128 def encode_images(self, images: torch.Tensor) -> torch.Tensor: 129 images = images.to(dtype=self.model_.dtype) 130 if type(images) is list: 131 image_features = [] 132 for image in images: 133 image_forward_out = self.model_.vision_tower( 134 image.to( 135 device=self.model_.device, dtype=self.model_.dtype 136 ).unsqueeze(0), 137 output_hidden_states=True, 138 ) 139 image_feature = self._feature_select(image_forward_out).to(image.dtype) 140 image_features.append(image_feature) 141 else: 142 image_forward_outs = self.model_.vision_tower( 143 images.to(device=self.model_.device, dtype=self.model_.dtype), 144 output_hidden_states=True, 145 ) 146 image_features = self._feature_select(image_forward_outs).to(images.dtype) 147 image_features = self.model_.multi_modal_projector(image_features) 148 return image_features 149 150 def image_preprocess(self, img: torch.Tensor) -> torch.Tensor: 151 target_h = self.image_processor.crop_size["height"] 152 target_w = self.image_processor.crop_size["width"] 153 # pad the image with median rgb value, to make a square 154 l_pad = (target_w - img.shape[2]) // 2 155 t_pad = (target_h - img.shape[1]) // 2 156 # ceil division 157 r_pad = -((target_w - img.shape[2]) // -2) 158 b_pad = -((target_h - img.shape[1]) // -2) 159 160 torch._check(l_pad >= 0) 161 torch._check(t_pad >= 0) 162 torch._check(r_pad >= 0) 163 torch._check(b_pad >= 0) 164 165 # This is different from the original implementation, due to export limitations. 166 resized = torch.nn.functional.pad( 167 img, 168 (l_pad, r_pad, t_pad, b_pad), 169 ) 170 # originally: 171 # resized = F.pad( 172 # img, 173 # padding=(l_pad, t_pad, r_pad, b_pad), 174 # fill=tuple(int(x * 255) for x in self.image_mean), 175 # ) 176 177 # TODO: implement _upsample_bicubic_aa.out in portable kernel library. 178 # here padded shape should be max(h, w) x max(h, w) 179 # skipping resize for now due to missing _upsample_bicubic_aa kernel in portable 180 # resized = resize( 181 # padded, 182 # size=[ 183 # self.image_processor.crop_size["height"], 184 # self.image_processor.crop_size["width"], 185 # ], 186 # interpolation="bicubic", 187 # ) 188 # torch._check(resized.size(1) == self.config.crop_size["height"]) 189 # torch._check(resized.size(2) == self.config.crop_size["width"]) 190 # print(resized.shape) 191 # cropped = F.center_crop(img, output_size=[w, w]) 192 # print(cropped.shape) 193 scaled = resized * self.image_processor.rescale_factor 194 # print(scaled) 195 normed = F.normalize( 196 scaled, self.image_processor.image_mean, self.image_processor.image_std 197 ) 198 # print(normed) 199 return normed.unsqueeze(0) 200 201 def step( 202 self, token: torch.Tensor, input_pos: Optional[torch.Tensor] = None 203 ) -> torch.Tensor: 204 """Input is one token. Return logits for next token.""" 205 token_embeds = self.embed_tokens(token).unsqueeze(0) 206 return self.text_model.forward(None, input_pos, token_embeds) 207 208 def image_embedding(self, images: torch.Tensor) -> torch.Tensor: 209 preprocessed_img = self.image_preprocess(images) 210 return self.encode_images(preprocessed_img) 211 212 def prefill_embedding( 213 self, 214 prompt_before_image: torch.Tensor, 215 images: torch.Tensor, 216 prompt_after_image: torch.Tensor, 217 ) -> torch.Tensor: 218 image_embeds = self.image_embedding(images) 219 embeds_before_img = self.embed_tokens(prompt_before_image) 220 embeds_after_img = self.embed_tokens(prompt_after_image) 221 result = torch.cat((embeds_before_img, image_embeds, embeds_after_img), dim=1) 222 return result 223 224 # prefill using the in house text_model of llama transformer 225 def prefill( 226 self, 227 prompt_before_image: torch.Tensor, 228 images: torch.Tensor, 229 prompt_after_image: torch.Tensor, 230 ) -> Tuple[int, torch.Tensor]: 231 """Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead.""" 232 embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image) 233 # returns the prefilled token length too, because the text model generates one logits in each forward call. 234 return embeds.shape[1], self.text_model.forward(None, torch.tensor([0]), embeds) 235 236 # reference prefill using the text model in HF 237 def prefill_ref( 238 self, 239 prompt_before_image: torch.Tensor, 240 images: torch.Tensor, 241 prompt_after_image: torch.Tensor, 242 ) -> torch.Tensor: 243 """Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead.""" 244 embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image) 245 return LlamaForCausalLM.forward( 246 self.model_.language_model, 247 inputs_embeds=embeds, 248 return_dict=False, 249 use_cache=False, 250 output_hidden_states=False, 251 ) 252 253 def forward( 254 self, 255 images: torch.Tensor, 256 ) -> torch.Tensor: 257 return self.image_embedding(images) 258 259 260class LlavaModel(EagerModelBase): 261 def __init__(self, use_sdpa_with_kv_cache_op=True, max_seq_len=768): 262 self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op 263 self.max_seq_len = max_seq_len 264 self.processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") 265 self.tokenizer = self.processor.tokenizer 266 self.image_processor = self.processor.image_processor 267 self.model = LlavaForConditionalGeneration.from_pretrained( 268 "llava-hf/llava-1.5-7b-hf", 269 device_map="cpu", 270 ) 271 self.image = Image.open( 272 requests.get( 273 "https://llava-vl.github.io/static/images/view.jpg", stream=True 274 ).raw 275 ) 276 self.prompt = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image> 277What are the things I should be cautious about when I visit here? ASSISTANT:""" 278 self.model_name = "llava-1.5-7b-hf" 279 # set input to None and initialize them lazily 280 self.input = None 281 self.resized_image = None 282 283 def get_eager_model(self): 284 model = Llava( 285 self.model, 286 self.image_processor, 287 self.use_sdpa_with_kv_cache_op, 288 self.max_seq_len, 289 ) 290 model.to(dtype=torch.float32) 291 return model 292 293 def get_example_inputs(self): 294 """Returns a resized image as input to model.forward().""" 295 if self.resized_image: 296 return self.resized_image 297 resized = prepare_image( 298 self.image, 299 self.image_processor.crop_size["height"], 300 self.image_processor.crop_size["width"], 301 ) 302 self.resized_image = (resized,) 303 return self.resized_image 304 305 def get_inputs_for_prefill(self): 306 """Returns prompts as well as image.""" 307 if self.input: 308 return self.input 309 self.input_ids = self.tokenizer.encode(self.prompt, return_tensors="pt").cpu() 310 index = torch.where(self.input_ids == self.model.config.image_token_index)[1] 311 self.prompt_before_image = self.input_ids[:, :index] 312 # print(prompt_before_image.shape) 313 self.prompt_after_image = self.input_ids[:, index + 1 :] 314 # print(prompt_after_image.shape) 315 self.input = ( 316 self.prompt_before_image, 317 *self.get_example_inputs(), 318 self.prompt_after_image, 319 ) 320 return self.input 321 322 def get_dynamic_shapes(self): 323 return self._get_image_dynamic_shapes() 324 325 def _get_image_dynamic_shapes(self): 326 # only support even number of height and width for now 327 _height = Dim( 328 "_height", min=1, max=self.image_processor.crop_size["height"] // 2 329 ) 330 _width = Dim("_width", min=1, max=self.image_processor.crop_size["width"] // 2) 331 height = 2 * _height 332 width = 2 * _width 333 dynamic_shapes = [{1: height, 2: width}] 334 return dynamic_shapes 335 336 def _get_prompt_dynamic_shapes(self): 337 dim = torch.export.Dim("token_dim", min=2, max=self.max_seq_len) 338 text_model_dynamic_shapes = ({0: 1}, {1: dim}) 339 return text_model_dynamic_shapes 340