xref: /aosp_15_r20/external/executorch/examples/models/llava/model.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# An ExecuTorch friendly implementation of Llava-1.5.
8
9import re
10
11from typing import Any, Dict, Optional, Tuple
12
13import requests
14import torch
15from executorch.examples.models.llama.llama_transformer import ModelArgs, Transformer
16
17from executorch.examples.models.llama.source_transformation.sdpa import (
18    replace_sdpa_with_custom_op,
19)
20from executorch.examples.models.llava.image_util import prepare_image
21from executorch.examples.models.model_base import EagerModelBase
22from PIL import Image
23
24from torch.export import Dim
25from torchvision.transforms.v2 import functional as F
26
27from transformers import (
28    AutoProcessor,
29    CLIPImageProcessor,
30    LlamaForCausalLM,
31    LlavaForConditionalGeneration,
32)
33
34
35class Llava(torch.nn.Module):
36    def __init__(
37        self,
38        llava_model: LlavaForConditionalGeneration,
39        image_processor: CLIPImageProcessor,
40        use_sdpa_with_kv_cache_op: bool = True,
41        max_seq_len: int = 768,
42    ):
43        super().__init__()
44        self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
45        self.model_ = llava_model
46        self.image_processor = image_processor
47        self.vision_feature_layer = self.model_.config.vision_feature_layer
48        self.vision_feature_select_strategy = (
49            self.model_.config.vision_feature_select_strategy
50        )
51        self.text_model_args = ModelArgs(
52            use_kv_cache=True,
53            vocab_size=self.model_.config.text_config.vocab_size,
54            hidden_dim=self.model_.config.text_config.intermediate_size,
55            max_batch_size=1,  # doesn't work with default batch size 32
56            ffn_dim_multiplier=1,  # TODO: a hack to make rotary embedding happy
57            enable_dynamic_shape=True,  # allow parallel prefill
58            use_sdpa_with_kv_cache_op=use_sdpa_with_kv_cache_op,  # use sdpa_with_kv_cache op
59            use_hf_rope=True,
60            max_seq_len=max_seq_len,
61        )
62        self.text_model = Transformer(self.text_model_args)
63        # use custom op for SDPA.
64        if use_sdpa_with_kv_cache_op:
65            self.text_model = replace_sdpa_with_custom_op(self.text_model)
66        # load state dict
67        self.text_model.load_state_dict(
68            state_dict=self._translate_state_dict_for_text_model(),
69            strict=False,
70            assign=True,
71        )
72
73    def _translate_state_dict_for_text_model(self) -> Dict[str, Any]:
74        state_dict = self.model_.language_model.state_dict()
75        key_map = {
76            # fmt: off
77            r"model.layers.([0-9]+).self_attn.q_proj.": r"layers.\1.attention.wq.",
78            r"model.layers.([0-9]+).self_attn.k_proj.": r"layers.\1.attention.wk.",
79            r"model.layers.([0-9]+).self_attn.v_proj.": r"layers.\1.attention.wv.",
80            r"model.layers.([0-9]+).self_attn.o_proj.": r"layers.\1.attention.wo.",
81            r"model.layers.([0-9]+).input_layernorm.": r"layers.\1.attention_norm.",
82            r"model.layers.([0-9]+).mlp.gate_proj.": r"layers.\1.feed_forward.w1.",
83            r"model.layers.([0-9]+).mlp.down_proj.": r"layers.\1.feed_forward.w2.",
84            r"model.layers.([0-9]+).mlp.up_proj.": r"layers.\1.feed_forward.w3.",
85            r"model.layers.([0-9]+).post_attention_layernorm.": r"layers.\1.ffn_norm.",
86            r"model.norm.": r"norm.",
87            # r"model.embed_tokens.": r"tok_embeddings.", # load separately
88            r"lm_head.": r"output.",
89            # fmt: on
90        }
91
92        new_state_dict = {}
93
94        def get_new_key(old_key: str) -> str:
95            for old_pattern, replacement in key_map.items():
96                if (new_key := re.sub(old_pattern, replacement, old_key)) != old_key:
97                    return new_key
98
99            return old_key
100
101        # Convert module keys from hf transformer to Llama transformer.
102        for old_key in state_dict.keys():
103            new_key = get_new_key(old_key)
104
105            new_state_dict[new_key] = state_dict[old_key]
106
107        return new_state_dict
108
109    def _feature_select(self, image_outputs):
110        selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
111
112        if self.vision_feature_select_strategy == "default":
113            selected_image_feature = selected_image_feature[:, 1:]
114        elif self.vision_feature_select_strategy == "full":
115            selected_image_feature = selected_image_feature
116        else:
117            raise ValueError(
118                f"Unexpected select feature: {self.vision_feature_select_strategy}"
119            )
120        return selected_image_feature
121
122    def get_model(self):
123        return self.model_.get_model()
124
125    def embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
126        return self.model_.language_model.model.embed_tokens(tokens)
127
128    def encode_images(self, images: torch.Tensor) -> torch.Tensor:
129        images = images.to(dtype=self.model_.dtype)
130        if type(images) is list:
131            image_features = []
132            for image in images:
133                image_forward_out = self.model_.vision_tower(
134                    image.to(
135                        device=self.model_.device, dtype=self.model_.dtype
136                    ).unsqueeze(0),
137                    output_hidden_states=True,
138                )
139                image_feature = self._feature_select(image_forward_out).to(image.dtype)
140                image_features.append(image_feature)
141        else:
142            image_forward_outs = self.model_.vision_tower(
143                images.to(device=self.model_.device, dtype=self.model_.dtype),
144                output_hidden_states=True,
145            )
146            image_features = self._feature_select(image_forward_outs).to(images.dtype)
147        image_features = self.model_.multi_modal_projector(image_features)
148        return image_features
149
150    def image_preprocess(self, img: torch.Tensor) -> torch.Tensor:
151        target_h = self.image_processor.crop_size["height"]
152        target_w = self.image_processor.crop_size["width"]
153        # pad the image with median rgb value, to make a square
154        l_pad = (target_w - img.shape[2]) // 2
155        t_pad = (target_h - img.shape[1]) // 2
156        # ceil division
157        r_pad = -((target_w - img.shape[2]) // -2)
158        b_pad = -((target_h - img.shape[1]) // -2)
159
160        torch._check(l_pad >= 0)
161        torch._check(t_pad >= 0)
162        torch._check(r_pad >= 0)
163        torch._check(b_pad >= 0)
164
165        # This is different from the original implementation, due to export limitations.
166        resized = torch.nn.functional.pad(
167            img,
168            (l_pad, r_pad, t_pad, b_pad),
169        )
170        # originally:
171        # resized = F.pad(
172        #     img,
173        #     padding=(l_pad, t_pad, r_pad, b_pad),
174        #     fill=tuple(int(x * 255) for x in self.image_mean),
175        # )
176
177        # TODO: implement _upsample_bicubic_aa.out in portable kernel library.
178        # here padded shape should be max(h, w) x max(h, w)
179        # skipping resize for now due to missing _upsample_bicubic_aa kernel in portable
180        # resized = resize(
181        #     padded,
182        #     size=[
183        #         self.image_processor.crop_size["height"],
184        #         self.image_processor.crop_size["width"],
185        #     ],
186        #     interpolation="bicubic",
187        # )
188        # torch._check(resized.size(1) == self.config.crop_size["height"])
189        # torch._check(resized.size(2) == self.config.crop_size["width"])
190        # print(resized.shape)
191        # cropped = F.center_crop(img, output_size=[w, w])
192        # print(cropped.shape)
193        scaled = resized * self.image_processor.rescale_factor
194        # print(scaled)
195        normed = F.normalize(
196            scaled, self.image_processor.image_mean, self.image_processor.image_std
197        )
198        # print(normed)
199        return normed.unsqueeze(0)
200
201    def step(
202        self, token: torch.Tensor, input_pos: Optional[torch.Tensor] = None
203    ) -> torch.Tensor:
204        """Input is one token. Return logits for next token."""
205        token_embeds = self.embed_tokens(token).unsqueeze(0)
206        return self.text_model.forward(None, input_pos, token_embeds)
207
208    def image_embedding(self, images: torch.Tensor) -> torch.Tensor:
209        preprocessed_img = self.image_preprocess(images)
210        return self.encode_images(preprocessed_img)
211
212    def prefill_embedding(
213        self,
214        prompt_before_image: torch.Tensor,
215        images: torch.Tensor,
216        prompt_after_image: torch.Tensor,
217    ) -> torch.Tensor:
218        image_embeds = self.image_embedding(images)
219        embeds_before_img = self.embed_tokens(prompt_before_image)
220        embeds_after_img = self.embed_tokens(prompt_after_image)
221        result = torch.cat((embeds_before_img, image_embeds, embeds_after_img), dim=1)
222        return result
223
224    # prefill using the in house text_model of llama transformer
225    def prefill(
226        self,
227        prompt_before_image: torch.Tensor,
228        images: torch.Tensor,
229        prompt_after_image: torch.Tensor,
230    ) -> Tuple[int, torch.Tensor]:
231        """Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead."""
232        embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image)
233        # returns the prefilled token length too, because the text model generates one logits in each forward call.
234        return embeds.shape[1], self.text_model.forward(None, torch.tensor([0]), embeds)
235
236    # reference prefill using the text model in HF
237    def prefill_ref(
238        self,
239        prompt_before_image: torch.Tensor,
240        images: torch.Tensor,
241        prompt_after_image: torch.Tensor,
242    ) -> torch.Tensor:
243        """Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead."""
244        embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image)
245        return LlamaForCausalLM.forward(
246            self.model_.language_model,
247            inputs_embeds=embeds,
248            return_dict=False,
249            use_cache=False,
250            output_hidden_states=False,
251        )
252
253    def forward(
254        self,
255        images: torch.Tensor,
256    ) -> torch.Tensor:
257        return self.image_embedding(images)
258
259
260class LlavaModel(EagerModelBase):
261    def __init__(self, use_sdpa_with_kv_cache_op=True, max_seq_len=768):
262        self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
263        self.max_seq_len = max_seq_len
264        self.processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
265        self.tokenizer = self.processor.tokenizer
266        self.image_processor = self.processor.image_processor
267        self.model = LlavaForConditionalGeneration.from_pretrained(
268            "llava-hf/llava-1.5-7b-hf",
269            device_map="cpu",
270        )
271        self.image = Image.open(
272            requests.get(
273                "https://llava-vl.github.io/static/images/view.jpg", stream=True
274            ).raw
275        )
276        self.prompt = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>
277What are the things I should be cautious about when I visit here? ASSISTANT:"""
278        self.model_name = "llava-1.5-7b-hf"
279        # set input to None and initialize them lazily
280        self.input = None
281        self.resized_image = None
282
283    def get_eager_model(self):
284        model = Llava(
285            self.model,
286            self.image_processor,
287            self.use_sdpa_with_kv_cache_op,
288            self.max_seq_len,
289        )
290        model.to(dtype=torch.float32)
291        return model
292
293    def get_example_inputs(self):
294        """Returns a resized image as input to model.forward()."""
295        if self.resized_image:
296            return self.resized_image
297        resized = prepare_image(
298            self.image,
299            self.image_processor.crop_size["height"],
300            self.image_processor.crop_size["width"],
301        )
302        self.resized_image = (resized,)
303        return self.resized_image
304
305    def get_inputs_for_prefill(self):
306        """Returns prompts as well as image."""
307        if self.input:
308            return self.input
309        self.input_ids = self.tokenizer.encode(self.prompt, return_tensors="pt").cpu()
310        index = torch.where(self.input_ids == self.model.config.image_token_index)[1]
311        self.prompt_before_image = self.input_ids[:, :index]
312        # print(prompt_before_image.shape)
313        self.prompt_after_image = self.input_ids[:, index + 1 :]
314        # print(prompt_after_image.shape)
315        self.input = (
316            self.prompt_before_image,
317            *self.get_example_inputs(),
318            self.prompt_after_image,
319        )
320        return self.input
321
322    def get_dynamic_shapes(self):
323        return self._get_image_dynamic_shapes()
324
325    def _get_image_dynamic_shapes(self):
326        # only support even number of height and width for now
327        _height = Dim(
328            "_height", min=1, max=self.image_processor.crop_size["height"] // 2
329        )
330        _width = Dim("_width", min=1, max=self.image_processor.crop_size["width"] // 2)
331        height = 2 * _height
332        width = 2 * _width
333        dynamic_shapes = [{1: height, 2: width}]
334        return dynamic_shapes
335
336    def _get_prompt_dynamic_shapes(self):
337        dim = torch.export.Dim("token_dim", min=2, max=self.max_seq_len)
338        text_model_dynamic_shapes = ({0: 1}, {1: dim})
339        return text_model_dynamic_shapes
340