xref: /aosp_15_r20/external/executorch/backends/apple/coreml/runtime/delegate/ETCoreMLModel.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1//
2// ETCoreMLModel.mm
3//
4// Copyright © 2024 Apple Inc. All rights reserved.
5//
6// Please refer to the license found in the LICENSE file in the root directory of the source tree.
7
8#import <ETCoreMLModel.h>
9
10#import "ETCoreMLAsset.h"
11#import "ETCoreMLLogging.h"
12#import "multiarray.h"
13#import "objc_array_util.h"
14#import "MLModel_Prewarm.h"
15#import <functional>
16#import <numeric>
17
18#pragma mark - ETCoreMLMultiArrayDescriptor
19__attribute__((objc_subclassing_restricted))
20@interface ETCoreMLMultiArrayDescriptor: NSObject <NSCopying>
21
22- (instancetype)init NS_UNAVAILABLE;
23
24+ (instancetype)new NS_UNAVAILABLE;
25
26- (instancetype)initWithShape:(NSArray<NSNumber *> *)shape
27                     dataType:(MLMultiArrayDataType)dataType NS_DESIGNATED_INITIALIZER;
28
29@property (copy, readonly, nonatomic) NSArray<NSNumber *> *shape;
30
31@property (assign, readonly, nonatomic) MLMultiArrayDataType dataType;
32
33@end
34
35@implementation ETCoreMLMultiArrayDescriptor
36
37- (instancetype)initWithShape:(NSArray<NSNumber *> *)shape
38                     dataType:(MLMultiArrayDataType)dataType {
39    self = [super init];
40    if (self) {
41        _shape = shape;
42        _dataType = dataType;
43    }
44
45    return self;
46}
47
48- (BOOL)isEqual:(id)object {
49    if (object == self) {
50        return YES;
51    }
52
53    if (![object isKindOfClass:self.class]) {
54        return NO;
55    }
56
57    ETCoreMLMultiArrayDescriptor *other = (ETCoreMLMultiArrayDescriptor *)object;
58    return [self.shape isEqualToArray:other.shape] && self.dataType == other.dataType;
59}
60
61- (NSUInteger)hash {
62    return [self.shape hash] ^ (NSUInteger)self.dataType;
63}
64
65- (instancetype)copyWithZone:(NSZone *)zone {
66    return [[ETCoreMLMultiArrayDescriptor allocWithZone:zone] initWithShape:self.shape
67                                                                   dataType:self.dataType];
68}
69
70@end
71
72namespace {
73
74using namespace executorchcoreml;
75
76size_t get_number_of_bytes(MLMultiArrayDataType data_type) {
77    switch (data_type) {
78        case MLMultiArrayDataTypeFloat16: {
79            return 2;
80        }
81        case MLMultiArrayDataTypeFloat32: {
82            return 4;
83        }
84        case MLMultiArrayDataTypeInt32: {
85            return 4;
86        }
87        case MLMultiArrayDataTypeFloat64: {
88            return 8;
89        }
90        default: {
91            return 0;
92        }
93    }
94}
95
96std::vector<size_t> calculate_strides(const std::vector<size_t>& shape) {
97    if (shape.size() == 0) {
98        return {};
99    }
100
101    if (shape.size() == 1) {
102        return {1};
103    }
104
105    std::vector<size_t> strides(shape.size(), 1);
106    size_t product = 1;
107    for (size_t i = shape.size(); i > 0; i--) {
108        strides[i - 1] = product;
109        product *= shape[i - 1];
110    }
111
112    return strides;
113}
114
115MLMultiArray * _Nullable make_ml_multi_array(const std::vector<size_t>& shape,
116                                             MLMultiArrayDataType dataType,
117                                             NSCache<ETCoreMLMultiArrayDescriptor *, NSMutableData *> *cache,
118                                             NSError * __autoreleasing *error) {
119    ETCoreMLMultiArrayDescriptor *descriptor = [[ETCoreMLMultiArrayDescriptor alloc] initWithShape:to_array(shape)
120                                                                                          dataType:dataType];
121    // Check the cache first otherwise allocate a new backing storage.
122    NSMutableData *backing_storage = [cache objectForKey:descriptor];
123    if (backing_storage) {
124        [cache removeObjectForKey:descriptor];
125    } else {
126        size_t n = std::accumulate(shape.cbegin(), shape.cend(), 1, std::multiplies<size_t>{});
127        backing_storage = [[NSMutableData alloc] initWithLength:n * get_number_of_bytes(dataType)];
128    }
129
130    __weak NSCache<ETCoreMLMultiArrayDescriptor *, NSMutableData *> *weakCache = cache;
131    // Add the storage back to the cache when it gets deallocated, the next prediction would use the same storage.
132    MLMultiArray *result = [[MLMultiArray alloc] initWithDataPointer:backing_storage.mutableBytes
133                                                               shape:descriptor.shape
134                                                            dataType:descriptor.dataType
135                                                             strides:to_array(calculate_strides(shape))
136                                                         deallocator:^(void * _Nonnull bytes) {[weakCache setObject:backing_storage forKey:descriptor];}
137                                                               error:error];
138
139    return result;
140}
141
142NSDictionary<NSString *, MLMultiArrayConstraint *> *
143get_multi_array_constraints_by_name(NSDictionary<NSString *, MLFeatureDescription *> *feature_descriptions) {
144    NSMutableDictionary<NSString *, MLMultiArrayConstraint *> *result = [NSMutableDictionary dictionaryWithCapacity:feature_descriptions.count];
145    [feature_descriptions enumerateKeysAndObjectsUsingBlock:^(NSString *key, MLFeatureDescription *description, BOOL * _Nonnull stop) {
146        result[key] = description.multiArrayConstraint;
147    }];
148
149    return result;
150}
151
152NSDictionary<NSString *, MLMultiArrayConstraint *> *get_multi_array_input_constraints_by_name(MLModelDescription *description) {
153    return get_multi_array_constraints_by_name(description.inputDescriptionsByName);
154}
155
156NSDictionary<NSString *, MLMultiArrayConstraint *> *get_multi_array_output_constraints_by_name(MLModelDescription *description) {
157    return get_multi_array_constraints_by_name(description.outputDescriptionsByName);
158}
159
160#if MODEL_STATE_IS_SUPPORTED
161API_AVAILABLE(macos(15.0), ios(18.0), tvos(18.0), watchos(11.0))
162void reset_state_for_feature_name(NSString *feature_name, MLState *state) {
163    [state getMultiArrayForStateNamed:feature_name handler:^(MLMultiArray *buffer) {
164        [buffer getMutableBytesWithHandler:^(void *mutableBytes, NSInteger size, NSArray<NSNumber *> * __unused strides) {
165            uint8_t *start = reinterpret_cast<uint8_t *>(mutableBytes);
166            uint8_t *end = start + size;
167            std::fill(start, end, uint8_t(0));
168        }];
169    }];
170}
171#endif
172
173}
174
175#pragma mark - ETCoreMLModel
176@interface ETCoreMLModel ()
177
178@property (strong, readonly, nonatomic) NSCache<ETCoreMLMultiArrayDescriptor *, NSMutableData *> *cache;
179@property (copy, readonly, nonatomic) NSDictionary<NSString *, MLMultiArrayConstraint *> *inputConstraintsByName;
180@property (copy, readonly, nonatomic) NSDictionary<NSString *, MLMultiArrayConstraint *> *outputConstraintsByName;
181
182@end
183
184
185@implementation ETCoreMLModel
186
187- (nullable instancetype)initWithAsset:(ETCoreMLAsset *)asset
188                         configuration:(MLModelConfiguration *)configuration
189                     orderedInputNames:(NSOrderedSet<NSString *> *)orderedInputNames
190                    orderedOutputNames:(NSOrderedSet<NSString *> *)orderedOutputNames
191                                 error:(NSError * __autoreleasing *)error {
192    if (![asset keepAliveAndReturnError:error]) {
193        return nil;
194    }
195
196    MLModel *mlModel = [MLModel modelWithContentsOfURL:asset.contentURL
197                                         configuration:configuration
198                                                 error:error];
199    if (!mlModel) {
200        return nil;
201    }
202
203    self = [super init];
204    if (self) {
205        _mlModel = mlModel;
206        _asset = asset;
207        _orderedInputNames = [orderedInputNames copy];
208        _orderedOutputNames = [orderedOutputNames copy];
209        _cache = [[NSCache alloc] init];
210        _inputConstraintsByName = get_multi_array_input_constraints_by_name(mlModel.modelDescription);
211        _outputConstraintsByName = get_multi_array_output_constraints_by_name(mlModel.modelDescription);
212#if MODEL_STATE_IS_SUPPORTED
213        if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, watchOS 11.0, *)) {
214            _state = mlModel.modelDescription.stateDescriptionsByName.count > 0 ? [_mlModel newState] : nil;
215        }
216#endif
217    }
218
219    return self;
220}
221
222- (NSString *)identifier {
223    return self.asset.identifier;
224}
225
226- (nullable NSArray<MLMultiArray *> *)prepareArgs:(const std::vector<executorchcoreml::MultiArray>&)args
227                                         argNames:(NSOrderedSet<NSString *> *)argNames
228                             argConstraintsByName:(NSDictionary<NSString *, MLMultiArrayConstraint *> *)argConstraintsByName
229                                         copyData:(const BOOL)copyData
230                                            error:(NSError * __autoreleasing *)error {
231    NSEnumerator *nameEnumerator = [argNames objectEnumerator];
232    NSMutableArray<MLMultiArray *> *result = [NSMutableArray arrayWithCapacity:args.size()];
233    for (const auto& arg : args) {
234        BOOL lCopyData = copyData;
235        NSString *argName = [nameEnumerator nextObject];
236        MLMultiArrayConstraint *constraint = argConstraintsByName[argName];
237        const auto& layout = arg.layout();
238        auto dataType = to_ml_multiarray_data_type(layout.dataType());
239        MLMultiArray *multiArrayArg = nil;
240        if (dataType == constraint.dataType) {
241            // We can use the same data storage.
242            multiArrayArg = [[MLMultiArray alloc] initWithDataPointer:arg.data()
243                                                                shape:to_array(layout.shape())
244                                                             dataType:constraint.dataType
245                                                              strides:to_array(layout.strides())
246                                                          deallocator:^(void * _Nonnull bytes) {}
247                                                                error:error];
248            lCopyData = NO;
249        } else {
250            // We can't use the same data storage, data types are not the same.
251            multiArrayArg = ::make_ml_multi_array(layout.shape(), constraint.dataType, self.cache, error);
252        }
253
254        if (!multiArrayArg) {
255            return nil;
256        }
257
258        if (multiArrayArg && lCopyData) {
259            [multiArrayArg getMutableBytesWithHandler:^(void *_Nonnull mutableBytes,
260                                                        NSInteger __unused size,
261                                                        NSArray<NSNumber *> *strides) {
262                MultiArray buffer(mutableBytes, MultiArray::MemoryLayout(to_multiarray_data_type(constraint.dataType).value(),
263                                                                         layout.shape(),
264                                                                         to_vector<ssize_t>(strides)));
265                arg.copy(buffer);
266            }];
267        }
268
269        [result addObject:multiArrayArg];
270    }
271
272    return result;
273}
274
275- (nullable NSArray<MLMultiArray *> *)prepareInputs:(const std::vector<executorchcoreml::MultiArray>&)inputs
276                                              error:(NSError * __autoreleasing *)error {
277    return [self prepareArgs:inputs
278                    argNames:self.orderedInputNames
279        argConstraintsByName:self.inputConstraintsByName
280                    copyData:YES
281                       error:error];
282
283}
284
285- (nullable NSArray<MLMultiArray *> *)prepareOutputBackings:(const std::vector<executorchcoreml::MultiArray>&)outputs
286                                                      error:(NSError * __autoreleasing *)error {
287    return [self prepareArgs:outputs
288                    argNames:self.orderedOutputNames
289        argConstraintsByName:self.outputConstraintsByName
290                    copyData:NO
291                       error:error];
292
293}
294
295- (nullable id<MLFeatureProvider>)predictionFromFeatures:(id<MLFeatureProvider>)input
296                                                 options:(MLPredictionOptions *)options
297                                                   error:(NSError **)error {
298#if MODEL_STATE_IS_SUPPORTED
299    if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, watchOS 11.0, *)) {
300        if (self.state != nil) {
301            return [self.mlModel predictionFromFeatures:input
302                                             usingState:(MLState *)self.state
303                                                options:options
304                                                  error:error];
305        }
306    }
307#endif
308
309    id<MLFeatureProvider> result = [self.mlModel predictionFromFeatures:input
310                                                                options:options
311                                                                  error:error];
312
313    return result;
314}
315
316- (BOOL)prewarmAndReturnError:(NSError* __autoreleasing*)error {
317    NSError *localError = nil;
318    BOOL result = [self.mlModel prewarmUsingState:self.state error:error];
319    if (!result) {
320        ETCoreMLLogError(localError,
321                         "%@: Failed to prewarm model with identifier = %@",
322                         NSStringFromClass(self.class),
323                         self.identifier);
324    }
325
326#if MODEL_STATE_IS_SUPPORTED
327    if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, watchOS 11.0, *)) {
328        NSDictionary<NSString *, MLFeatureDescription *> *stateDescriptions = self.mlModel.modelDescription.stateDescriptionsByName;
329        [stateDescriptions enumerateKeysAndObjectsUsingBlock:^(NSString *featureName, MLFeatureDescription * __unused obj, BOOL * __unused stop) {
330            reset_state_for_feature_name(featureName, (MLState *) self.state);
331        }];
332    }
333#endif
334
335
336    if (error) {
337        *error = localError;
338    }
339
340    return result;
341}
342
343@end
344