xref: /aosp_15_r20/external/pytorch/ios/TestApp/TestAppTests/TestFullJIT.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker#import <XCTest/XCTest.h>
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker#include <torch/script.h>
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Worker@interface TestAppTests : XCTestCase
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker@end
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker@implementation TestAppTests {
10*da0073e9SAndroid Build Coastguard Worker}
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker- (void)testFullJIT {
13*da0073e9SAndroid Build Coastguard Worker  NSString* modelPath = [[NSBundle bundleForClass:[self class]] pathForResource:@"model"
14*da0073e9SAndroid Build Coastguard Worker                                                                         ofType:@"pt"];
15*da0073e9SAndroid Build Coastguard Worker  auto module = torch::jit::load(modelPath.UTF8String);
16*da0073e9SAndroid Build Coastguard Worker  c10::InferenceMode mode;
17*da0073e9SAndroid Build Coastguard Worker  auto input = torch::ones({1, 3, 224, 224}, at::kFloat);
18*da0073e9SAndroid Build Coastguard Worker  auto outputTensor = module.forward({input}).toTensor();
19*da0073e9SAndroid Build Coastguard Worker  XCTAssertTrue(outputTensor.numel() == 1000);
20*da0073e9SAndroid Build Coastguard Worker}
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker@end
23