1/* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9#import "ResourceTestCase.h" 10 11#import <executorch/extension/module/module.h> 12#import <executorch/extension/tensor/tensor.h> 13 14using namespace ::executorch::extension; 15using namespace ::executorch::runtime; 16 17#define ASSERT_OK_OR_RETURN(value__) \ 18 ({ \ 19 XCTAssertEqual(value__.error(), Error::Ok); \ 20 if (!value__.ok()) { \ 21 return; \ 22 } \ 23 }) 24 25@interface GenericTests : ResourceTestCase 26@end 27 28@implementation GenericTests 29 30+ (NSArray<NSString *> *)directories { 31 return @[ 32 @"Resources", 33 @"aatp/data", // AWS Farm devices look for resources here. 34 ]; 35} 36 37+ (NSDictionary<NSString *, BOOL (^)(NSString *)> *)predicates { 38 return @{ 39 @"model" : ^BOOL(NSString *filename){ 40 return [filename hasSuffix:@".pte"]; 41 }, 42 }; 43} 44 45+ (NSDictionary<NSString *, void (^)(XCTestCase *)> *)dynamicTestsForResources: 46 (NSDictionary<NSString *, NSString *> *)resources { 47 NSString *modelPath = resources[@"model"]; 48 return @{ 49 @"load" : ^(XCTestCase *testCase){ 50 [testCase 51 measureWithMetrics:@[ [XCTClockMetric new], [XCTMemoryMetric new] ] 52 block:^{ 53 XCTAssertEqual( 54 Module(modelPath.UTF8String).load_forward(), 55 Error::Ok); 56 }]; 57 }, 58 @"forward" : ^(XCTestCase *testCase) { 59 auto __block module = std::make_unique<Module>(modelPath.UTF8String); 60 61 const auto method_meta = module->method_meta("forward"); 62 ASSERT_OK_OR_RETURN(method_meta); 63 64 const auto num_inputs = method_meta->num_inputs(); 65 XCTAssertGreaterThan(num_inputs, 0); 66 67 std::vector<TensorPtr> tensors; 68 tensors.reserve(num_inputs); 69 70 for (auto index = 0; index < num_inputs; ++index) { 71 const auto input_tag = method_meta->input_tag(index); 72 ASSERT_OK_OR_RETURN(input_tag); 73 74 switch (*input_tag) { 75 case Tag::Tensor: { 76 const auto tensor_meta = method_meta->input_tensor_meta(index); 77 ASSERT_OK_OR_RETURN(tensor_meta); 78 79 const auto sizes = tensor_meta->sizes(); 80 tensors.emplace_back( 81 ones({sizes.begin(), sizes.end()}, tensor_meta->scalar_type())); 82 XCTAssertEqual(module->set_input(tensors.back(), index), Error::Ok); 83 } break; 84 default: 85 XCTFail("Unsupported tag %i at input %d", *input_tag, index); 86 } 87 } 88 XCTMeasureOptions *options = [[XCTMeasureOptions alloc] init]; 89 options.iterationCount = 20; 90 [testCase measureWithMetrics:@[ [XCTClockMetric new], [XCTMemoryMetric new] ] 91 options:options 92 block:^{ 93 XCTAssertEqual(module->forward().error(), Error::Ok); 94 }]; 95 }, 96 }; 97} 98 99@end 100