xref: /aosp_15_r20/external/executorch/backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1//
2// ETCoreMLModelManager.mm
3//
4//  Copyright © 2023 Apple Inc. All rights reserved.
5//
6// Please refer to the license found in the LICENSE file in the root directory of the source tree.
7
8#import "ETCoreMLAsset.h"
9#import "ETCoreMLAssetManager.h"
10#import "ETCoreMLDefaultModelExecutor.h"
11#import "ETCoreMLLogging.h"
12#import "ETCoreMLModel.h"
13#import "ETCoreMLModelCompiler.h"
14#import "ETCoreMLModelExecutor.h"
15#import "ETCoreMLModelLoader.h"
16#import "ETCoreMLModelManager.h"
17#import "ETCoreMLStrings.h"
18#import "MLModel_Prewarm.h"
19#import "MLMultiArray_Copy.h"
20#import <filesystem>
21#import "inmemory_filesystem_utils.hpp"
22#import <iostream>
23#import <memory>
24#import "model_metadata.h"
25#import "multiarray.h"
26#import "objc_array_util.h"
27#import <optional>
28#import <os/lock.h>
29#import "serde_json.h"
30#import <string>
31#import <system_error>
32#import <vector>
33
34#if ET_EVENT_TRACER_ENABLED
35#import "ETCoreMLModelAnalyzer.h"
36#import "ETCoreMLModelDebugInfo.h"
37#import "ETCoreMLModelStructurePath.h"
38#import "objc_safe_cast.h"
39#endif
40
41namespace {
42
43using namespace executorchcoreml;
44
45enum class ModelAssetType: uint8_t {
46    CompiledModel,
47    Model
48};
49
50std::vector<std::string> canonical_path(NSString *path) {
51    NSArray<NSString *> *components = path.pathComponents;
52    std::vector<std::string> result;
53    result.reserve(components.count);
54    for (NSString *component in components) {
55        result.emplace_back(component.UTF8String);
56    }
57
58    return result;
59}
60
61id<MLFeatureProvider> _Nullable get_feature_provider(NSArray<MLMultiArray *> *inputs,
62                                                     NSOrderedSet<NSString *> *input_names,
63                                                     NSError * __autoreleasing *error) {
64    NSEnumerator<NSString *> *enumerator = [input_names objectEnumerator];
65    NSMutableDictionary<NSString *, MLFeatureValue *> *features = [NSMutableDictionary dictionaryWithCapacity:inputs.count];
66    for (MLMultiArray *input in inputs) {
67        NSString *input_name = [enumerator nextObject];
68        features[input_name] = [MLFeatureValue featureValueWithMultiArray:input];
69    }
70
71    return [[MLDictionaryFeatureProvider alloc] initWithDictionary:features error:error];
72}
73
74BOOL is_backed_by_same_buffer(MLMultiArray *array1, MLMultiArray *array2) {
75    __block BOOL result = NO;
76    [array1 getBytesWithHandler:^(const void *bytes1, NSInteger __unused size1){
77        [array2 getBytesWithHandler:^(const void *bytes2, NSInteger __unused size2) {
78            result = (bytes1 == bytes2);
79        }];
80    }];
81
82    return result;
83}
84
85MLPredictionOptions *get_prediction_options(NSArray<MLMultiArray *> *outputs,
86                                            NSOrderedSet<NSString *> *output_names,
87                                            NSError * __autoreleasing *error) {
88    MLPredictionOptions *options = [MLPredictionOptions new];
89    NSMutableDictionary<NSString *, id> *output_backings = [NSMutableDictionary new];
90    NSEnumerator<NSString *> *enumerator = [output_names objectEnumerator];
91    for (MLMultiArray *output in outputs) {
92        NSString *output_name = [enumerator nextObject];
93        if (output_name.length == 0) {
94            ETCoreMLLogErrorAndSetNSError(error, 0, "%@: Model is broken.", NSStringFromClass(ETCoreMLModelManager.class));
95            return nil;
96        }
97        output_backings[output_name] = output;
98    }
99    options.outputBackings = output_backings;
100
101    return options;
102}
103
104void copy(MLMultiArray *src, MLMultiArray *dst) {
105    if (::is_backed_by_same_buffer(src, dst)) {
106        return;
107    }
108
109    [src copyInto:dst];
110}
111
112void set_outputs(NSArray<MLMultiArray *> *outputs, NSArray<MLMultiArray *> *model_outputs) {
113    NSEnumerator<MLMultiArray *> *enumerator = [model_outputs objectEnumerator];
114    for (MLMultiArray *output in outputs) {
115        MLMultiArray *model_output = [enumerator nextObject];
116        ::copy(model_output, output);
117    }
118}
119
120std::optional<MultiArray::DataType> get_data_type(MLMultiArrayDataType data_type) {
121    switch (data_type) {
122        case MLMultiArrayDataTypeFloat16: {
123            return MultiArray::DataType::Float16;
124        }
125        case MLMultiArrayDataTypeFloat32: {
126            return MultiArray::DataType::Float32;
127        }
128        case MLMultiArrayDataTypeFloat64: {
129            return MultiArray::DataType::Float64;
130        }
131        case MLMultiArrayDataTypeInt32: {
132            return MultiArray::DataType::Int32;
133        }
134        default: {
135            return std::nullopt;
136        }
137    }
138}
139
140void copy(MLMultiArray *src, executorchcoreml::MultiArray& dst) {
141    [src getBytesWithHandler:^(const void * _Nonnull bytes, NSInteger size) {
142        if (bytes == dst.data()) {
143            return;
144        }
145
146        MultiArray::MemoryLayout src_layout(get_data_type(src.dataType).value(), to_vector<size_t>(src.shape), to_vector<ssize_t>(src.strides));
147        MultiArray(const_cast<void *>(bytes), std::move(src_layout)).copy(dst);
148    }];
149}
150
151void set_outputs(std::vector<executorchcoreml::MultiArray>& outputs,
152                 NSArray<MLMultiArray *> *model_outputs) {
153    NSEnumerator<MLMultiArray *> *enumerator = [model_outputs objectEnumerator];
154    for (auto& output : outputs) {
155        MLMultiArray *model_output = [enumerator nextObject];
156        ::copy(model_output, output);
157    }
158}
159
160NSData * _Nullable get_file_data(const inmemoryfs::InMemoryFileSystem *inMemoryFS,
161                                 NSString *fileName) {
162    std::error_code ec;
163    const auto& file_path = ::canonical_path(fileName);
164    __block auto buffer = inMemoryFS->get_file_content(file_path, ec);
165    if (!buffer ||  buffer->size() == 0) {
166        return nil;
167    }
168
169    NSData *file_data = [[NSData alloc] initWithBytesNoCopy:buffer->data()
170                                                     length:buffer->size()
171                                                deallocator:^(void * _Nonnull __unused bytes, NSUInteger __unused length) {
172        buffer.reset();
173    }];
174
175    return file_data;
176}
177
178std::optional<ModelMetadata> get_model_metadata(const inmemoryfs::InMemoryFileSystem *inMemoryFS) {
179    NSData *file_data = get_file_data(inMemoryFS, ETCoreMLStrings.metadataFileRelativePath);
180    if (!file_data) {
181        return std::nullopt;
182    }
183
184    std::string contents;
185    contents.assign(static_cast<const char *>(file_data.bytes), file_data.length);
186    ModelMetadata metadata;
187    metadata.from_json_string(std::move(contents));
188    if (metadata.is_valid()) {
189        return metadata;
190    }
191
192    return std::nullopt;
193}
194
195NSOrderedSet<NSString *> *get_ordered_set(const std::vector<std::string>& values) {
196    NSMutableOrderedSet<NSString *> *result = [NSMutableOrderedSet orderedSetWithCapacity:values.size()];
197    for (const auto& value : values) {
198        [result addObject:@(value.c_str())];
199    }
200
201    return result;
202}
203
204NSURL * _Nullable write_model_files(NSURL *dst_url,
205                                    NSFileManager *fm,
206                                    NSString *identifier,
207                                    ModelAssetType model_asset_type,
208                                    const inmemoryfs::InMemoryFileSystem *inmemory_fs,
209                                    NSError * __autoreleasing *error) {
210    NSError *local_error = nil;
211    if (![fm createDirectoryAtURL:dst_url withIntermediateDirectories:NO attributes:@{} error:error]) {
212        ETCoreMLLogUnderlyingErrorAndSetNSError(error,
213                                                ETCoreMLErrorModelSaveFailed,
214                                                local_error,
215                                                "%@: Failed to create directory when saving model with identifier = %@.",
216                                                NSStringFromClass(ETCoreMLModelManager.class),
217                                                identifier);
218        return nil;
219    }
220
221    std::filesystem::path model_path(dst_url.fileSystemRepresentation);
222    std::error_code ec;
223    std::vector<std::string> file_path;
224    switch (model_asset_type) {
225        case ModelAssetType::Model: {
226            file_path = canonical_path(ETCoreMLStrings.modelFileRelativePath);
227            break;
228        }
229
230        case ModelAssetType::CompiledModel: {
231            file_path = canonical_path(ETCoreMLStrings.compiledModelFileRelativePath);
232            break;
233        }
234    }
235
236    if (!inmemory_fs->write_item_to_disk(file_path, model_path, true, ec)) {
237        ETCoreMLLogErrorAndSetNSError(error,
238                                      ETCoreMLErrorModelSaveFailed,
239                                      "%@: Failed to write model files to disk when saving model with identifier = %@.",
240                                      NSStringFromClass(ETCoreMLModelManager.class),
241                                      identifier);
242        return nil;
243    }
244
245    switch (model_asset_type) {
246        case ModelAssetType::Model: {
247            return [dst_url URLByAppendingPathComponent:[NSString stringWithFormat:@"model.%@", ETCoreMLStrings.modelExtensionName]];
248        }
249        case ModelAssetType::CompiledModel: {
250            return [dst_url URLByAppendingPathComponent:[NSString stringWithFormat:@"model.%@", ETCoreMLStrings.compiledModelExtensionName]];
251        }
252    }
253}
254
255std::optional<ModelAssetType> get_model_asset_type(const inmemoryfs::InMemoryFileSystem *inmemory_fs) {
256    std::error_code ec;
257    if (inmemory_fs->exists(canonical_path(ETCoreMLStrings.compiledModelFileRelativePath))) {
258        return ModelAssetType::CompiledModel;
259    }
260
261    if (inmemory_fs->exists(canonical_path(ETCoreMLStrings.modelFileRelativePath))) {
262        return ModelAssetType::Model;
263    }
264
265    return std::nullopt;
266}
267
268
269ETCoreMLModel * _Nullable get_model_from_asset(ETCoreMLAsset *asset,
270                                               MLModelConfiguration *configuration,
271                                               const ModelMetadata& metadata,
272                                               NSError * __autoreleasing *error) {
273    NSOrderedSet<NSString *> *orderedInputNames = ::get_ordered_set(metadata.input_names);
274    NSOrderedSet<NSString *> *orderedOutputNames = ::get_ordered_set(metadata.output_names);
275    ETCoreMLModel *model = [[ETCoreMLModel alloc] initWithAsset:asset
276                                                  configuration:configuration
277                                              orderedInputNames:orderedInputNames
278                                             orderedOutputNames:orderedOutputNames
279                                                          error:error];
280    return model;
281}
282
283std::string to_string(MLComputeUnits compute_units) {
284    switch (compute_units) {
285        case MLComputeUnitsAll: {
286            return ETCoreMLStrings.allComputeUnitsName.UTF8String;
287        }
288        case MLComputeUnitsCPUOnly: {
289            return ETCoreMLStrings.cpuComputeUnitName.UTF8String;
290        }
291        case MLComputeUnitsCPUAndGPU: {
292            return ETCoreMLStrings.cpuAndGpuComputeUnitsName.UTF8String;
293        }
294        case MLComputeUnitsCPUAndNeuralEngine: {
295            return ETCoreMLStrings.cpuAndNeuralEngineComputeUnitsName.UTF8String;
296        }
297        default: {
298            return ETCoreMLStrings.allComputeUnitsName.UTF8String;
299        }
300    }
301}
302
303void add_compute_unit(std::string& identifier, MLComputeUnits compute_units) {
304    identifier.append("_");
305    identifier.append(to_string(compute_units));
306}
307
308#if ET_EVENT_TRACER_ENABLED
309ETCoreMLAsset * _Nullable make_asset(NSURL *url,
310                                     NSString *identifier,
311                                     NSFileManager *fm,
312                                     NSError * __autoreleasing *error) {
313    auto backingAsset = executorchcoreml::Asset::make(url, identifier, fm, error);
314    if (!backingAsset) {
315        return nil;
316    }
317
318    return [[ETCoreMLAsset alloc] initWithBackingAsset:std::move(backingAsset.value())];
319}
320
321ETCoreMLModelDebugInfo * _Nullable get_model_debug_info(const inmemoryfs::InMemoryFileSystem *inMemoryFS,
322                                                        NSError * __autoreleasing *error) {
323    NSData *file_data = get_file_data(inMemoryFS, ETCoreMLStrings.debugInfoFileRelativePath);
324    if (!file_data) {
325        return nil;
326    }
327
328    return [ETCoreMLModelDebugInfo modelDebugInfoFromData:file_data error:error];
329}
330
331#endif
332} //namespace
333
334@interface ETCoreMLModelManager () {
335    os_unfair_lock _lock;
336}
337
338@property (nonatomic, readonly, strong) NSFileManager *fileManager;
339@property (strong, readonly, nonatomic) ETCoreMLAssetManager* assetManager;
340@property (nonatomic, readonly, strong) NSMutableDictionary<NSValue *, id<ETCoreMLModelExecutor>> *handleToExecutorMap;
341@property (nonatomic, readonly, strong) NSMapTable<NSString *, dispatch_queue_t> *modelIdentifierToLoadingQueueMap;
342@property (nonatomic, readonly, strong) NSMutableDictionary<NSString *, ETCoreMLAsset *> *modelIdentifierToPrewarmedAssetMap;
343@property (nonatomic, readonly, strong) dispatch_queue_t prewarmQueue;
344
345@end
346
347@implementation ETCoreMLModelManager
348
349- (instancetype)initWithAssetManager:(ETCoreMLAssetManager *)assetManager {
350    self = [super init];
351    if (self) {
352        _assetManager = assetManager;
353        _lock = OS_UNFAIR_LOCK_INIT;
354        _handleToExecutorMap = [NSMutableDictionary dictionary];
355        _modelIdentifierToLoadingQueueMap = [NSMapTable strongToWeakObjectsMapTable];
356        _modelIdentifierToPrewarmedAssetMap = [NSMutableDictionary dictionary];
357        _fileManager = [[NSFileManager alloc] init];
358        dispatch_queue_attr_t attr = dispatch_queue_attr_make_with_qos_class(DISPATCH_QUEUE_SERIAL, QOS_CLASS_DEFAULT, -1);
359        _prewarmQueue = dispatch_queue_create("com.executorchcoreml.modelmanager.prewarm", attr);
360    }
361
362    return self;
363}
364
365- (nullable id<ETCoreMLModelExecutor>)executorWithHandle:(ModelHandle *)handle {
366    id<ETCoreMLModelExecutor> executor = nil;
367    NSValue *key = [NSValue valueWithPointer:handle];
368    {
369        os_unfair_lock_lock(&_lock);
370        executor = self.handleToExecutorMap[key];
371        os_unfair_lock_unlock(&_lock);
372    }
373
374    return executor;
375}
376
377- (nullable ETCoreMLModel *)modelWithHandle:(ModelHandle *)handle {
378    id<ETCoreMLModelExecutor> executor = [self executorWithHandle:handle];
379    return executor.model;
380}
381
382- (nullable ETCoreMLAsset *)assetWithIdentifier:(NSString *)identifier {
383    ETCoreMLAsset *modelAsset = nil;
384    {
385        os_unfair_lock_lock(&_lock);
386        modelAsset = self.modelIdentifierToPrewarmedAssetMap[identifier];
387        os_unfair_lock_unlock(&_lock);
388    }
389
390    if (modelAsset) {
391        return modelAsset;
392    }
393
394    NSError *localError = nil;
395    modelAsset = [self.assetManager assetWithIdentifier:identifier error:&localError];
396    if (localError) {
397        ETCoreMLLogError(localError,
398                         "%@: Failed to retrieve asset with identifier = %@",
399                         NSStringFromClass(self.assetManager.class),
400                         identifier);
401    }
402
403    return modelAsset;
404}
405
406- (nullable NSURL *)compiledModelURLWithIdentifier:(NSString *)identifier
407                                        inMemoryFS:(const inmemoryfs::InMemoryFileSystem*)inMemoryFS
408                                      assetManager:(ETCoreMLAssetManager *)assetManager
409                                             error:(NSError * __autoreleasing *)error {
410    auto modelAssetType = get_model_asset_type(inMemoryFS);
411    if (!modelAssetType) {
412        ETCoreMLLogErrorAndSetNSError(error,
413                                      ETCoreMLErrorCorruptedModel,
414                                      "%@: AOT blob is missing model file.",
415                                      NSStringFromClass(ETCoreMLModelManager.class));
416        return nil;
417    }
418
419    NSURL *dstURL = [self.assetManager.trashDirectoryURL URLByAppendingPathComponent:[NSUUID UUID].UUIDString];
420    NSURL *modelURL = ::write_model_files(dstURL, self.fileManager, identifier, modelAssetType.value(), inMemoryFS, error);
421    switch (modelAssetType.value()) {
422        case ModelAssetType::CompiledModel: {
423            return modelURL;
424        }
425
426        case ModelAssetType::Model: {
427            // we need to compiled the model.
428            NSURL *compiledModelURL = [ETCoreMLModelCompiler compileModelAtURL:modelURL
429                                                          maxWaitTimeInSeconds:(5 * 60)
430                                                                         error:error];
431
432            return compiledModelURL;
433        }
434    }
435}
436
437#if ET_EVENT_TRACER_ENABLED
438- (nullable id<ETCoreMLModelExecutor>)modelExecutorWithMetadata:(const ModelMetadata&)metadata
439                                                     inMemoryFS:(const inmemoryfs::InMemoryFileSystem*)inMemoryFS
440                                                  configuration:(MLModelConfiguration *)configuration
441                                                          error:(NSError * __autoreleasing *)error {
442    NSString *identifier = @(metadata.identifier.c_str());
443    // Otherwise try to retrieve the compiled asset.
444    ETCoreMLAsset *compiledModelAsset = [self assetWithIdentifier:identifier];
445    // Create a unique directory for writing model files.
446    NSURL *dstURL = [self.assetManager.trashDirectoryURL URLByAppendingPathComponent:[NSUUID UUID].UUIDString];
447    auto modelAssetType = get_model_asset_type(inMemoryFS);
448    ETCoreMLAsset *modelAsset = nil;
449    // Write the model files.
450    if (modelAssetType == ModelAssetType::Model) {
451        NSURL *modelURL = ::write_model_files(dstURL, self.fileManager, identifier, modelAssetType.value(), inMemoryFS, error);
452        if (modelURL) {
453            modelAsset = make_asset(modelURL,
454                                    identifier,
455                                    self.fileManager,
456                                    error);
457        }
458    }
459
460    if (!compiledModelAsset) {
461        // Compile the model.
462        NSURL *compiledModelURL = [self compiledModelURLWithIdentifier:identifier
463                                                            inMemoryFS:inMemoryFS
464                                                          assetManager:self.assetManager
465                                                                 error:error];
466        compiledModelAsset = make_asset(compiledModelURL,
467                                        identifier,
468                                        self.fileManager,
469                                        error);
470    }
471
472    if (!compiledModelAsset) {
473        return nil;
474    }
475
476    NSError *localError = nil;
477    ETCoreMLModelDebugInfo *debug_info = get_model_debug_info(inMemoryFS, &localError);
478    if (localError) {
479        ETCoreMLLogError(localError, "Failed to parse debug info file");
480    }
481
482
483    return [[ETCoreMLModelAnalyzer alloc] initWithCompiledModelAsset:compiledModelAsset
484                                                          modelAsset:modelAsset
485                                                      modelDebugInfo:debug_info
486                                                            metadata:metadata
487                                                       configuration:configuration
488                                                        assetManager:self.assetManager
489                                                               error:error];
490}
491
492#else
493- (nullable id<ETCoreMLModelExecutor>)modelExecutorWithMetadata:(const ModelMetadata&)metadata
494                                                     inMemoryFS:(const inmemoryfs::InMemoryFileSystem*)inMemoryFS
495                                                  configuration:(MLModelConfiguration *)configuration
496                                                          error:(NSError * __autoreleasing *)error {
497    NSString *identifier = @(metadata.identifier.c_str());
498    // Otherwise try to retrieve the compiled asset.
499    ETCoreMLAsset *asset = [self assetWithIdentifier:identifier];
500    ETCoreMLModel *model = asset ? get_model_from_asset(asset, configuration, metadata, error) : nil;
501    if (model) {
502        return [[ETCoreMLDefaultModelExecutor alloc] initWithModel:model];
503    }
504
505    // Compile the model.
506    NSURL *compiledModelURL = [self compiledModelURLWithIdentifier:identifier
507                                                        inMemoryFS:inMemoryFS
508                                                      assetManager:self.assetManager
509                                                             error:error];
510    if (!compiledModelURL) {
511        return nil;
512    }
513
514    model = [ETCoreMLModelLoader loadModelWithContentsOfURL:compiledModelURL
515                                              configuration:configuration
516                                                   metadata:metadata
517                                               assetManager:self.assetManager
518                                                      error:error];
519
520    return [[ETCoreMLDefaultModelExecutor alloc] initWithModel:model];
521}
522#endif
523
524- (nullable id<ETCoreMLModelExecutor>)_modelExecutorWithAOTData:(NSData *)data
525                                                  configuration:(MLModelConfiguration *)configuration
526                                                          error:(NSError * __autoreleasing *)error {
527    using namespace inmemoryfs;
528
529    auto buffer = MemoryBuffer::make_unowned(const_cast<void *>(data.bytes), data.length);
530    std::unique_ptr<InMemoryFileSystem> inMemoryFS = inmemoryfs::make_from_buffer(std::move(buffer));
531    if (!inMemoryFS) {
532        ETCoreMLLogErrorAndSetNSError(error,
533                                      ETCoreMLErrorCorruptedModel,
534                                      "%@: Model data is corrupted.",
535                                      NSStringFromClass(ETCoreMLModelManager.class));
536        return nil;
537    }
538
539    std::optional<ModelMetadata> metadata = ::get_model_metadata(inMemoryFS.get());
540    if (!metadata) {
541        ETCoreMLLogErrorAndSetNSError(error,
542                                      ETCoreMLErrorCorruptedMetadata,
543                                      "%@: Metadata is invalid or missing.",
544                                      NSStringFromClass(ETCoreMLModelManager.class));
545        return nil;
546    }
547
548    auto metadataValue = metadata.value();
549    add_compute_unit(metadataValue.identifier, configuration.computeUnits);
550    NSString *identifier = @(metadataValue.identifier.c_str());
551    // If there are multiple calls to load the same model, we only want to compile it once.
552    __block id<ETCoreMLModelExecutor> executor = nil;
553    dispatch_queue_t loadingQueue = [self queueForLoadingModelWithIdentifier:identifier];
554    auto inMemoryFSPtr = inMemoryFS.get();
555    dispatch_sync(loadingQueue, ^{
556        executor = [self modelExecutorWithMetadata:metadataValue
557                                        inMemoryFS:inMemoryFSPtr
558                                     configuration:configuration
559                                             error:error];
560    });
561
562    return executor;
563}
564
565- (dispatch_queue_t)queueForLoadingModelWithIdentifier:(NSString *)identifier {
566    os_unfair_lock_lock(&_lock);
567    dispatch_queue_t queue = [self.modelIdentifierToLoadingQueueMap objectForKey:identifier];
568    if (!queue) {
569        queue = dispatch_queue_create("com.executorchcoreml.modelmanager.loader", DISPATCH_QUEUE_SERIAL_WITH_AUTORELEASE_POOL);
570        [self.modelIdentifierToLoadingQueueMap setObject:queue forKey:identifier];
571    }
572    os_unfair_lock_unlock(&_lock);
573
574    return queue;
575}
576
577- (ModelHandle *)loadModelFromAOTData:(NSData*)data
578                        configuration:(MLModelConfiguration*)configuration
579                                error:(NSError* __autoreleasing*)error {
580    id<ETCoreMLModelExecutor> executor = [self _modelExecutorWithAOTData:data
581                                                           configuration:configuration
582                                                                   error:error];
583    {
584        os_unfair_lock_lock(&_lock);
585        if (executor) {
586            NSValue *key = [NSValue valueWithPointer:(__bridge void *)executor.model];
587            self.handleToExecutorMap[key] = executor;
588        }
589        os_unfair_lock_unlock(&_lock);
590    }
591
592    return (__bridge ModelHandle *)executor.model;
593}
594
595- (BOOL)prewarmModelWithHandle:(ModelHandle *)handle
596                         error:(NSError * __autoreleasing *)error {
597    ETCoreMLModel *model = [self modelWithHandle:handle];
598    if (!model) {
599        return NO;
600    }
601
602    return [model prewarmAndReturnError:error];
603}
604
605- (void)prewarmRecentlyUsedAssetsWithMaxCount:(NSUInteger)maxCount {
606    NSError *localError = nil;
607    NSArray<ETCoreMLAsset *> *assets = [self.assetManager mostRecentlyUsedAssetsWithMaxCount:maxCount error:&localError];
608
609    if (localError) {
610        ETCoreMLLogError(localError,
611                         "%@: Failed to retrieve recently used assets.",
612                         NSStringFromClass(self.assetManager.class));
613    }
614
615    if (assets.count == 0) {
616        return;
617    }
618
619    for (ETCoreMLAsset *asset in assets) {
620        __weak __typeof(self) weakSelf = self;
621        dispatch_async(self.prewarmQueue, ^{
622            __strong __typeof(self) strongSelf = weakSelf;
623            if (!strongSelf) {
624                return;
625            }
626
627            NSError *prewarmError = nil;
628            if (![asset prewarmAndReturnError:&prewarmError]) {
629                ETCoreMLLogError(prewarmError,
630                                 "%@: Failed to prewarm asset with identifier = %@",
631                                 NSStringFromClass(strongSelf.assetManager.class),
632                                 asset.identifier);
633                return;
634            }
635
636            [strongSelf addPrewarmedAsset:asset];
637        });
638    }
639}
640
641- (void)addPrewarmedAsset:(ETCoreMLAsset *)asset {
642    os_unfair_lock_lock(&_lock);
643    [self.modelIdentifierToPrewarmedAssetMap setObject:asset forKey:asset.identifier];
644    os_unfair_lock_unlock(&_lock);
645}
646
647- (nullable NSArray<MLMultiArray *> *)executeModelUsingExecutor:(id<ETCoreMLModelExecutor>)executor
648                                                         inputs:(NSArray<MLMultiArray *> *)inputs
649                                                 outputBackings:(NSArray<MLMultiArray *> *)outputBackings
650                                                 loggingOptions:(const executorchcoreml::ModelLoggingOptions&)loggingOptions
651                                                    eventLogger:(const executorchcoreml::ModelEventLogger* _Nullable)eventLogger
652                                                          error:(NSError * __autoreleasing *)error {
653    NSError *localError = nil;
654    ETCoreMLModel *model = executor.model;
655    MLPredictionOptions *predictionOptions = ::get_prediction_options(outputBackings, model.orderedOutputNames, error);
656    if (!predictionOptions) {
657        return nil;
658    }
659
660    id<MLFeatureProvider> inputFeatures = ::get_feature_provider(inputs, model.orderedInputNames, error);
661    if (!inputFeatures) {
662        return nil;
663    }
664
665    NSArray<MLMultiArray *> *modelOutputs = [executor executeModelWithInputs:inputFeatures
666                                                           predictionOptions:predictionOptions
667                                                             loggingOptions:loggingOptions
668                                                                 eventLogger:eventLogger
669                                                                       error:&localError];
670    // Try without output backings.
671    if (!modelOutputs && predictionOptions.outputBackings.count > 0) {
672        executor.ignoreOutputBackings = YES;
673        localError = nil;
674        modelOutputs = [executor executeModelWithInputs:inputFeatures
675                                      predictionOptions:predictionOptions
676                                         loggingOptions:loggingOptions
677                                            eventLogger:eventLogger
678                                                  error:&localError];
679    }
680
681    if (error) {
682        *error = localError;
683    }
684
685    return modelOutputs;
686}
687
688- (BOOL)executeModelWithHandle:(ModelHandle *)handle
689                          args:(NSArray<MLMultiArray *> *)args
690                loggingOptions:(const executorchcoreml::ModelLoggingOptions&)loggingOptions
691                   eventLogger:(const executorchcoreml::ModelEventLogger* _Nullable)eventLogger
692                         error:(NSError * __autoreleasing *)error {
693    id<ETCoreMLModelExecutor> executor = [self executorWithHandle:handle];
694    if (!executor) {
695        ETCoreMLLogErrorAndSetNSError(error,
696                                      0,
697                                      "%@: Model is already unloaded.",
698                                      NSStringFromClass(self.class));
699        return NO;
700    }
701
702    ETCoreMLModel *model = executor.model;
703    if (args.count != model.orderedInputNames.count + model.orderedOutputNames.count) {
704        ETCoreMLLogErrorAndSetNSError(error,
705                                      ETCoreMLErrorCorruptedModel,
706                                      "%@: Model is invalid, expected args count to be %lu but got %lu.",
707                                      NSStringFromClass(self.class),
708                                      static_cast<unsigned long>(model.orderedInputNames.count + model.orderedOutputNames.count),
709                                      args.count);
710        return NO;
711    }
712    @autoreleasepool {
713        NSArray<MLMultiArray *> *inputs = [args subarrayWithRange:NSMakeRange(0, model.orderedInputNames.count)];
714        NSArray<MLMultiArray *> *outputs = [args subarrayWithRange:NSMakeRange(model.orderedInputNames.count, args.count - model.orderedInputNames.count)];
715        NSArray<MLMultiArray *> *outputBackings = @[];
716        if (executor.ignoreOutputBackings == NO) {
717            outputBackings = outputs;
718        }
719
720        NSArray<MLMultiArray *> *modelOutputs = [self executeModelUsingExecutor:executor
721                                                                         inputs:inputs
722                                                                 outputBackings:outputBackings
723                                                                 loggingOptions:loggingOptions
724                                                                    eventLogger:eventLogger
725                                                                          error:error];
726        if (!modelOutputs) {
727            return NO;
728        }
729
730        ::set_outputs(outputs, modelOutputs);
731    }
732
733    return YES;
734}
735
736- (BOOL)executeModelWithHandle:(ModelHandle *)handle
737                       argsVec:(const std::vector<executorchcoreml::MultiArray>&)argsVec
738                loggingOptions:(const executorchcoreml::ModelLoggingOptions&)loggingOptions
739                   eventLogger:(const executorchcoreml::ModelEventLogger* _Nullable)eventLogger
740                         error:(NSError * __autoreleasing *)error {
741    id<ETCoreMLModelExecutor> executor = [self executorWithHandle:handle];
742    if (!executor) {
743        ETCoreMLLogErrorAndSetNSError(error,
744                                      0,
745                                      "%@: Model is already unloaded.",
746                                      NSStringFromClass(self.class));
747        return NO;
748    }
749
750    ETCoreMLModel *model = executor.model;
751    if (argsVec.size() != model.orderedInputNames.count + model.orderedOutputNames.count) {
752        ETCoreMLLogErrorAndSetNSError(error,
753                                      ETCoreMLErrorCorruptedModel,
754                                      "%@: Model is invalid, expected args count to be %lu but got %lu.",
755                                      NSStringFromClass(self.class),
756                                      static_cast<unsigned long>(model.orderedInputNames.count + model.orderedOutputNames.count),
757                                      argsVec.size());
758        return NO;
759    }
760
761    std::vector<executorchcoreml::MultiArray> inputArgs(argsVec.begin(), argsVec.begin() + model.orderedInputNames.count);
762    std::vector<executorchcoreml::MultiArray> outputArgs(argsVec.begin() + model.orderedInputNames.count, argsVec.end());
763    @autoreleasepool {
764        NSArray<MLMultiArray *> *inputs = [model prepareInputs:inputArgs error:error];
765        if (!inputs) {
766            return NO;
767        }
768
769        NSArray<MLMultiArray *> *outputBackings = @[];
770        if (executor.ignoreOutputBackings == NO) {
771            outputBackings = [model prepareOutputBackings:outputArgs error:error];
772        }
773
774        if (!outputBackings) {
775            return NO;
776        }
777
778        NSArray<MLMultiArray *> *modelOutputs = [self executeModelUsingExecutor:executor
779                                                                         inputs:inputs
780                                                                 outputBackings:outputBackings
781                                                                 loggingOptions:loggingOptions
782                                                                    eventLogger:eventLogger
783                                                                          error:error];
784        if (!modelOutputs) {
785            return NO;
786        }
787
788        ::set_outputs(outputArgs, modelOutputs);
789        return YES;
790    }
791}
792
793- (BOOL)unloadModelWithHandle:(ModelHandle *)handle {
794    BOOL result = NO;
795    @autoreleasepool {
796        NSValue *key = [NSValue valueWithPointer:handle];
797        os_unfair_lock_lock(&_lock);
798        result = (self.handleToExecutorMap[key] != nil);
799        [self.handleToExecutorMap removeObjectForKey:key];
800        os_unfair_lock_unlock(&_lock);
801    }
802
803    return result;
804}
805
806- (BOOL)purgeModelsCacheAndReturnError:(NSError *__autoreleasing *)error {
807    return [self.assetManager purgeAndReturnError:error];
808}
809
810@end
811