xref: /aosp_15_r20/external/executorch/extension/benchmark/apple/Benchmark/Tests/CoreMLTests.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#import "ResourceTestCase.h"
10
11#import <CoreML/CoreML.h>
12
13static MLMultiArray *DummyMultiArrayForFeature(MLFeatureDescription *feature, NSError **error) {
14  MLMultiArray *array = [[MLMultiArray alloc] initWithShape:feature.multiArrayConstraint.shape
15                                                   dataType:feature.multiArrayConstraint.dataType == MLMultiArrayDataTypeInt32 ? MLMultiArrayDataTypeInt32 : MLMultiArrayDataTypeDouble
16                                                      error:error];
17  for (auto index = 0; index < array.count; ++index) {
18    array[index] = feature.multiArrayConstraint.dataType == MLMultiArrayDataTypeInt32 ? @1 : @1.0;
19  }
20  return array;
21}
22
23static NSMutableDictionary *DummyInputsForModel(MLModel *model, NSError **error) {
24  NSMutableDictionary *inputs = [NSMutableDictionary dictionary];
25  NSDictionary<NSString *, MLFeatureDescription *> *inputDescriptions = model.modelDescription.inputDescriptionsByName;
26
27  for (NSString *inputName in inputDescriptions) {
28    MLFeatureDescription *feature = inputDescriptions[inputName];
29
30    switch (feature.type) {
31      case MLFeatureTypeMultiArray: {
32        MLMultiArray *array = DummyMultiArrayForFeature(feature, error);
33        inputs[inputName] = [MLFeatureValue featureValueWithMultiArray:array];
34        break;
35      }
36      case MLFeatureTypeInt64:
37        inputs[inputName] = [MLFeatureValue featureValueWithInt64:1];
38        break;
39      case MLFeatureTypeDouble:
40        inputs[inputName] = [MLFeatureValue featureValueWithDouble:1.0];
41        break;
42      case MLFeatureTypeString:
43        inputs[inputName] = [MLFeatureValue featureValueWithString:@"1"];
44        break;
45      default:
46        break;
47    }
48  }
49  return inputs;
50}
51
52@interface CoreMLTests : ResourceTestCase
53@end
54
55@implementation CoreMLTests
56
57+ (NSArray<NSString *> *)directories {
58  return @[@"Resources"];
59}
60
61+ (NSDictionary<NSString *, BOOL (^)(NSString *)> *)predicates {
62  return @{ @"model" : ^BOOL(NSString *filename) {
63    return [filename hasSuffix:@".mlpackage"];
64  }};
65}
66
67+ (NSDictionary<NSString *, void (^)(XCTestCase *)> *)dynamicTestsForResources:(NSDictionary<NSString *, NSString *> *)resources {
68  NSString *modelPath = resources[@"model"];
69
70  return @{
71    @"prediction" : ^(XCTestCase *testCase) {
72      NSError *error = nil;
73      NSURL *compiledModelURL = [MLModel compileModelAtURL:[NSURL fileURLWithPath:modelPath] error:&error];
74      if (error || !compiledModelURL) {
75        XCTFail(@"Failed to compile model: %@", error.localizedDescription);
76        return;
77      }
78      MLModel *model = [MLModel modelWithContentsOfURL:compiledModelURL error:&error];
79      if (error || !model) {
80        XCTFail(@"Failed to load model: %@", error.localizedDescription);
81        return;
82      }
83      NSMutableDictionary *inputs = DummyInputsForModel(model, &error);
84      if (error || !inputs) {
85        XCTFail(@"Failed to prepare inputs: %@", error.localizedDescription);
86        return;
87      }
88      MLDictionaryFeatureProvider *featureProvider = [[MLDictionaryFeatureProvider alloc] initWithDictionary:inputs error:&error];
89      if (error || !featureProvider) {
90        XCTFail(@"Failed to create input provider: %@", error.localizedDescription);
91        return;
92      }
93      [testCase measureWithMetrics:@[[XCTClockMetric new], [XCTMemoryMetric new]]
94                             block:^{
95        NSError *error = nil;
96        id<MLFeatureProvider> prediction = [model predictionFromFeatures:featureProvider error:&error];
97        if (error || !prediction) {
98          XCTFail(@"Prediction failed: %@", error.localizedDescription);
99        }
100      }];
101    }
102  };
103}
104
105@end
106