xref: /aosp_15_r20/external/pytorch/test/custom_backend/test_custom_backend.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/cuda.h>
2 #include <torch/script.h>
3 
4 #include <string>
5 
6 #include "custom_backend.h"
7 
8 // Load a module lowered for the custom backend from \p path and test that
9 // it can be executed and produces correct results.
load_serialized_lowered_module_and_execute(const std::string & path)10 void load_serialized_lowered_module_and_execute(const std::string& path) {
11   torch::jit::Module module = torch::jit::load(path);
12   // The custom backend is hardcoded to compute f(a, b) = (a + b, a - b).
13   auto tensor = torch::ones(5);
14   std::vector<torch::jit::IValue> inputs{tensor, tensor};
15   auto output = module.forward(inputs);
16   AT_ASSERT(output.isTuple());
17   auto output_elements = output.toTupleRef().elements();
18   for (auto& e : output_elements) {
19     AT_ASSERT(e.isTensor());
20   }
21   AT_ASSERT(output_elements.size(), 2);
22   AT_ASSERT(output_elements[0].toTensor().allclose(tensor + tensor));
23   AT_ASSERT(output_elements[1].toTensor().allclose(tensor - tensor));
24 }
25 
main(int argc,const char * argv[])26 int main(int argc, const char* argv[]) {
27   if (argc != 2) {
28     std::cerr
29         << "usage: test_custom_backend <path-to-exported-script-module>\n";
30     return -1;
31   }
32   const std::string path_to_exported_script_module = argv[1];
33 
34   std::cout << "Testing " << torch::custom_backend::getBackendName() << "\n";
35   load_serialized_lowered_module_and_execute(path_to_exported_script_module);
36 
37   std::cout << "OK\n";
38   return 0;
39 }
40