xref: /aosp_15_r20/external/pytorch/ios/TestApp/TestAppTests/TestLiteInterpreter.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker#import <XCTest/XCTest.h>
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker#include <torch/csrc/jit/mobile/import.h>
4*da0073e9SAndroid Build Coastguard Worker#include <torch/csrc/jit/mobile/module.h>
5*da0073e9SAndroid Build Coastguard Worker#include <torch/script.h>
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker@interface TestAppTests : XCTestCase
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker@end
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker@implementation TestAppTests {
12*da0073e9SAndroid Build Coastguard Worker}
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker- (void)testCoreML {
15*da0073e9SAndroid Build Coastguard Worker  NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"model_coreml"
16*da0073e9SAndroid Build Coastguard Worker                                                                         ofType:@"ptl"];
17*da0073e9SAndroid Build Coastguard Worker  auto module = torch::jit::_load_for_mobile(modelPath.UTF8String);
18*da0073e9SAndroid Build Coastguard Worker  c10::InferenceMode mode;
19*da0073e9SAndroid Build Coastguard Worker  auto input = torch::ones({1, 3, 224, 224}, at::kFloat);
20*da0073e9SAndroid Build Coastguard Worker  auto outputTensor = module.forward({input}).toTensor();
21*da0073e9SAndroid Build Coastguard Worker  XCTAssertTrue(outputTensor.numel() == 1000);
22*da0073e9SAndroid Build Coastguard Worker}
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker- (void)testModel:(NSString*)modelName {
25*da0073e9SAndroid Build Coastguard Worker  NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:modelName
26*da0073e9SAndroid Build Coastguard Worker                                                                         ofType:@"ptl"];
27*da0073e9SAndroid Build Coastguard Worker  XCTAssertNotNil(modelPath, @"Model not found. See https://github.com/pytorch/pytorch/tree/master/test/mobile/model_test#diagnose-failed-test.");
28*da0073e9SAndroid Build Coastguard Worker  [self runModel:modelPath];
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker  // model generated on the fly
31*da0073e9SAndroid Build Coastguard Worker  NSString* onTheFlyModelName = [NSString stringWithFormat:@"%@", modelName];
32*da0073e9SAndroid Build Coastguard Worker  NSString* onTheFlyModelPath = [[NSBundle bundleForClass:[self class]] pathForResource:onTheFlyModelName
33*da0073e9SAndroid Build Coastguard Worker                                                                         ofType:@"ptl"];
34*da0073e9SAndroid Build Coastguard Worker  XCTAssertNotNil(onTheFlyModelPath, @"On-the-fly model not found. Follow https://github.com/pytorch/pytorch/tree/master/test/mobile/model_test#diagnose-failed-test to generate them and run the setup.rb script again.");
35*da0073e9SAndroid Build Coastguard Worker  [self runModel:onTheFlyModelPath];
36*da0073e9SAndroid Build Coastguard Worker}
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker- (void)runModel:(NSString*)modelPath {
39*da0073e9SAndroid Build Coastguard Worker  c10::InferenceMode mode;
40*da0073e9SAndroid Build Coastguard Worker  auto module = torch::jit::_load_for_mobile(modelPath.UTF8String);
41*da0073e9SAndroid Build Coastguard Worker  auto has_bundled_input = module.find_method("get_all_bundled_inputs");
42*da0073e9SAndroid Build Coastguard Worker  if (has_bundled_input) {
43*da0073e9SAndroid Build Coastguard Worker    c10::IValue bundled_inputs = module.run_method("get_all_bundled_inputs");
44*da0073e9SAndroid Build Coastguard Worker    c10::List<at::IValue> all_inputs = bundled_inputs.toList();
45*da0073e9SAndroid Build Coastguard Worker    std::vector<std::vector<at::IValue>> inputs;
46*da0073e9SAndroid Build Coastguard Worker    for (at::IValue input : all_inputs) {
47*da0073e9SAndroid Build Coastguard Worker      inputs.push_back(input.toTupleRef().elements());
48*da0073e9SAndroid Build Coastguard Worker    }
49*da0073e9SAndroid Build Coastguard Worker    // run with the first bundled input
50*da0073e9SAndroid Build Coastguard Worker    XCTAssertNoThrow(module.forward(inputs[0]));
51*da0073e9SAndroid Build Coastguard Worker  } else {
52*da0073e9SAndroid Build Coastguard Worker    XCTAssertNoThrow(module.forward({}));
53*da0073e9SAndroid Build Coastguard Worker  }
54*da0073e9SAndroid Build Coastguard Worker}
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker// TODO remove this once updated test script
57*da0073e9SAndroid Build Coastguard Worker- (void)testLiteInterpreter {
58*da0073e9SAndroid Build Coastguard Worker  XCTAssertTrue(true);
59*da0073e9SAndroid Build Coastguard Worker}
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker- (void)testMobileNetV2 {
62*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"mobilenet_v2"];
63*da0073e9SAndroid Build Coastguard Worker}
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker- (void)testPointwiseOps {
66*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"pointwise_ops"];
67*da0073e9SAndroid Build Coastguard Worker}
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker- (void)testReductionOps {
70*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"reduction_ops"];
71*da0073e9SAndroid Build Coastguard Worker}
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker- (void)testComparisonOps {
74*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"comparison_ops"];
75*da0073e9SAndroid Build Coastguard Worker}
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker- (void)testOtherMathOps {
78*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"other_math_ops"];
79*da0073e9SAndroid Build Coastguard Worker}
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker- (void)testSpectralOps {
82*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"spectral_ops"];
83*da0073e9SAndroid Build Coastguard Worker}
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker- (void)testBlasLapackOps {
86*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"blas_lapack_ops"];
87*da0073e9SAndroid Build Coastguard Worker}
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker- (void)testSamplingOps {
90*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"sampling_ops"];
91*da0073e9SAndroid Build Coastguard Worker}
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker- (void)testTensorOps {
94*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"tensor_general_ops"];
95*da0073e9SAndroid Build Coastguard Worker}
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker- (void)testTensorCreationOps {
98*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"tensor_creation_ops"];
99*da0073e9SAndroid Build Coastguard Worker}
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker- (void)testTensorIndexingOps {
102*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"tensor_indexing_ops"];
103*da0073e9SAndroid Build Coastguard Worker}
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker- (void)testTensorTypingOps {
106*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"tensor_typing_ops"];
107*da0073e9SAndroid Build Coastguard Worker}
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker- (void)testTensorViewOps {
110*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"tensor_view_ops"];
111*da0073e9SAndroid Build Coastguard Worker}
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker- (void)testConvolutionOps {
114*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"convolution_ops"];
115*da0073e9SAndroid Build Coastguard Worker}
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker- (void)testPoolingOps {
118*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"pooling_ops"];
119*da0073e9SAndroid Build Coastguard Worker}
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker- (void)testPaddingOps {
122*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"padding_ops"];
123*da0073e9SAndroid Build Coastguard Worker}
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker- (void)testActivationOps {
126*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"activation_ops"];
127*da0073e9SAndroid Build Coastguard Worker}
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker- (void)testNormalizationOps {
130*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"normalization_ops"];
131*da0073e9SAndroid Build Coastguard Worker}
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker- (void)testRecurrentOps {
134*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"recurrent_ops"];
135*da0073e9SAndroid Build Coastguard Worker}
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker- (void)testTransformerOps {
138*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"transformer_ops"];
139*da0073e9SAndroid Build Coastguard Worker}
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker- (void)testLinearOps {
142*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"linear_ops"];
143*da0073e9SAndroid Build Coastguard Worker}
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker- (void)testDropoutOps {
146*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"dropout_ops"];
147*da0073e9SAndroid Build Coastguard Worker}
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker- (void)testSparseOps {
150*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"sparse_ops"];
151*da0073e9SAndroid Build Coastguard Worker}
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker- (void)testDistanceFunctionOps {
154*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"distance_function_ops"];
155*da0073e9SAndroid Build Coastguard Worker}
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker- (void)testLossFunctionOps {
158*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"loss_function_ops"];
159*da0073e9SAndroid Build Coastguard Worker}
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker- (void)testVisionFunctionOps {
162*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"vision_function_ops"];
163*da0073e9SAndroid Build Coastguard Worker}
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker- (void)testShuffleOps {
166*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"shuffle_ops"];
167*da0073e9SAndroid Build Coastguard Worker}
168*da0073e9SAndroid Build Coastguard Worker
169*da0073e9SAndroid Build Coastguard Worker- (void)testNNUtilsOps {
170*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"nn_utils_ops"];
171*da0073e9SAndroid Build Coastguard Worker}
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker- (void)testQuantOps {
174*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"general_quant_ops"];
175*da0073e9SAndroid Build Coastguard Worker}
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker- (void)testDynamicQuantOps {
178*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"dynamic_quant_ops"];
179*da0073e9SAndroid Build Coastguard Worker}
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker- (void)testStaticQuantOps {
182*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"static_quant_ops"];
183*da0073e9SAndroid Build Coastguard Worker}
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker- (void)testFusedQuantOps {
186*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"fused_quant_ops"];
187*da0073e9SAndroid Build Coastguard Worker}
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker- (void)testTorchScriptBuiltinQuantOps {
190*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"torchscript_builtin_ops"];
191*da0073e9SAndroid Build Coastguard Worker}
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker- (void)testTorchScriptCollectionQuantOps {
194*da0073e9SAndroid Build Coastguard Worker  [self testModel:@"torchscript_collection_ops"];
195*da0073e9SAndroid Build Coastguard Worker}
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker@end
198