1/* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9#import "ResourceTestCase.h" 10 11#import <CoreML/CoreML.h> 12 13static MLMultiArray *DummyMultiArrayForFeature(MLFeatureDescription *feature, NSError **error) { 14 MLMultiArray *array = [[MLMultiArray alloc] initWithShape:feature.multiArrayConstraint.shape 15 dataType:feature.multiArrayConstraint.dataType == MLMultiArrayDataTypeInt32 ? MLMultiArrayDataTypeInt32 : MLMultiArrayDataTypeDouble 16 error:error]; 17 for (auto index = 0; index < array.count; ++index) { 18 array[index] = feature.multiArrayConstraint.dataType == MLMultiArrayDataTypeInt32 ? @1 : @1.0; 19 } 20 return array; 21} 22 23static NSMutableDictionary *DummyInputsForModel(MLModel *model, NSError **error) { 24 NSMutableDictionary *inputs = [NSMutableDictionary dictionary]; 25 NSDictionary<NSString *, MLFeatureDescription *> *inputDescriptions = model.modelDescription.inputDescriptionsByName; 26 27 for (NSString *inputName in inputDescriptions) { 28 MLFeatureDescription *feature = inputDescriptions[inputName]; 29 30 switch (feature.type) { 31 case MLFeatureTypeMultiArray: { 32 MLMultiArray *array = DummyMultiArrayForFeature(feature, error); 33 inputs[inputName] = [MLFeatureValue featureValueWithMultiArray:array]; 34 break; 35 } 36 case MLFeatureTypeInt64: 37 inputs[inputName] = [MLFeatureValue featureValueWithInt64:1]; 38 break; 39 case MLFeatureTypeDouble: 40 inputs[inputName] = [MLFeatureValue featureValueWithDouble:1.0]; 41 break; 42 case MLFeatureTypeString: 43 inputs[inputName] = [MLFeatureValue featureValueWithString:@"1"]; 44 break; 45 default: 46 break; 47 } 48 } 49 return inputs; 50} 51 52@interface CoreMLTests : ResourceTestCase 53@end 54 55@implementation CoreMLTests 56 57+ (NSArray<NSString *> *)directories { 58 return @[@"Resources"]; 59} 60 61+ (NSDictionary<NSString *, BOOL (^)(NSString *)> *)predicates { 62 return @{ @"model" : ^BOOL(NSString *filename) { 63 return [filename hasSuffix:@".mlpackage"]; 64 }}; 65} 66 67+ (NSDictionary<NSString *, void (^)(XCTestCase *)> *)dynamicTestsForResources:(NSDictionary<NSString *, NSString *> *)resources { 68 NSString *modelPath = resources[@"model"]; 69 70 return @{ 71 @"prediction" : ^(XCTestCase *testCase) { 72 NSError *error = nil; 73 NSURL *compiledModelURL = [MLModel compileModelAtURL:[NSURL fileURLWithPath:modelPath] error:&error]; 74 if (error || !compiledModelURL) { 75 XCTFail(@"Failed to compile model: %@", error.localizedDescription); 76 return; 77 } 78 MLModel *model = [MLModel modelWithContentsOfURL:compiledModelURL error:&error]; 79 if (error || !model) { 80 XCTFail(@"Failed to load model: %@", error.localizedDescription); 81 return; 82 } 83 NSMutableDictionary *inputs = DummyInputsForModel(model, &error); 84 if (error || !inputs) { 85 XCTFail(@"Failed to prepare inputs: %@", error.localizedDescription); 86 return; 87 } 88 MLDictionaryFeatureProvider *featureProvider = [[MLDictionaryFeatureProvider alloc] initWithDictionary:inputs error:&error]; 89 if (error || !featureProvider) { 90 XCTFail(@"Failed to create input provider: %@", error.localizedDescription); 91 return; 92 } 93 [testCase measureWithMetrics:@[[XCTClockMetric new], [XCTMemoryMetric new]] 94 block:^{ 95 NSError *error = nil; 96 id<MLFeatureProvider> prediction = [model predictionFromFeatures:featureProvider error:&error]; 97 if (error || !prediction) { 98 XCTFail(@"Prediction failed: %@", error.localizedDescription); 99 } 100 }]; 101 } 102 }; 103} 104 105@end 106