1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 import SwiftUI 10 import UniformTypeIdentifiers 11 12 import LLaMARunner 13 14 class RunnerHolder: ObservableObject { 15 var runner: Runner? 16 var llavaRunner: LLaVARunner? 17 } 18 19 extension UIImage { resizednull20 func resized(to newSize: CGSize) -> UIImage { 21 let format = UIGraphicsImageRendererFormat.default() 22 format.scale = 1 23 return UIGraphicsImageRenderer(size: newSize, format: format).image { 24 _ in draw(in: CGRect(origin: .zero, size: newSize)) 25 } 26 } 27 toRGBArraynull28 func toRGBArray() -> [UInt8]? { 29 guard let cgImage = self.cgImage else { return nil } 30 31 let width = Int(cgImage.width), height = Int(cgImage.height) 32 let totalPixels = width * height, bytesPerPixel = 4, bytesPerRow = bytesPerPixel * width 33 var rgbValues = [UInt8](repeating: 0, count: totalPixels * 3) 34 var pixelData = [UInt8](repeating: 0, count: width * height * bytesPerPixel) 35 36 guard let context = CGContext( 37 data: &pixelData, width: width, height: height, bitsPerComponent: 8, 38 bytesPerRow: bytesPerRow, space: CGColorSpaceCreateDeviceRGB(), 39 bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue | CGBitmapInfo.byteOrder32Big.rawValue 40 ) else { return nil } 41 42 context.draw(cgImage, in: CGRect(x: 0, y: 0, width: width, height: height)) 43 44 for y in 0..<height { 45 for x in 0..<width { 46 let pixelIndex = (y * width + x) * bytesPerPixel 47 let rgbIndex = y * width + x 48 rgbValues[rgbIndex] = pixelData[pixelIndex] 49 rgbValues[rgbIndex + totalPixels] = pixelData[pixelIndex + 1] 50 rgbValues[rgbIndex + totalPixels * 2] = pixelData[pixelIndex + 2] 51 } 52 } 53 return rgbValues 54 } 55 } 56 57 struct ContentView: View { 58 @State private var prompt = "" 59 @State private var messages: [Message] = [] 60 @State private var showingLogs = false 61 @State private var pickerType: PickerType? 62 @State private var isGenerating = false 63 @State private var shouldStopGenerating = false 64 @State private var shouldStopShowingToken = false 65 private let runnerQueue = DispatchQueue(label: "org.pytorch.executorch.llama") 66 @StateObject private var runnerHolder = RunnerHolder() 67 @StateObject private var resourceManager = ResourceManager() 68 @StateObject private var resourceMonitor = ResourceMonitor() 69 @StateObject private var logManager = LogManager() 70 71 @State private var isImagePickerPresented = false 72 @State private var selectedImage: UIImage? 73 @State private var imagePickerSourceType: UIImagePickerController.SourceType = .photoLibrary 74 75 @State private var showingSettings = false 76 77 enum PickerType { 78 case model 79 case tokenizer 80 } 81 82 private var placeholder: String { 83 resourceManager.isModelValid ? resourceManager.isTokenizerValid ? "Prompt..." : "Select Tokenizer..." : "Select Model..." 84 } 85 86 private var title: String { 87 resourceManager.isModelValid ? resourceManager.isTokenizerValid ? resourceManager.modelName : "Select Tokenizer..." : "Select Model..." 88 } 89 90 private var modelTitle: String { 91 resourceManager.isModelValid ? resourceManager.modelName : "Select Model..." 92 } 93 94 private var tokenizerTitle: String { 95 resourceManager.isTokenizerValid ? resourceManager.tokenizerName : "Select Tokenizer..." 96 } 97 98 private var isInputEnabled: Bool { resourceManager.isModelValid && resourceManager.isTokenizerValid } 99 100 var body: some View { 101 NavigationView { 102 VStack { 103 if showingSettings { 104 VStack(spacing: 20) { 105 Form { 106 Section(header: Text("Model and Tokenizer") 107 .font(.headline) 108 .foregroundColor(.primary)) { 109 Button(action: { pickerType = .model }) { 110 Label(resourceManager.modelName == "" ? modelTitle : resourceManager.modelName, systemImage: "doc") 111 } 112 Button(action: { pickerType = .tokenizer }) { 113 Label(resourceManager.tokenizerName == "" ? tokenizerTitle : resourceManager.tokenizerName, systemImage: "doc") 114 } 115 } 116 } 117 } 118 } 119 120 MessageListView(messages: $messages) 121 .gesture( 122 DragGesture().onChanged { value in 123 if value.translation.height > 10 { 124 UIApplication.shared.sendAction(#selector(UIResponder.resignFirstResponder), to: nil, from: nil, for: nil) 125 } 126 } 127 ) 128 HStack { 129 Button(action: { 130 imagePickerSourceType = .photoLibrary 131 isImagePickerPresented = true 132 }) { 133 Image(systemName: "photo.on.rectangle") 134 .resizable() 135 .scaledToFit() 136 .frame(width: 24, height: 24) 137 } 138 .background(Color.clear) 139 .cornerRadius(8) 140 141 Button(action: { 142 if UIImagePickerController.isSourceTypeAvailable(.camera) { 143 imagePickerSourceType = .camera 144 isImagePickerPresented = true 145 } else { 146 print("Camera not available") 147 } 148 }) { 149 Image(systemName: "camera") 150 .resizable() 151 .scaledToFit() 152 .frame(width: 24, height: 24) 153 } 154 .background(Color.clear) 155 .cornerRadius(8) 156 157 TextField(placeholder, text: $prompt, axis: .vertical) 158 .padding(8) 159 .background(Color.gray.opacity(0.1)) 160 .cornerRadius(20) 161 .lineLimit(1...10) 162 .overlay( 163 RoundedRectangle(cornerRadius: 20) 164 .stroke(isInputEnabled ? Color.blue : Color.gray, lineWidth: 1) 165 ) 166 .disabled(!isInputEnabled) 167 168 Button(action: isGenerating ? stop : generate) { 169 Image(systemName: isGenerating ? "stop.circle" : "arrowshape.up.circle.fill") 170 .resizable() 171 .aspectRatio(contentMode: .fit) 172 .frame(height: 28) 173 } 174 .disabled(isGenerating ? shouldStopGenerating : (!isInputEnabled || prompt.isEmpty)) 175 } 176 .padding([.leading, .trailing, .bottom], 10) 177 .sheet(isPresented: $isImagePickerPresented, onDismiss: addSelectedImageMessage) { 178 ImagePicker(selectedImage: $selectedImage, sourceType: imagePickerSourceType) 179 .id(imagePickerSourceType.rawValue) 180 } 181 } 182 .navigationBarTitle(title, displayMode: .inline) 183 .navigationBarItems( 184 leading: 185 Button(action: { 186 showingSettings.toggle() 187 }) { 188 Image(systemName: "gearshape") 189 .imageScale(.large) 190 }, 191 trailing: 192 HStack { 193 Menu { 194 Section(header: Text("Memory")) { 195 Text("Used: \(resourceMonitor.usedMemory) Mb") 196 Text("Available: \(resourceMonitor.usedMemory) Mb") 197 } 198 } label: { 199 Text("\(resourceMonitor.usedMemory) Mb") 200 } 201 .onAppear { 202 resourceMonitor.start() 203 } 204 .onDisappear { 205 resourceMonitor.stop() 206 } 207 Button(action: { showingLogs = true }) { 208 Image(systemName: "list.bullet.rectangle") 209 } 210 } 211 ) 212 .sheet(isPresented: $showingLogs) { 213 NavigationView { 214 LogView(logManager: logManager) 215 } 216 } 217 .fileImporter( 218 isPresented: Binding<Bool>( 219 get: { pickerType != nil }, 220 set: { if !$0 { pickerType = nil } } 221 ), 222 allowedContentTypes: allowedContentTypes(), 223 allowsMultipleSelection: false 224 ) { [pickerType] result in 225 handleFileImportResult(pickerType, result) 226 } 227 .onAppear { 228 do { 229 try resourceManager.createDirectoriesIfNeeded() 230 } catch { 231 withAnimation { 232 messages.append(Message(type: .info, text: "Error creating content directories: \(error.localizedDescription)")) 233 } 234 } 235 } 236 } 237 .navigationViewStyle(StackNavigationViewStyle()) 238 } 239 addSelectedImageMessagenull240 private func addSelectedImageMessage() { 241 if let selectedImage { 242 messages.append(Message(image: selectedImage)) 243 } 244 } 245 generatenull246 private func generate() { 247 guard !prompt.isEmpty else { return } 248 isGenerating = true 249 shouldStopGenerating = false 250 shouldStopShowingToken = false 251 let text = prompt.trimmingCharacters(in: .whitespacesAndNewlines) 252 let seq_len = 768 // text: 256, vision: 768 253 let modelPath = resourceManager.modelPath 254 let tokenizerPath = resourceManager.tokenizerPath 255 let useLlama = modelPath.lowercased().contains("llama") 256 257 prompt = "" 258 hideKeyboard() 259 showingSettings = false 260 261 messages.append(Message(text: text)) 262 messages.append(Message(type: useLlama ? .llamagenerated : .llavagenerated)) 263 264 runnerQueue.async { 265 defer { 266 DispatchQueue.main.async { 267 isGenerating = false 268 selectedImage = nil 269 } 270 } 271 272 if useLlama { 273 runnerHolder.runner = runnerHolder.runner ?? Runner(modelPath: modelPath, tokenizerPath: tokenizerPath) 274 } else { 275 runnerHolder.llavaRunner = runnerHolder.llavaRunner ?? LLaVARunner(modelPath: modelPath, tokenizerPath: tokenizerPath) 276 } 277 278 guard !shouldStopGenerating else { return } 279 if useLlama { 280 if let runner = runnerHolder.runner, !runner.isLoaded() { 281 var error: Error? 282 let startLoadTime = Date() 283 do { 284 try runner.load() 285 } catch let loadError { 286 error = loadError 287 } 288 289 let loadTime = Date().timeIntervalSince(startLoadTime) 290 DispatchQueue.main.async { 291 withAnimation { 292 var message = messages.removeLast() 293 message.type = .info 294 if let error { 295 message.text = "Model loading failed: error \((error as NSError).code)" 296 } else { 297 message.text = "Model loaded in \(String(format: "%.2f", loadTime)) s" 298 } 299 messages.append(message) 300 if error == nil { 301 messages.append(Message(type: .llamagenerated)) 302 } 303 } 304 } 305 if error != nil { 306 return 307 } 308 } 309 } else { 310 if let runner = runnerHolder.llavaRunner, !runner.isLoaded() { 311 var error: Error? 312 let startLoadTime = Date() 313 do { 314 try runner.load() 315 } catch let loadError { 316 error = loadError 317 } 318 319 let loadTime = Date().timeIntervalSince(startLoadTime) 320 DispatchQueue.main.async { 321 withAnimation { 322 var message = messages.removeLast() 323 message.type = .info 324 if let error { 325 message.text = "Model loading failed: error \((error as NSError).code)" 326 } else { 327 message.text = "Model loaded in \(String(format: "%.2f", loadTime)) s" 328 } 329 messages.append(message) 330 if error == nil { 331 messages.append(Message(type: .llavagenerated)) 332 } 333 } 334 } 335 if error != nil { 336 return 337 } 338 } 339 } 340 341 guard !shouldStopGenerating else { 342 DispatchQueue.main.async { 343 withAnimation { 344 _ = messages.removeLast() 345 } 346 } 347 return 348 } 349 do { 350 var tokens: [String] = [] 351 var rgbArray: [UInt8]? 352 let MAX_WIDTH = 336.0 353 var newHeight = 0.0 354 var imageBuffer: UnsafeMutableRawPointer? 355 356 if let img = selectedImage { 357 let llava_prompt = "\(text) ASSISTANT" 358 359 newHeight = MAX_WIDTH * img.size.height / img.size.width 360 let resizedImage = img.resized(to: CGSize(width: MAX_WIDTH, height: newHeight)) 361 rgbArray = resizedImage.toRGBArray() 362 imageBuffer = UnsafeMutableRawPointer(mutating: rgbArray) 363 364 try runnerHolder.llavaRunner?.generate(imageBuffer!, width: MAX_WIDTH, height: newHeight, prompt: llava_prompt, sequenceLength: seq_len) { token in 365 366 if token != llava_prompt { 367 if token == "</s>" { 368 shouldStopGenerating = true 369 runnerHolder.llavaRunner?.stop() 370 } else { 371 tokens.append(token) 372 if tokens.count > 2 { 373 let text = tokens.joined() 374 let count = tokens.count 375 tokens = [] 376 DispatchQueue.main.async { 377 var message = messages.removeLast() 378 message.text += text 379 message.tokenCount += count 380 message.dateUpdated = Date() 381 messages.append(message) 382 } 383 } 384 if shouldStopGenerating { 385 runnerHolder.llavaRunner?.stop() 386 } 387 } 388 } 389 } 390 } else { 391 let llama3_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\(text)<|eot_id|><|start_header_id|>assistant<|end_header_id|>" 392 393 try runnerHolder.runner?.generate(llama3_prompt, sequenceLength: seq_len) { token in 394 395 NSLog(">>> token={\(token)}") 396 if token != llama3_prompt { 397 // hack to fix the issue that extension/llm/runner/text_token_generator.h 398 // keeps generating after <|eot_id|> 399 if token == "<|eot_id|>" { 400 shouldStopShowingToken = true 401 } else { 402 tokens.append(token.trimmingCharacters(in: .newlines)) 403 if tokens.count > 2 { 404 let text = tokens.joined() 405 let count = tokens.count 406 tokens = [] 407 DispatchQueue.main.async { 408 var message = messages.removeLast() 409 message.text += text 410 message.tokenCount += count 411 message.dateUpdated = Date() 412 messages.append(message) 413 } 414 } 415 if shouldStopGenerating { 416 runnerHolder.runner?.stop() 417 } 418 } 419 } 420 } 421 } 422 } catch { 423 DispatchQueue.main.async { 424 withAnimation { 425 var message = messages.removeLast() 426 message.type = .info 427 message.text = "Text generation failed: error \((error as NSError).code)" 428 messages.append(message) 429 } 430 } 431 } 432 } 433 } 434 stopnull435 private func stop() { 436 shouldStopGenerating = true 437 } 438 allowedContentTypesnull439 private func allowedContentTypes() -> [UTType] { 440 guard let pickerType else { return [] } 441 switch pickerType { 442 case .model: 443 return [UTType(filenameExtension: "pte")].compactMap { $0 } 444 case .tokenizer: 445 return [UTType(filenameExtension: "bin"), UTType(filenameExtension: "model")].compactMap { $0 } 446 } 447 } 448 handleFileImportResultnull449 private func handleFileImportResult(_ pickerType: PickerType?, _ result: Result<[URL], Error>) { 450 switch result { 451 case .success(let urls): 452 guard let url = urls.first, let pickerType else { 453 withAnimation { 454 messages.append(Message(type: .info, text: "Failed to select a file")) 455 } 456 return 457 } 458 runnerQueue.async { 459 runnerHolder.runner = nil 460 runnerHolder.llavaRunner = nil 461 } 462 switch pickerType { 463 case .model: 464 resourceManager.modelPath = url.path 465 case .tokenizer: 466 resourceManager.tokenizerPath = url.path 467 } 468 case .failure(let error): 469 withAnimation { 470 messages.append(Message(type: .info, text: "Failed to select a file: \(error.localizedDescription)")) 471 } 472 } 473 } 474 } 475 476 extension View { hideKeyboardnull477 func hideKeyboard() { 478 UIApplication.shared.sendAction(#selector(UIResponder.resignFirstResponder), to: nil, from: nil, for: nil) 479 } 480 } 481