xref: /aosp_15_r20/external/pytorch/test/cpp/jit/torch_python_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/ivalue.h>
2 #include <c10/util/Exception.h>
3 #include <torch/csrc/Export.h>
4 #include <torch/csrc/jit/api/module.h>
5 #include <torch/script.h>
6 
7 namespace torch {
8 namespace jit {
9 
10 #ifdef _MSC_VER
11 #define JIT_TEST_API
12 #else
13 #define JIT_TEST_API TORCH_API
14 #endif
15 
16 namespace {
17 
isSandcastle()18 bool isSandcastle() {
19   return (
20       (std::getenv("SANDCASTLE")) ||
21       (std::getenv("TW_JOB_USER") &&
22        std::string(std::getenv("TW_JOB_USER")) == "sandcastle"));
23 }
24 
testEvalModeForLoadedModule()25 void testEvalModeForLoadedModule() {
26   if (isSandcastle())
27     return; // The module file to load is not generated in Sandcastle
28   std::string module_path = "dropout_model.pt";
29   torch::jit::Module module = torch::jit::load(module_path);
30   AT_ASSERT(module.attr("dropout").toModule().is_training());
31   module.eval();
32   AT_ASSERT(!module.attr("dropout").toModule().is_training());
33   module.train();
34   AT_ASSERT(module.attr("dropout").toModule().is_training());
35 }
36 
37 // TODO: this test never ran before and is broken.
38 // void testSerializationInterop() {
39 //   if (isSandcastle()) {
40 //     // The module file to load is not generated in Sandcastle
41 //     return;
42 //   }
43 
44 //   // This should be generated by `test/cpp/jit/tests_setup.py`
45 //   std::ifstream input_stream("ivalue.pt");
46 //   std::vector<char> input;
47 //   input.insert(
48 //       input.begin(),
49 //       std::istream_iterator<char>(input_stream),
50 //       std::istream_iterator<char>());
51 //   IValue ivalue = pickle_load(input);
52 
53 //   auto elements = ivalue.toTupleRef().elements();
54 //   auto ones = torch::ones({2, 2});
55 //   AT_ASSERT(ones.equal(elements.at(0).toTensor()));
56 
57 //   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
58 //   auto twos = torch::ones({3, 5}) * 2;
59 //   AT_ASSERT(twos.equal(elements.at(1).toTensor()));
60 // }
61 
testTorchSaveError()62 void testTorchSaveError() {
63   if (isSandcastle()) {
64     // The file to load is not generated in Sandcastle
65     return;
66   }
67 
68   // This should be generated by `test/cpp/jit/tests_setup.py`
69   bool passed = true;
70   try {
71     torch::jit::load("eager_value.pt");
72     passed = false;
73   } catch (const std::exception& c) {
74   }
75   // Ensure torch::jit::load did not run
76   AT_ASSERT(passed);
77 }
78 } // namespace
79 
runJITCPPTests()80 JIT_TEST_API void runJITCPPTests() {
81   // TODO: this test never ran before and is broken.
82   // testSerializationInterop();
83   testEvalModeForLoadedModule();
84   testTorchSaveError();
85 }
86 } // namespace jit
87 } // namespace torch
88