1 #include <torch/cuda.h> 2 #include <torch/script.h> 3 4 #include <string> 5 6 #include "custom_backend.h" 7 8 // Load a module lowered for the custom backend from \p path and test that 9 // it can be executed and produces correct results. load_serialized_lowered_module_and_execute(const std::string & path)10void load_serialized_lowered_module_and_execute(const std::string& path) { 11 torch::jit::Module module = torch::jit::load(path); 12 // The custom backend is hardcoded to compute f(a, b) = (a + b, a - b). 13 auto tensor = torch::ones(5); 14 std::vector<torch::jit::IValue> inputs{tensor, tensor}; 15 auto output = module.forward(inputs); 16 AT_ASSERT(output.isTuple()); 17 auto output_elements = output.toTupleRef().elements(); 18 for (auto& e : output_elements) { 19 AT_ASSERT(e.isTensor()); 20 } 21 AT_ASSERT(output_elements.size(), 2); 22 AT_ASSERT(output_elements[0].toTensor().allclose(tensor + tensor)); 23 AT_ASSERT(output_elements[1].toTensor().allclose(tensor - tensor)); 24 } 25 main(int argc,const char * argv[])26int main(int argc, const char* argv[]) { 27 if (argc != 2) { 28 std::cerr 29 << "usage: test_custom_backend <path-to-exported-script-module>\n"; 30 return -1; 31 } 32 const std::string path_to_exported_script_module = argv[1]; 33 34 std::cout << "Testing " << torch::custom_backend::getBackendName() << "\n"; 35 load_serialized_lowered_module_and_execute(path_to_exported_script_module); 36 37 std::cout << "OK\n"; 38 return 0; 39 } 40