xref: /aosp_15_r20/external/executorch/extension/benchmark/apple/Benchmark/Tests/GenericTests.mm (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#import "ResourceTestCase.h"
10
11#import <executorch/extension/module/module.h>
12#import <executorch/extension/tensor/tensor.h>
13
14using namespace ::executorch::extension;
15using namespace ::executorch::runtime;
16
17#define ASSERT_OK_OR_RETURN(value__)            \
18  ({                                            \
19    XCTAssertEqual(value__.error(), Error::Ok); \
20    if (!value__.ok()) {                        \
21      return;                                   \
22    }                                           \
23  })
24
25@interface GenericTests : ResourceTestCase
26@end
27
28@implementation GenericTests
29
30+ (NSArray<NSString *> *)directories {
31  return @[
32    @"Resources",
33    @"aatp/data", // AWS Farm devices look for resources here.
34  ];
35}
36
37+ (NSDictionary<NSString *, BOOL (^)(NSString *)> *)predicates {
38  return @{
39    @"model" : ^BOOL(NSString *filename){
40      return [filename hasSuffix:@".pte"];
41    },
42  };
43}
44
45+ (NSDictionary<NSString *, void (^)(XCTestCase *)> *)dynamicTestsForResources:
46    (NSDictionary<NSString *, NSString *> *)resources {
47  NSString *modelPath = resources[@"model"];
48  return @{
49    @"load" : ^(XCTestCase *testCase){
50      [testCase
51          measureWithMetrics:@[ [XCTClockMetric new], [XCTMemoryMetric new] ]
52                       block:^{
53                         XCTAssertEqual(
54                             Module(modelPath.UTF8String).load_forward(),
55                             Error::Ok);
56                       }];
57    },
58    @"forward" : ^(XCTestCase *testCase) {
59      auto __block module = std::make_unique<Module>(modelPath.UTF8String);
60
61      const auto method_meta = module->method_meta("forward");
62      ASSERT_OK_OR_RETURN(method_meta);
63
64      const auto num_inputs = method_meta->num_inputs();
65      XCTAssertGreaterThan(num_inputs, 0);
66
67      std::vector<TensorPtr> tensors;
68      tensors.reserve(num_inputs);
69
70      for (auto index = 0; index < num_inputs; ++index) {
71        const auto input_tag = method_meta->input_tag(index);
72        ASSERT_OK_OR_RETURN(input_tag);
73
74        switch (*input_tag) {
75        case Tag::Tensor: {
76          const auto tensor_meta = method_meta->input_tensor_meta(index);
77          ASSERT_OK_OR_RETURN(tensor_meta);
78
79          const auto sizes = tensor_meta->sizes();
80          tensors.emplace_back(
81              ones({sizes.begin(), sizes.end()}, tensor_meta->scalar_type()));
82          XCTAssertEqual(module->set_input(tensors.back(), index), Error::Ok);
83        } break;
84        default:
85          XCTFail("Unsupported tag %i at input %d", *input_tag, index);
86        }
87      }
88      XCTMeasureOptions *options = [[XCTMeasureOptions alloc] init];
89      options.iterationCount = 20;
90      [testCase measureWithMetrics:@[ [XCTClockMetric new], [XCTMemoryMetric new] ]
91                            options:options
92                            block:^{
93                              XCTAssertEqual(module->forward().error(), Error::Ok);
94                            }];
95    },
96  };
97}
98
99@end
100