1 #include <ATen/core/ivalue.h> 2 #include <c10/util/Exception.h> 3 #include <torch/csrc/Export.h> 4 #include <torch/csrc/jit/api/module.h> 5 #include <torch/script.h> 6 7 namespace torch { 8 namespace jit { 9 10 #ifdef _MSC_VER 11 #define JIT_TEST_API 12 #else 13 #define JIT_TEST_API TORCH_API 14 #endif 15 16 namespace { 17 isSandcastle()18bool isSandcastle() { 19 return ( 20 (std::getenv("SANDCASTLE")) || 21 (std::getenv("TW_JOB_USER") && 22 std::string(std::getenv("TW_JOB_USER")) == "sandcastle")); 23 } 24 testEvalModeForLoadedModule()25void testEvalModeForLoadedModule() { 26 if (isSandcastle()) 27 return; // The module file to load is not generated in Sandcastle 28 std::string module_path = "dropout_model.pt"; 29 torch::jit::Module module = torch::jit::load(module_path); 30 AT_ASSERT(module.attr("dropout").toModule().is_training()); 31 module.eval(); 32 AT_ASSERT(!module.attr("dropout").toModule().is_training()); 33 module.train(); 34 AT_ASSERT(module.attr("dropout").toModule().is_training()); 35 } 36 37 // TODO: this test never ran before and is broken. 38 // void testSerializationInterop() { 39 // if (isSandcastle()) { 40 // // The module file to load is not generated in Sandcastle 41 // return; 42 // } 43 44 // // This should be generated by `test/cpp/jit/tests_setup.py` 45 // std::ifstream input_stream("ivalue.pt"); 46 // std::vector<char> input; 47 // input.insert( 48 // input.begin(), 49 // std::istream_iterator<char>(input_stream), 50 // std::istream_iterator<char>()); 51 // IValue ivalue = pickle_load(input); 52 53 // auto elements = ivalue.toTupleRef().elements(); 54 // auto ones = torch::ones({2, 2}); 55 // AT_ASSERT(ones.equal(elements.at(0).toTensor())); 56 57 // // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) 58 // auto twos = torch::ones({3, 5}) * 2; 59 // AT_ASSERT(twos.equal(elements.at(1).toTensor())); 60 // } 61 testTorchSaveError()62void testTorchSaveError() { 63 if (isSandcastle()) { 64 // The file to load is not generated in Sandcastle 65 return; 66 } 67 68 // This should be generated by `test/cpp/jit/tests_setup.py` 69 bool passed = true; 70 try { 71 torch::jit::load("eager_value.pt"); 72 passed = false; 73 } catch (const std::exception& c) { 74 } 75 // Ensure torch::jit::load did not run 76 AT_ASSERT(passed); 77 } 78 } // namespace 79 runJITCPPTests()80JIT_TEST_API void runJITCPPTests() { 81 // TODO: this test never ran before and is broken. 82 // testSerializationInterop(); 83 testEvalModeForLoadedModule(); 84 testTorchSaveError(); 85 } 86 } // namespace jit 87 } // namespace torch 88