README.md
1# Summary
2This example demonstrates how to run [Llama models](https://www.llama.com/) on mobile via ExecuTorch. We use XNNPACK to accelerate the performance and 4-bit groupwise quantization to fit the model on a phone.
3
4Here are supported models:
5
6- Llama 3.2 1B and 3B
7- Llama 3.2 Quantized 1B and 3B
8- Llama 3.1 8B
9- Llama 3 8B
10- [Llama 2 7B](../llama2/README.md)
11
12Pretrained models are not included in this repo. Users are suggested to download them [here](https://ai.meta.com/resources/models-and-libraries/llama-downloads/).
13
14This page contains the basic recipe for running Llama. See [Llama utils page](./UTILS.md) page for more advanced use-cases such as fine-tuning and running smaller models for educational purposes.
15
16# What is Llama?
17Llama is a collection of large language models that use publicly available data for training. These models are based on the transformer architecture, which allows it to process input sequences of arbitrary length and generate output sequences of variable length. One of the key features of Llama models is its ability to generate coherent and contextually relevant text. This is achieved through the use of attention mechanisms, which allow the model to focus on different parts of the input sequence as it generates output. Additionally, Llama models use a technique called “masked language modeling” to pre-train the model on a large corpus of text, which helps it learn to predict missing words in a sentence.
18
19Llama models have shown to perform well on a variety of natural language processing tasks, including language translation, question answering, and text summarization and are also capable of generating human-like text, making Llama models a useful tool for creative writing and other applications where natural language generation is important.
20
21Overall, Llama models are powerful and versatile language models that can be used for a wide range of natural language processing tasks. The model’s ability to generate coherent and contextually relevant text makes it particularly useful for applications such as chatbots, virtual assistants, and language translation.
22
23Please note that the models are subject to the [Llama 2 Acceptable Use Policy](https://github.com/facebookresearch/llama/blob/main/USE_POLICY.md), [Llama 3 Acceptable Use Policy](https://github.com/meta-llama/llama3/blob/main/USE_POLICY.md) and [Responsible Use Guide](https://ai.meta.com/static-resource/responsible-use-guide/).
24
25
26# Results
27
28## Llama 3.2 1B/3B and quantized 1B/3B models
29
30For Llama 3.2 1B/3B models, we have enabled the original BF16 format and quantization to 4-bit, using SpinQuant and QAT+LoRA, for enhanced performance.
31
32The quantized models were optimized primarily for Arm CPU architecture by leveraging XNNPACK and Kleidi AI library. Work is underway to specifically enable quantization on mobile accelerators for Llama 1B/3B.
33
34### Enablement
35
36We have successfully verified performance on the following devices: iPhone 15 Pro, iPhone 15 Pro Max, Samsung Galaxy S24+, S22 and OnePlus 12 (featuring 16GB RAM).
37
38Note, the Llama 3.2 3B unquantized BF16 model was only tested on the OnePlus 12, which has sufficient memory (16GB RAM) to support its size requirements.
39
40### Quantization
41
42The 1B/3B models are sensitive to accuracy loss when regular post-training quantization (PTQ) is applied. To achieve a balance between accuracy, performance and memory, we utilized 4-bit quantization, using [SpinQuant](https://github.com/facebookresearch/SpinQuant/tree/main) and QAT+LoRA methods.
43
44Our quantization scheme involves three parts, applicable to both methods:
45
46- We quantize all linear layers in all transformer blocks to a 4-bit groupwise scheme (with a group size of 32) for weights and 8-bit per-token dynamic quantization for activations.
47- The classification layer is quantized to 8-bit per-channel for weight and 8-bit per token dynamic quantization for activation.
48- We employ an 8-bit per channel quantization for embedding.
49
50We use [torchao](https://github.com/pytorch/ao) library APIs to define these schemes.
51
52#### SpinQuant
53
54The SpinQuant method takes the original weights and produces optimized quantized weights with minimal outliers, resulting in higher accuracy. This can be achieved without any finetuning of the weights and only requires 100 iterations on a single A100 node.
55
56SpinQuant can generate quantized weights that are [compatible with ExecuTorch](https://github.com/facebookresearch/SpinQuant/tree/main?tab=readme-ov-file#3-export-to-executorch), specifically, it can be integrated with the existing optimized XNNPACK kernels (e.g., group-wise 4bit weight and 8bit dynamic activation). This allows developers to benefit from the higher accuracy of SpinQuant while also taking advantage of the strong performance of ExecuTorch acceleration.
57
58#### Quantization-Aware Training and LoRA (QAT+LoRA)
59
60Quantization-Aware Training (QAT) is employed to simulate the effects of quantization during the training of Llama-3.2 models, enabling optimization of their performance in low precision environments. To initialize QAT, BF16 Llama-3.2 model checkpoints obtained after supervised fine-tuning (SFT) are utilized and an additional full round of SFT training with QAT is performed. The backbone of the QAT model is then frozen and another round of SFT is performed with low-rank adaptation (LoRA) adaptors applied to all layers within the transformer block. Meanwhile, the LoRA adaptors' weights and activations are maintained in BF16.
61
62### Accuracy
63
64Please see the [Llama 3.2 model card](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/MODEL_CARD.md) for accuracy evalations.
65
66### Performance
67
68Llama 3.2 1B and 3B performance was measured on Android OnePlus 12 device. The performance measurement is expressed in terms of tokens per second using an [adb binary-based approach](#step-4-run-benchmark-on-android-phone) with prompt length of 64. It is measured with KleidiAI library. KleidiAI is not enabled by default yet. Use `-DEXECUTORCH_XNNPACK_ENABLE_KLEIDI=ON` to enable it in the build.
69
70|Model | Decode (tokens/s) | Time-to-first-token (sec) | Prefill (tokens/s) | Model size (PTE file size in MiB) | Memory size (RSS in MiB) |
71|-------|------------------:|--------------------------:| ------------------:|----------------------------------:| ------------------------:|
72|1B BF16 (baseline) | 19.2 | 1.0 | 60.3 | 2,358 | 3,185 |
73|1B SpinQuant | 50.2 (2.6x) | 0.3 (-76.9%) | 260.5 (4.3x) | 1,083 (-54.1%) | 1,921 (-39.7%) |
74|1B QAT+LoRA | 45.8 (2.4x) | 0.3 (-76.0%) | 252.0 (4.2x) | 1,127 (-52.2%) | 2,255 (-29.2%) |
75|3B BF16 (baseline) | 7.6 | 3.0 | 21.2 | 6,129 | 7,419 |
76|3B SpinQuant | 19.7 (2.6x) | 0.7 (-76.4%) | 89.7 (4.2x) | 2,435 (-60.3%) | 3,726 (-49.8%) |
77|3B QAT+LoRA | 18.5 (2.4x) | 0.7 (-76.1%) | 88.8 (4.2x) | 2,529 (-58.7%) | 4,060 (-45.3%) |
78
79
80<table>
81 <tr>
82 <td>
83 <img src="./Android3_2_1B_bf16.gif" width="300">
84 <br>
85 <em> Llama3.2 1B, unquantized, BF16 on Android phone. </em>
86 </td>
87 <td>
88 <img src="./Android3_2_3B_SpinQuant.gif" width="300">
89 <br>
90 <em>
91 Llama3.2 3B, 4bit quantized (SpinQuant) on Android phone
92 </em>
93 </td>
94 </tr>
95</table>
96
97## Llama 3/3.1 8B
98Since Llama 3 8B model needs at least 4-bit quantization to fit even within some of the highend phones, results presented here correspond to 4-bit groupwise post-training quantized (PTQ) model.
99
100### Enablement
101
102For Llama 3 8B and Llama3.1 8B, we have verified so far on iPhone 15 Pro, iPhone 15 Pro Max, Samsung Galaxy S24+ and OnePlus 12 (with 16GB RAM) by quantizing to 4bit.
103
104### Quantization
105
106We employed PTQ 4-bit groupwise per token dynamic quantization of all the linear layers of the model. Dynamic quantization refers to quantizating activations dynamically, such that quantization parameters for activations are calculated, from min/max range, at runtime. Here we quantized activations with 8bits (signed integer). Furthermore, weights are statically quantized. In our case weights were per-channel groupwise quantized with 4bit signed integer. Due to Llama3's vocabulary size, we had to quantize embedding lookup table as well. For these results embedding lookup table was groupwise quantized with 4-bits and group size of 32.
107
108We use [torchao](https://github.com/pytorch/ao) library APIs to define these schemes.
109
110### Accuracy
111
112We evaluated WikiText perplexity using [LM Eval](https://github.com/EleutherAI/lm-evaluation-harness). Below are the results for two different groupsizes, with max_seq_length 2048, and limit 1000.
113
114|Model | Baseline (FP32) | Groupwise 4-bit (128) | Groupwise 4-bit (256)
115|--------|-----------------| ---------------------- | ---------------
116|Llama 3 8B | 7.9 | 9.4 | 9.7
117
118Please note that LM Eval reports perplexity normalized by word count instead of token count. You may see different perplexity for WikiText from other sources if they implement it differently. More details could be found [here](https://github.com/EleutherAI/lm-evaluation-harness/issues/2301).
119
120### Performance
121
122Llama 3 8B performance was measured on the Samsung Galaxy S22, S24, and OnePlus 12 devices. The performance measurement is expressed in terms of tokens per second using an [adb binary-based approach](#step-4-run-benchmark-on-android-phone).
123
124|Device | Groupwise 4-bit (128) | Groupwise 4-bit (256)
125|--------| ---------------------- | ---------------
126|Galaxy S22 | 7.85 tokens/second | 8.4 tokens/second |
127|Galaxy S24 | 10.91 tokens/second | 11.21 tokens/second |
128|OnePlus 12 | 10.85 tokens/second | 11.02 tokens/second |
129
130<p align="center">
131 <br>
132 <img src="./llama_via_xnnpack.gif" width=300>
133 <br>
134 <em>
135 Llama3.1 8B, 4bit quantized on Android phone
136 </em>
137</p>
138
139[Please visit this section to try it on non-CPU backend, including CoreML, MPS, Qualcomm HTP or MediaTek](non_cpu_backends.md).
140
141# Instructions
142
143## Tested on
144
145- MacOS M1/M2, Linux.
146- For Llama 3 8B, your device may require at least 32GB RAM. If this is a constraint for you, please try the [smaller stories model](./UTILS.md).
147
148## Step 1: Setup
149> :warning: **double check your python environment**: make sure `conda activate <VENV>` is run before all the bash and python scripts.
150
1511. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch. For installation run `./install_requirements.sh --pybind xnnpack`
1522. Run `examples/models/llama/install_requirements.sh` to install a few dependencies.
153
154
155## Step 2: Prepare model
156
157### Option A: Download and export Llama3.2 1B/3B model.
158
1591. Download `consolidated.00.pth`, `params.json` and `tokenizer.model` from [Llama website](https://www.llama.com/llama-downloads/) or [Hugging Face](https://huggingface.co/meta-llama/Llama-3.2-1B). For chat use-cases, download the instruct models.
160
1612. Export model and generate `.pte` file.
162
163- Use **original BF16** version, without any quantization.
164```
165# No quantization
166# Set these paths to point to the downloaded files
167LLAMA_CHECKPOINT=path/to/checkpoint.pth
168LLAMA_PARAMS=path/to/params.json
169
170python -m examples.models.llama.export_llama \
171 --checkpoint "${LLAMA_CHECKPOINT:?}" \
172 --params "${LLAMA_PARAMS:?}" \
173 -kv \
174 --use_sdpa_with_kv_cache \
175 -X \
176 -d bf16 \
177 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' \
178 --output_name="llama3_2.pte"
179```
180
181- To use **SpinQuant**, here are two ways:
182 - Download directly from [Llama website](https://www.llama.com/llama-downloads). The model weights are prequantized and can be exported to `pte` file directly.
183 - Follow its [instruction](https://github.com/facebookresearch/SpinQuant/tree/main?tab=readme-ov-file#3-export-to-executorch) for exporting checkpoint to ExecuTorch and then export the SpinQuant checkpoint.
184
185```
186# SpinQuant
187# Set these paths to point to the exported files
188LLAMA_QUANTIZED_CHECKPOINT=path/to/spinquant/checkpoint.pth
189LLAMA_PARAMS=path/to/spinquant/params.json
190
191python -m examples.models.llama.export_llama \
192 --checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
193 --params "${LLAMA_PARAMS:?}" \
194 --use_sdpa_with_kv_cache \
195 -X \
196 --xnnpack-extended-ops \
197 --preq_mode 8da4w_output_8da8w \
198 --preq_group_size 32 \
199 --max_seq_length 2048 \
200 --output_name "llama3_2.pte" \
201 -kv \
202 -d fp32 \
203 --preq_embedding_quantize 8,0 \
204 --use_spin_quant native \
205 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'
206```
207
208- To use **QAT+LoRA**, download directly from [Llama website](https://www.llama.com/llama-downloads). The model weights are prequantized and can be exported to `pte` file directly by:
209
210```
211# QAT+LoRA
212# Set these paths to point to the exported files
213LLAMA_QUANTIZED_CHECKPOINT=path/to/qlora/checkpoint.pth
214LLAMA_PARAMS=path/to/qlora/params.json
215
216python -m examples.models.llama.export_llama \
217 --checkpoint "${LLAMA_QUANTIZED_CHECKPOINT:?}" \
218 --params "${LLAMA_PARAMS:?}" \
219 -qat \
220 -lora 16 \
221 --preq_mode 8da4w_output_8da8w \
222 --preq_group_size 32 \
223 --preq_embedding_quantize 8,0 \
224 --use_sdpa_with_kv_cache \
225 -kv \
226 -X \
227 --xnnpack-extended-ops \
228 -d fp32 \
229 --max_seq_length 2048 \
230 --output_name "llama3_2.pte" \
231 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'
232```
233
234### Option B: Download and export Llama 3 8B instruct model
235
236You can export and run the original Llama 3 8B instruct model.
237
2381. Llama 3 pretrained parameters can be downloaded from [Meta's official Llama 3 repository](https://github.com/meta-llama/llama3/).
239
2402. Export model and generate `.pte` file
241 ```
242 python -m examples.models.llama.export_llama \
243 --checkpoint <consolidated.00.pth> \
244 -p <params.json> \
245 -kv \
246 --use_sdpa_with_kv_cache \
247 -X \
248 -qmode 8da4w \
249 --group_size 128 \
250 -d fp32 \
251 --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' \
252 --embedding-quantize 4,32 \
253 --output_name="llama3_kv_sdpa_xnn_qe_4_32.pte"
254 ```
255 Due to the larger vocabulary size of Llama 3, we recommend quantizing the embeddings with `--embedding-quantize 4,32` as shown above to further reduce the model size.
256
257
258 If you're interested in deploying on non-CPU backends, [please refer the non-cpu-backend section](non_cpu_backends.md)
259
260## Step 3: Run on your computer to validate
261
2621. Build executorch with optimized CPU performance as follows. Build options available [here](https://github.com/pytorch/executorch/blob/main/CMakeLists.txt#L59).
263 ```
264 cmake -DPYTHON_EXECUTABLE=python \
265 -DCMAKE_INSTALL_PREFIX=cmake-out \
266 -DEXECUTORCH_ENABLE_LOGGING=1 \
267 -DCMAKE_BUILD_TYPE=Release \
268 -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
269 -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
270 -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
271 -DEXECUTORCH_BUILD_XNNPACK=ON \
272 -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
273 -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
274 -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
275 -Bcmake-out .
276
277 cmake --build cmake-out -j16 --target install --config Release
278 ```
279Note for Mac users: There's a known linking issue with Xcode 15.1. Refer to the section of Common Issues and Mitigations below for solutions.
280
2812. Build llama runner.
282 ```
283 cmake -DPYTHON_EXECUTABLE=python \
284 -DCMAKE_INSTALL_PREFIX=cmake-out \
285 -DCMAKE_BUILD_TYPE=Release \
286 -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
287 -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
288 -DEXECUTORCH_BUILD_XNNPACK=ON \
289 -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
290 -Bcmake-out/examples/models/llama \
291 examples/models/llama
292
293 cmake --build cmake-out/examples/models/llama -j16 --config Release
294 ```
295
2963. Run model. Run options available [here](https://github.com/pytorch/executorch/blob/main/examples/models/llama/main.cpp#L18-L40).
297 ```
298 cmake-out/examples/models/llama/llama_main --model_path=<model pte file> --tokenizer_path=<tokenizer.model> --prompt=<prompt>
299 ```
300
301To build for CoreML backend and validate on Mac, replace `-DEXECUTORCH_BUILD_XNNPACK=ON` with `-DEXECUTORCH_BUILD_COREML=ON`
302
303## Step 4: Run benchmark on Android phone
304
305**1. Build llama runner binary for Android**
306
307*Pre-requisite*: Android NDK (tested with r27b) which can be downloaded from [here](https://developer.android.com/ndk/downloads). Note that the mac binary can be unpackaged and you can locate NDK folder from it.
308
309**1.1 Set Android NDK**
310```
311export ANDROID_NDK=<path-to-android-ndk>
312```
313**1.2 Build executorch and associated libraries for android.**
314```
315cmake -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
316 -DANDROID_ABI=arm64-v8a \
317 -DANDROID_PLATFORM=android-23 \
318 -DCMAKE_INSTALL_PREFIX=cmake-out-android \
319 -DCMAKE_BUILD_TYPE=Release \
320 -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
321 -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
322 -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
323 -DEXECUTORCH_ENABLE_LOGGING=1 \
324 -DPYTHON_EXECUTABLE=python \
325 -DEXECUTORCH_BUILD_XNNPACK=ON \
326 -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
327 -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
328 -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
329 -Bcmake-out-android .
330
331cmake --build cmake-out-android -j16 --target install --config Release
332```
333
334**1.2 Build llama runner for android**
335```
336cmake -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \
337 -DANDROID_ABI=arm64-v8a \
338 -DANDROID_PLATFORM=android-23 \
339 -DCMAKE_INSTALL_PREFIX=cmake-out-android \
340 -DCMAKE_BUILD_TYPE=Release \
341 -DPYTHON_EXECUTABLE=python \
342 -DEXECUTORCH_BUILD_XNNPACK=ON \
343 -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
344 -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
345 -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
346 -Bcmake-out-android/examples/models/llama \
347 examples/models/llama
348
349cmake --build cmake-out-android/examples/models/llama -j16 --config Release
350```
351
352**2. Run on Android via adb shell**
353
354*Pre-requisite*: Make sure you enable USB debugging via developer options on your phone
355
356**2.1 Connect your android phone**
357
358**2.2 Upload model, tokenizer and llama runner binary to phone**
359```
360adb shell mkdir -p /data/local/tmp/llama
361adb push <model.pte> /data/local/tmp/llama/
362adb push <tokenizer.model> /data/local/tmp/llama/
363adb push cmake-out-android/examples/models/llama/llama_main /data/local/tmp/llama/
364```
365
366**2.3 Run model**
367```
368adb shell "cd /data/local/tmp/llama && ./llama_main --model_path <model.pte> --tokenizer_path <tokenizer.model> --prompt \"What is the capital of France?\" --seq_len 120" --warmup=1
369```
370## Step 6: Build Mobile apps
371
372### iOS
373
374Please refer to [this tutorial](https://pytorch.org/executorch/main/llm/llama-demo-ios.html) to for full instructions on building the iOS LLAMA Demo App. Rename `tokenizer.model` file to `tokenizer.bin` because the demo app looks for the tokenizer file with .bin extension.
375
376### Android
377Please refer to [this tutorial](https://pytorch.org/executorch/main/llm/llama-demo-android.html) to for full instructions on building the Android LLAMA Demo App.
378
379
380## Utility tools for Llama enablement
381
382### Evaluate model accuracy
383
384> Forewarning: Model evaluation without a GPU may take a long time, especially on larger models.
385
386We use [LM Eval](https://github.com/EleutherAI/lm-evaluation-harness) to evaluate model accuracy.
387
388For base models, use the following example command to calculate its perplexity based on WikiText.
389```
390python -m examples.models.llama.eval_llama \
391 -c <checkpoint.pth> \
392 -p <params.json> \
393 -t <tokenizer.model/bin> \
394 -kv \
395 -d <checkpoint dtype> \
396 --max_seq_len <max sequence length> \
397 --limit <number of samples>
398```
399
400For instruct models, use the following example command to calculate its MMLU score.
401```
402python -m examples.models.llama.eval_llama \
403 -c <checkpoint.pth> \
404 -p <params.json> \
405 -t <tokenizer.model/bin> \
406 -kv \
407 -d <checkpoint dtype> \
408 --tasks mmlu \
409 --num_fewshot 5 \
410 --max_seq_len <max sequence length>
411```
412
413See [Llama utils page](./UTILS.md) page for more advanced use-cases such as fine-tuning and running smaller models for educational purposes, and quick iteration and verification.
414
415# What is coming next?
416## Quantization
417- Enabling FP16 model to leverage smaller groupsize for 4-bit quantization.
418- Enabling GPTQ for 4-bit groupwise quantization
419- Enabling custom quantization
420- Lower bit quantization
421## Models
422- Enabling more generative AI models and architectures.
423## Performance
424- Performance improvement via techniques such as speculative decoding
425- Enabling LLama and other architectures via Vulkan
426- Enabling performant execution of widely used quantization schemes.
427
428# Notes
429This example tries to reuse the Python code, with minimal modifications to make it compatible with current ExecuTorch:
4301. Since ExecuTorch does not support complex Tensor data type, use the customized functions to have rotary embedding with real numbers. Please see [GitHub issue: Support complex data type in ExecuTorch](https://github.com/pytorch/executorch/issues/886).
4312. No CUDA. ExecuTorch is focused on Edge use cases where CUDA is not available on most of the edge devices.
4323. No dependencies on fairscale. The ColumnParallelLinear, ParallelEmbedding and training are not needed and supported in ExecuTorch.
433
434
435# Common Issues and Mitigations:
436- To clean your build:
437```
438git clean -xfd
439pip uninstall executorch
440./install_requirements.sh --pybind xnnpack
441
442rm -rf cmake-out
443```
444- If you encounter `pthread` related issues during link time, add `pthread` in `target_link_libraries` in `CMakeLists.txt`
445- On Mac, if there is linking error in Step 4 with error message like
446```
4470 0x100823648 __assert_rtn + 72
4481 0x10074bc5c ld::Fixup::applyFixup(ld::Atom const*, ld::LayoutLinkedImage const&, unsigned char*) const + 8268
4492 0x1007de7d8 ___ZN2ld16LayoutExecutable27writeContentWithoutLinkEditENSt3__14spanIhLm18446744073709551615EEEy_block_invoke + 332
4503 0x188cca428 _dispatch_client_callout2 + 20
4514 0x188cde850 _dispatch_apply_invoke3 + 336
4525 0x188cca3e8 _dispatch_client_callout + 20
4536 0x188ccbc68 _dispatch_once_callout + 32
4547 0x188cdeeec _dispatch_apply_invoke_and_wait + 372
4558 0x188cdde9c _dispatch_apply_with_attr_f + 1212
4569 0x188cde08c dispatch_apply + 96
45710 0x1007de9e4 void mapReduce<ld::Atom const*, mach_o::Error>(std::__1::span<ld::Atom const*, 18446744073709551615ul>, unsigned long, void (unsigned long, mach_o::Error&, std::__1::span<ld::Atom const*, 18446744073709551615ul>) block_pointer, void (std::__1::span<mach_o::Error, 18446744073709551615ul>) block_pointer) + 336
45811 0x1007de594 ld::LayoutExecutable::writeContentWithoutLinkEdit(std::__1::span<unsigned char, 18446744073709551615ul>, unsigned long long) + 1180
45912 0x1007e4020 ld::LayoutExecutable::writeToFile(char const*) + 15248
46013 0x1007962e8 main + 9424
461ld: Assertion failed: (extras.otherInstrOffset != 0 && "Kind::arm64_adrp_ldr missing extra info"), function applyFixup, file Fixup.cpp, line 793.
462clang: error: linker command failed with exit code 1 (use -v to see invocation)
463```
464It's a known issue for Xcode version 15.1.
465Mitigation: update to most recent Xcode version, clean and rebuild.
466