1*da0073e9SAndroid Build Coastguard Worker#import <XCTest/XCTest.h> 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker#include <torch/csrc/jit/mobile/import.h> 4*da0073e9SAndroid Build Coastguard Worker#include <torch/csrc/jit/mobile/module.h> 5*da0073e9SAndroid Build Coastguard Worker#include <torch/script.h> 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker@interface TestAppTests : XCTestCase 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker@end 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker@implementation TestAppTests { 12*da0073e9SAndroid Build Coastguard Worker} 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker- (void)testCoreML { 15*da0073e9SAndroid Build Coastguard Worker NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"model_coreml" 16*da0073e9SAndroid Build Coastguard Worker ofType:@"ptl"]; 17*da0073e9SAndroid Build Coastguard Worker auto module = torch::jit::_load_for_mobile(modelPath.UTF8String); 18*da0073e9SAndroid Build Coastguard Worker c10::InferenceMode mode; 19*da0073e9SAndroid Build Coastguard Worker auto input = torch::ones({1, 3, 224, 224}, at::kFloat); 20*da0073e9SAndroid Build Coastguard Worker auto outputTensor = module.forward({input}).toTensor(); 21*da0073e9SAndroid Build Coastguard Worker XCTAssertTrue(outputTensor.numel() == 1000); 22*da0073e9SAndroid Build Coastguard Worker} 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker- (void)testModel:(NSString*)modelName { 25*da0073e9SAndroid Build Coastguard Worker NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:modelName 26*da0073e9SAndroid Build Coastguard Worker ofType:@"ptl"]; 27*da0073e9SAndroid Build Coastguard Worker XCTAssertNotNil(modelPath, @"Model not found. See https://github.com/pytorch/pytorch/tree/master/test/mobile/model_test#diagnose-failed-test."); 28*da0073e9SAndroid Build Coastguard Worker [self runModel:modelPath]; 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker // model generated on the fly 31*da0073e9SAndroid Build Coastguard Worker NSString* onTheFlyModelName = [NSString stringWithFormat:@"%@", modelName]; 32*da0073e9SAndroid Build Coastguard Worker NSString* onTheFlyModelPath = [[NSBundle bundleForClass:[self class]] pathForResource:onTheFlyModelName 33*da0073e9SAndroid Build Coastguard Worker ofType:@"ptl"]; 34*da0073e9SAndroid Build Coastguard Worker XCTAssertNotNil(onTheFlyModelPath, @"On-the-fly model not found. Follow https://github.com/pytorch/pytorch/tree/master/test/mobile/model_test#diagnose-failed-test to generate them and run the setup.rb script again."); 35*da0073e9SAndroid Build Coastguard Worker [self runModel:onTheFlyModelPath]; 36*da0073e9SAndroid Build Coastguard Worker} 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker- (void)runModel:(NSString*)modelPath { 39*da0073e9SAndroid Build Coastguard Worker c10::InferenceMode mode; 40*da0073e9SAndroid Build Coastguard Worker auto module = torch::jit::_load_for_mobile(modelPath.UTF8String); 41*da0073e9SAndroid Build Coastguard Worker auto has_bundled_input = module.find_method("get_all_bundled_inputs"); 42*da0073e9SAndroid Build Coastguard Worker if (has_bundled_input) { 43*da0073e9SAndroid Build Coastguard Worker c10::IValue bundled_inputs = module.run_method("get_all_bundled_inputs"); 44*da0073e9SAndroid Build Coastguard Worker c10::List<at::IValue> all_inputs = bundled_inputs.toList(); 45*da0073e9SAndroid Build Coastguard Worker std::vector<std::vector<at::IValue>> inputs; 46*da0073e9SAndroid Build Coastguard Worker for (at::IValue input : all_inputs) { 47*da0073e9SAndroid Build Coastguard Worker inputs.push_back(input.toTupleRef().elements()); 48*da0073e9SAndroid Build Coastguard Worker } 49*da0073e9SAndroid Build Coastguard Worker // run with the first bundled input 50*da0073e9SAndroid Build Coastguard Worker XCTAssertNoThrow(module.forward(inputs[0])); 51*da0073e9SAndroid Build Coastguard Worker } else { 52*da0073e9SAndroid Build Coastguard Worker XCTAssertNoThrow(module.forward({})); 53*da0073e9SAndroid Build Coastguard Worker } 54*da0073e9SAndroid Build Coastguard Worker} 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker// TODO remove this once updated test script 57*da0073e9SAndroid Build Coastguard Worker- (void)testLiteInterpreter { 58*da0073e9SAndroid Build Coastguard Worker XCTAssertTrue(true); 59*da0073e9SAndroid Build Coastguard Worker} 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker- (void)testMobileNetV2 { 62*da0073e9SAndroid Build Coastguard Worker [self testModel:@"mobilenet_v2"]; 63*da0073e9SAndroid Build Coastguard Worker} 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker- (void)testPointwiseOps { 66*da0073e9SAndroid Build Coastguard Worker [self testModel:@"pointwise_ops"]; 67*da0073e9SAndroid Build Coastguard Worker} 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker- (void)testReductionOps { 70*da0073e9SAndroid Build Coastguard Worker [self testModel:@"reduction_ops"]; 71*da0073e9SAndroid Build Coastguard Worker} 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker- (void)testComparisonOps { 74*da0073e9SAndroid Build Coastguard Worker [self testModel:@"comparison_ops"]; 75*da0073e9SAndroid Build Coastguard Worker} 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker- (void)testOtherMathOps { 78*da0073e9SAndroid Build Coastguard Worker [self testModel:@"other_math_ops"]; 79*da0073e9SAndroid Build Coastguard Worker} 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker- (void)testSpectralOps { 82*da0073e9SAndroid Build Coastguard Worker [self testModel:@"spectral_ops"]; 83*da0073e9SAndroid Build Coastguard Worker} 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker- (void)testBlasLapackOps { 86*da0073e9SAndroid Build Coastguard Worker [self testModel:@"blas_lapack_ops"]; 87*da0073e9SAndroid Build Coastguard Worker} 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker- (void)testSamplingOps { 90*da0073e9SAndroid Build Coastguard Worker [self testModel:@"sampling_ops"]; 91*da0073e9SAndroid Build Coastguard Worker} 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker- (void)testTensorOps { 94*da0073e9SAndroid Build Coastguard Worker [self testModel:@"tensor_general_ops"]; 95*da0073e9SAndroid Build Coastguard Worker} 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker- (void)testTensorCreationOps { 98*da0073e9SAndroid Build Coastguard Worker [self testModel:@"tensor_creation_ops"]; 99*da0073e9SAndroid Build Coastguard Worker} 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker- (void)testTensorIndexingOps { 102*da0073e9SAndroid Build Coastguard Worker [self testModel:@"tensor_indexing_ops"]; 103*da0073e9SAndroid Build Coastguard Worker} 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker- (void)testTensorTypingOps { 106*da0073e9SAndroid Build Coastguard Worker [self testModel:@"tensor_typing_ops"]; 107*da0073e9SAndroid Build Coastguard Worker} 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker- (void)testTensorViewOps { 110*da0073e9SAndroid Build Coastguard Worker [self testModel:@"tensor_view_ops"]; 111*da0073e9SAndroid Build Coastguard Worker} 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker- (void)testConvolutionOps { 114*da0073e9SAndroid Build Coastguard Worker [self testModel:@"convolution_ops"]; 115*da0073e9SAndroid Build Coastguard Worker} 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker- (void)testPoolingOps { 118*da0073e9SAndroid Build Coastguard Worker [self testModel:@"pooling_ops"]; 119*da0073e9SAndroid Build Coastguard Worker} 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker- (void)testPaddingOps { 122*da0073e9SAndroid Build Coastguard Worker [self testModel:@"padding_ops"]; 123*da0073e9SAndroid Build Coastguard Worker} 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker- (void)testActivationOps { 126*da0073e9SAndroid Build Coastguard Worker [self testModel:@"activation_ops"]; 127*da0073e9SAndroid Build Coastguard Worker} 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker- (void)testNormalizationOps { 130*da0073e9SAndroid Build Coastguard Worker [self testModel:@"normalization_ops"]; 131*da0073e9SAndroid Build Coastguard Worker} 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker- (void)testRecurrentOps { 134*da0073e9SAndroid Build Coastguard Worker [self testModel:@"recurrent_ops"]; 135*da0073e9SAndroid Build Coastguard Worker} 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker- (void)testTransformerOps { 138*da0073e9SAndroid Build Coastguard Worker [self testModel:@"transformer_ops"]; 139*da0073e9SAndroid Build Coastguard Worker} 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker- (void)testLinearOps { 142*da0073e9SAndroid Build Coastguard Worker [self testModel:@"linear_ops"]; 143*da0073e9SAndroid Build Coastguard Worker} 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker- (void)testDropoutOps { 146*da0073e9SAndroid Build Coastguard Worker [self testModel:@"dropout_ops"]; 147*da0073e9SAndroid Build Coastguard Worker} 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker- (void)testSparseOps { 150*da0073e9SAndroid Build Coastguard Worker [self testModel:@"sparse_ops"]; 151*da0073e9SAndroid Build Coastguard Worker} 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker- (void)testDistanceFunctionOps { 154*da0073e9SAndroid Build Coastguard Worker [self testModel:@"distance_function_ops"]; 155*da0073e9SAndroid Build Coastguard Worker} 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker- (void)testLossFunctionOps { 158*da0073e9SAndroid Build Coastguard Worker [self testModel:@"loss_function_ops"]; 159*da0073e9SAndroid Build Coastguard Worker} 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker- (void)testVisionFunctionOps { 162*da0073e9SAndroid Build Coastguard Worker [self testModel:@"vision_function_ops"]; 163*da0073e9SAndroid Build Coastguard Worker} 164*da0073e9SAndroid Build Coastguard Worker 165*da0073e9SAndroid Build Coastguard Worker- (void)testShuffleOps { 166*da0073e9SAndroid Build Coastguard Worker [self testModel:@"shuffle_ops"]; 167*da0073e9SAndroid Build Coastguard Worker} 168*da0073e9SAndroid Build Coastguard Worker 169*da0073e9SAndroid Build Coastguard Worker- (void)testNNUtilsOps { 170*da0073e9SAndroid Build Coastguard Worker [self testModel:@"nn_utils_ops"]; 171*da0073e9SAndroid Build Coastguard Worker} 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker- (void)testQuantOps { 174*da0073e9SAndroid Build Coastguard Worker [self testModel:@"general_quant_ops"]; 175*da0073e9SAndroid Build Coastguard Worker} 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker- (void)testDynamicQuantOps { 178*da0073e9SAndroid Build Coastguard Worker [self testModel:@"dynamic_quant_ops"]; 179*da0073e9SAndroid Build Coastguard Worker} 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker- (void)testStaticQuantOps { 182*da0073e9SAndroid Build Coastguard Worker [self testModel:@"static_quant_ops"]; 183*da0073e9SAndroid Build Coastguard Worker} 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker- (void)testFusedQuantOps { 186*da0073e9SAndroid Build Coastguard Worker [self testModel:@"fused_quant_ops"]; 187*da0073e9SAndroid Build Coastguard Worker} 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker- (void)testTorchScriptBuiltinQuantOps { 190*da0073e9SAndroid Build Coastguard Worker [self testModel:@"torchscript_builtin_ops"]; 191*da0073e9SAndroid Build Coastguard Worker} 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker- (void)testTorchScriptCollectionQuantOps { 194*da0073e9SAndroid Build Coastguard Worker [self testModel:@"torchscript_collection_ops"]; 195*da0073e9SAndroid Build Coastguard Worker} 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker@end 198