xref: /aosp_15_r20/external/pytorch/test/mobile/custom_build/predictor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // This is a simple predictor binary that loads a TorchScript CV model and runs
2 // a forward pass with fixed input `torch::ones({1, 3, 224, 224})`.
3 // It's used for end-to-end integration test for custom mobile build.
4 
5 #include <iostream>
6 #include <string>
7 #include <c10/util/irange.h>
8 #include <torch/script.h>
9 
10 using namespace std;
11 
12 namespace {
13 
14 struct MobileCallGuard {
15   // Set InferenceMode for inference only use case.
16   c10::InferenceMode guard;
17   // Disable graph optimizer to ensure list of unused ops are not changed for
18   // custom mobile build.
19   torch::jit::GraphOptimizerEnabledGuard no_optimizer_guard{false};
20 };
21 
loadModel(const std::string & path)22 torch::jit::Module loadModel(const std::string& path) {
23   MobileCallGuard guard;
24   auto module = torch::jit::load(path);
25   module.eval();
26   return module;
27 }
28 
29 } // namespace
30 
main(int argc,const char * argv[])31 int main(int argc, const char* argv[]) {
32   if (argc < 2) {
33     std::cerr << "Usage: " << argv[0] << " <model_path>\n";
34     return 1;
35   }
36   auto module = loadModel(argv[1]);
37   auto input = torch::ones({1, 3, 224, 224});
38   auto output = [&]() {
39     MobileCallGuard guard;
40     return module.forward({input}).toTensor();
41   }();
42 
43   std::cout << std::setprecision(3) << std::fixed;
44   for (const auto i : c10::irange(5)) {
45     std::cout << output.data_ptr<float>()[i] << std::endl;
46   }
47   return 0;
48 }
49