xref: /aosp_15_r20/external/executorch/examples/demo-apps/apple_ios/LLaMA/LLaMA/Application/ContentView.swift (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 import SwiftUI
10 import UniformTypeIdentifiers
11 
12 import LLaMARunner
13 
14 class RunnerHolder: ObservableObject {
15   var runner: Runner?
16   var llavaRunner: LLaVARunner?
17 }
18 
19 extension UIImage {
resizednull20   func resized(to newSize: CGSize) -> UIImage {
21     let format = UIGraphicsImageRendererFormat.default()
22     format.scale = 1
23     return UIGraphicsImageRenderer(size: newSize, format: format).image {
24       _ in draw(in: CGRect(origin: .zero, size: newSize))
25     }
26   }
27 
toRGBArraynull28   func toRGBArray() -> [UInt8]? {
29     guard let cgImage = self.cgImage else { return nil }
30 
31     let width = Int(cgImage.width), height = Int(cgImage.height)
32     let totalPixels = width * height, bytesPerPixel = 4, bytesPerRow = bytesPerPixel * width
33     var rgbValues = [UInt8](repeating: 0, count: totalPixels * 3)
34     var pixelData = [UInt8](repeating: 0, count: width * height * bytesPerPixel)
35 
36     guard let context = CGContext(
37       data: &pixelData, width: width, height: height, bitsPerComponent: 8,
38       bytesPerRow: bytesPerRow, space: CGColorSpaceCreateDeviceRGB(),
39       bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue | CGBitmapInfo.byteOrder32Big.rawValue
40     ) else { return nil }
41 
42     context.draw(cgImage, in: CGRect(x: 0, y: 0, width: width, height: height))
43 
44     for y in 0..<height {
45       for x in 0..<width {
46         let pixelIndex = (y * width + x) * bytesPerPixel
47         let rgbIndex = y * width + x
48         rgbValues[rgbIndex] = pixelData[pixelIndex]
49         rgbValues[rgbIndex + totalPixels] = pixelData[pixelIndex + 1]
50         rgbValues[rgbIndex + totalPixels * 2] = pixelData[pixelIndex + 2]
51       }
52     }
53     return rgbValues
54   }
55 }
56 
57 struct ContentView: View {
58   @State private var prompt = ""
59   @State private var messages: [Message] = []
60   @State private var showingLogs = false
61   @State private var pickerType: PickerType?
62   @State private var isGenerating = false
63   @State private var shouldStopGenerating = false
64   @State private var shouldStopShowingToken = false
65   private let runnerQueue = DispatchQueue(label: "org.pytorch.executorch.llama")
66   @StateObject private var runnerHolder = RunnerHolder()
67   @StateObject private var resourceManager = ResourceManager()
68   @StateObject private var resourceMonitor = ResourceMonitor()
69   @StateObject private var logManager = LogManager()
70 
71   @State private var isImagePickerPresented = false
72   @State private var selectedImage: UIImage?
73   @State private var imagePickerSourceType: UIImagePickerController.SourceType = .photoLibrary
74 
75   @State private var showingSettings = false
76 
77   enum PickerType {
78     case model
79     case tokenizer
80   }
81 
82   private var placeholder: String {
83     resourceManager.isModelValid ? resourceManager.isTokenizerValid ? "Prompt..." : "Select Tokenizer..." : "Select Model..."
84   }
85 
86   private var title: String {
87     resourceManager.isModelValid ? resourceManager.isTokenizerValid ? resourceManager.modelName : "Select Tokenizer..." : "Select Model..."
88   }
89 
90   private var modelTitle: String {
91     resourceManager.isModelValid ? resourceManager.modelName : "Select Model..."
92   }
93 
94   private var tokenizerTitle: String {
95     resourceManager.isTokenizerValid ? resourceManager.tokenizerName : "Select Tokenizer..."
96   }
97 
98   private var isInputEnabled: Bool { resourceManager.isModelValid && resourceManager.isTokenizerValid }
99 
100   var body: some View {
101     NavigationView {
102       VStack {
103         if showingSettings {
104           VStack(spacing: 20) {
105             Form {
106               Section(header: Text("Model and Tokenizer")
107                         .font(.headline)
108                         .foregroundColor(.primary)) {
109                 Button(action: { pickerType = .model }) {
110                   Label(resourceManager.modelName == "" ? modelTitle : resourceManager.modelName, systemImage: "doc")
111                 }
112                 Button(action: { pickerType = .tokenizer }) {
113                   Label(resourceManager.tokenizerName == "" ? tokenizerTitle : resourceManager.tokenizerName, systemImage: "doc")
114                 }
115               }
116             }
117           }
118         }
119 
120         MessageListView(messages: $messages)
121           .gesture(
122             DragGesture().onChanged { value in
123               if value.translation.height > 10 {
124                 UIApplication.shared.sendAction(#selector(UIResponder.resignFirstResponder), to: nil, from: nil, for: nil)
125               }
126             }
127           )
128         HStack {
129           Button(action: {
130             imagePickerSourceType = .photoLibrary
131             isImagePickerPresented = true
132           }) {
133             Image(systemName: "photo.on.rectangle")
134               .resizable()
135               .scaledToFit()
136               .frame(width: 24, height: 24)
137           }
138           .background(Color.clear)
139           .cornerRadius(8)
140 
141           Button(action: {
142             if UIImagePickerController.isSourceTypeAvailable(.camera) {
143               imagePickerSourceType = .camera
144               isImagePickerPresented = true
145             } else {
146               print("Camera not available")
147             }
148           }) {
149             Image(systemName: "camera")
150               .resizable()
151               .scaledToFit()
152               .frame(width: 24, height: 24)
153           }
154           .background(Color.clear)
155           .cornerRadius(8)
156 
157           TextField(placeholder, text: $prompt, axis: .vertical)
158             .padding(8)
159             .background(Color.gray.opacity(0.1))
160             .cornerRadius(20)
161             .lineLimit(1...10)
162             .overlay(
163               RoundedRectangle(cornerRadius: 20)
164                 .stroke(isInputEnabled ? Color.blue : Color.gray, lineWidth: 1)
165             )
166             .disabled(!isInputEnabled)
167 
168           Button(action: isGenerating ? stop : generate) {
169             Image(systemName: isGenerating ? "stop.circle" : "arrowshape.up.circle.fill")
170               .resizable()
171               .aspectRatio(contentMode: .fit)
172               .frame(height: 28)
173           }
174           .disabled(isGenerating ? shouldStopGenerating : (!isInputEnabled || prompt.isEmpty))
175         }
176         .padding([.leading, .trailing, .bottom], 10)
177         .sheet(isPresented: $isImagePickerPresented, onDismiss: addSelectedImageMessage) {
178           ImagePicker(selectedImage: $selectedImage, sourceType: imagePickerSourceType)
179             .id(imagePickerSourceType.rawValue)
180         }
181       }
182       .navigationBarTitle(title, displayMode: .inline)
183       .navigationBarItems(
184         leading:
185           Button(action: {
186             showingSettings.toggle()
187           }) {
188             Image(systemName: "gearshape")
189               .imageScale(.large)
190           },
191         trailing:
192           HStack {
193             Menu {
194               Section(header: Text("Memory")) {
195                 Text("Used: \(resourceMonitor.usedMemory) Mb")
196                 Text("Available: \(resourceMonitor.usedMemory) Mb")
197               }
198             } label: {
199               Text("\(resourceMonitor.usedMemory) Mb")
200             }
201             .onAppear {
202               resourceMonitor.start()
203             }
204             .onDisappear {
205               resourceMonitor.stop()
206             }
207             Button(action: { showingLogs = true }) {
208               Image(systemName: "list.bullet.rectangle")
209             }
210           }
211       )
212       .sheet(isPresented: $showingLogs) {
213         NavigationView {
214           LogView(logManager: logManager)
215         }
216       }
217       .fileImporter(
218         isPresented: Binding<Bool>(
219           get: { pickerType != nil },
220           set: { if !$0 { pickerType = nil } }
221         ),
222         allowedContentTypes: allowedContentTypes(),
223         allowsMultipleSelection: false
224       ) { [pickerType] result in
225         handleFileImportResult(pickerType, result)
226       }
227       .onAppear {
228         do {
229           try resourceManager.createDirectoriesIfNeeded()
230         } catch {
231           withAnimation {
232             messages.append(Message(type: .info, text: "Error creating content directories: \(error.localizedDescription)"))
233           }
234         }
235       }
236     }
237     .navigationViewStyle(StackNavigationViewStyle())
238   }
239 
addSelectedImageMessagenull240   private func addSelectedImageMessage() {
241     if let selectedImage {
242       messages.append(Message(image: selectedImage))
243     }
244   }
245 
generatenull246   private func generate() {
247     guard !prompt.isEmpty else { return }
248     isGenerating = true
249     shouldStopGenerating = false
250     shouldStopShowingToken = false
251     let text = prompt.trimmingCharacters(in: .whitespacesAndNewlines)
252     let seq_len = 768 // text: 256, vision: 768
253     let modelPath = resourceManager.modelPath
254     let tokenizerPath = resourceManager.tokenizerPath
255     let useLlama = modelPath.lowercased().contains("llama")
256 
257     prompt = ""
258     hideKeyboard()
259     showingSettings = false
260 
261     messages.append(Message(text: text))
262     messages.append(Message(type: useLlama ? .llamagenerated : .llavagenerated))
263 
264     runnerQueue.async {
265       defer {
266         DispatchQueue.main.async {
267           isGenerating = false
268           selectedImage = nil
269         }
270       }
271 
272       if useLlama {
273         runnerHolder.runner = runnerHolder.runner ?? Runner(modelPath: modelPath, tokenizerPath: tokenizerPath)
274       } else {
275         runnerHolder.llavaRunner = runnerHolder.llavaRunner ?? LLaVARunner(modelPath: modelPath, tokenizerPath: tokenizerPath)
276       }
277 
278       guard !shouldStopGenerating else { return }
279       if useLlama {
280         if let runner = runnerHolder.runner, !runner.isLoaded() {
281           var error: Error?
282           let startLoadTime = Date()
283           do {
284             try runner.load()
285           } catch let loadError {
286             error = loadError
287           }
288 
289           let loadTime = Date().timeIntervalSince(startLoadTime)
290           DispatchQueue.main.async {
291             withAnimation {
292               var message = messages.removeLast()
293               message.type = .info
294               if let error {
295                 message.text = "Model loading failed: error \((error as NSError).code)"
296               } else {
297                 message.text = "Model loaded in \(String(format: "%.2f", loadTime)) s"
298               }
299               messages.append(message)
300               if error == nil {
301                 messages.append(Message(type: .llamagenerated))
302               }
303             }
304           }
305           if error != nil {
306             return
307           }
308         }
309       } else {
310         if let runner = runnerHolder.llavaRunner, !runner.isLoaded() {
311           var error: Error?
312           let startLoadTime = Date()
313           do {
314             try runner.load()
315           } catch let loadError {
316             error = loadError
317           }
318 
319           let loadTime = Date().timeIntervalSince(startLoadTime)
320           DispatchQueue.main.async {
321             withAnimation {
322               var message = messages.removeLast()
323               message.type = .info
324               if let error {
325                 message.text = "Model loading failed: error \((error as NSError).code)"
326               } else {
327                 message.text = "Model loaded in \(String(format: "%.2f", loadTime)) s"
328               }
329               messages.append(message)
330               if error == nil {
331                 messages.append(Message(type: .llavagenerated))
332               }
333             }
334           }
335           if error != nil {
336             return
337           }
338         }
339       }
340 
341       guard !shouldStopGenerating else {
342         DispatchQueue.main.async {
343           withAnimation {
344             _ = messages.removeLast()
345           }
346         }
347         return
348       }
349       do {
350         var tokens: [String] = []
351         var rgbArray: [UInt8]?
352         let MAX_WIDTH = 336.0
353         var newHeight = 0.0
354         var imageBuffer: UnsafeMutableRawPointer?
355 
356         if let img = selectedImage {
357           let llava_prompt = "\(text) ASSISTANT"
358 
359           newHeight = MAX_WIDTH * img.size.height / img.size.width
360           let resizedImage = img.resized(to: CGSize(width: MAX_WIDTH, height: newHeight))
361           rgbArray = resizedImage.toRGBArray()
362           imageBuffer = UnsafeMutableRawPointer(mutating: rgbArray)
363 
364           try runnerHolder.llavaRunner?.generate(imageBuffer!, width: MAX_WIDTH, height: newHeight, prompt: llava_prompt, sequenceLength: seq_len) { token in
365 
366             if token != llava_prompt {
367               if token == "</s>" {
368                 shouldStopGenerating = true
369                 runnerHolder.llavaRunner?.stop()
370               } else {
371                 tokens.append(token)
372                 if tokens.count > 2 {
373                   let text = tokens.joined()
374                   let count = tokens.count
375                   tokens = []
376                   DispatchQueue.main.async {
377                     var message = messages.removeLast()
378                     message.text += text
379                     message.tokenCount += count
380                     message.dateUpdated = Date()
381                     messages.append(message)
382                   }
383                 }
384                 if shouldStopGenerating {
385                   runnerHolder.llavaRunner?.stop()
386                 }
387               }
388             }
389           }
390         } else {
391           let llama3_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\(text)<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
392 
393           try runnerHolder.runner?.generate(llama3_prompt, sequenceLength: seq_len) { token in
394 
395             NSLog(">>> token={\(token)}")
396             if token != llama3_prompt {
397               // hack to fix the issue that extension/llm/runner/text_token_generator.h
398               // keeps generating after <|eot_id|>
399               if token == "<|eot_id|>" {
400                 shouldStopShowingToken = true
401               } else {
402                 tokens.append(token.trimmingCharacters(in: .newlines))
403                 if tokens.count > 2 {
404                   let text = tokens.joined()
405                   let count = tokens.count
406                   tokens = []
407                   DispatchQueue.main.async {
408                     var message = messages.removeLast()
409                     message.text += text
410                     message.tokenCount += count
411                     message.dateUpdated = Date()
412                     messages.append(message)
413                   }
414                 }
415                 if shouldStopGenerating {
416                   runnerHolder.runner?.stop()
417                 }
418               }
419             }
420           }
421         }
422       } catch {
423         DispatchQueue.main.async {
424           withAnimation {
425             var message = messages.removeLast()
426             message.type = .info
427             message.text = "Text generation failed: error \((error as NSError).code)"
428             messages.append(message)
429           }
430         }
431       }
432     }
433   }
434 
stopnull435   private func stop() {
436     shouldStopGenerating = true
437   }
438 
allowedContentTypesnull439   private func allowedContentTypes() -> [UTType] {
440     guard let pickerType else { return [] }
441     switch pickerType {
442     case .model:
443       return [UTType(filenameExtension: "pte")].compactMap { $0 }
444     case .tokenizer:
445       return [UTType(filenameExtension: "bin"), UTType(filenameExtension: "model")].compactMap { $0 }
446     }
447   }
448 
handleFileImportResultnull449   private func handleFileImportResult(_ pickerType: PickerType?, _ result: Result<[URL], Error>) {
450     switch result {
451     case .success(let urls):
452       guard let url = urls.first, let pickerType else {
453         withAnimation {
454           messages.append(Message(type: .info, text: "Failed to select a file"))
455         }
456         return
457       }
458       runnerQueue.async {
459         runnerHolder.runner = nil
460         runnerHolder.llavaRunner = nil
461       }
462       switch pickerType {
463       case .model:
464         resourceManager.modelPath = url.path
465       case .tokenizer:
466         resourceManager.tokenizerPath = url.path
467       }
468     case .failure(let error):
469       withAnimation {
470         messages.append(Message(type: .info, text: "Failed to select a file: \(error.localizedDescription)"))
471       }
472     }
473   }
474 }
475 
476 extension View {
hideKeyboardnull477   func hideKeyboard() {
478     UIApplication.shared.sendAction(#selector(UIResponder.resignFirstResponder), to: nil, from: nil, for: nil)
479   }
480 }
481