xref: /aosp_15_r20/external/pytorch/test/mobile/nnc/test_nnc_backend.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Functions.h>
2 #include <gtest/gtest.h>
3 #include <torch/csrc/jit/backends/backend.h>
4 #include <torch/csrc/jit/backends/backend_detail.h>
5 #include <torch/csrc/jit/backends/backend_preprocess.h>
6 #include <torch/csrc/jit/frontend/resolver.h>
7 #include <torch/csrc/jit/mobile/import.h>
8 #include <torch/csrc/jit/mobile/module.h>
9 #include <torch/csrc/jit/mobile/nnc/context.h>
10 #include <torch/csrc/jit/mobile/nnc/registry.h>
11 #include <torch/csrc/jit/passes/freeze_module.h>
12 #include <torch/custom_class.h>
13 #include <torch/script.h>
14 
15 namespace torch {
16 namespace jit {
17 namespace mobile {
18 namespace nnc {
19 
20 namespace {
21 
create_compile_spec(const std::string & method_name,const std::string & model_name,const std::string & input_shapes,const std::string & input_types,const std::string & memory_formats,const std::string & dynamic_sizes)22 c10::Dict<c10::IValue, c10::IValue> create_compile_spec(
23     const std::string& method_name,
24     const std::string& model_name,
25     const std::string& input_shapes,
26     const std::string& input_types,
27     const std::string& memory_formats,
28     const std::string& dynamic_sizes) {
29   c10::Dict<c10::IValue, c10::IValue> method_spec(
30       c10::StringType::get(), c10::AnyType::get());
31 
32   method_spec.insert("sizes", input_shapes);
33   method_spec.insert("types", input_types);
34   method_spec.insert("model_name", model_name);
35   method_spec.insert("model_version", "v1");
36   method_spec.insert("asmfile", "fake_nnc_model.s");
37   method_spec.insert("arch", "x86-64");
38   method_spec.insert("memory_formats", memory_formats);
39   method_spec.insert("dynamic_sizes", dynamic_sizes);
40 
41   c10::Dict<c10::IValue, c10::IValue> compile_spec(
42       c10::StringType::get(), c10::AnyType::get());
43   compile_spec.insert(method_name, method_spec);
44   return compile_spec;
45 }
46 
47 } // namespace
48 
49 extern "C" {
50 
51 // The test kernels are supposed to be generated by the NNC compiler ahead-of-
52 // time. For integration test purpose we manually wrote instead.
add_kernel(void ** args)53 int add_kernel(void** args) {
54   // out = input + param
55   at::Tensor input = at::from_blob(args[0], {4, 4}, at::kFloat);
56   at::Tensor out = at::from_blob(args[1], {4, 4}, at::kFloat);
57   at::Tensor param = at::from_blob(args[2], {1}, at::kFloat);
58   out.copy_(at::add(input, param));
59   return 0;
60 }
61 
62 } // extern "C"
63 
64 REGISTER_NNC_KERNEL(
65     "_add_kernel_nnc_fake_model:v1:forward:VERTOKEN",
66     add_kernel)
67 
TEST(DISABLED_NNCBackendTest,AOTCompileThenExecute)68 TEST(DISABLED_NNCBackendTest, AOTCompileThenExecute) {
69   torch::jit::Module m("m");
70   auto param = torch::ones({1});
71   m.register_parameter("param", param, false);
72   m.define(R"(
73     def forward(self, input):
74         return input + self.param
75   )");
76 
77   // Run the TorchScript module to get reference result.
78   std::vector<IValue> inputs;
79   inputs.emplace_back(2.0 * torch::ones({4, 4}));
80   auto reference = m.forward(inputs);
81 
82   // Compile the model with NNC.
83   auto compile_spec = create_compile_spec(
84       "forward", "_add_kernel_nnc_fake_model", "4,4", "float", "", "");
85   auto any_dict_ty =
86       c10::DictType::create(c10::StringType::get(), c10::AnyType::get());
87   auto frozen_m = torch::jit::freeze_module(m.clone());
88   auto compiled_module = torch::jit::detail::codegen_backend_module(
89       "nnc", frozen_m, compile_spec, any_dict_ty);
90 
91   // Save the compiled model.
92   std::stringstream ss;
93   compiled_module._save_for_mobile(ss);
94 
95   // Load and run the saved model.
96   auto loaded_module = _load_for_mobile(ss);
97   auto result = loaded_module.forward(inputs);
98   EXPECT_TRUE(result.toTensor().equal(3.0 * torch::ones({4, 4})));
99   EXPECT_TRUE(result.toTensor().equal(reference.toTensor()));
100   EXPECT_EQ(remove("fake_nnc_model.s"), 0);
101 }
102 
103 } // namespace nnc
104 } // namespace mobile
105 } // namespace jit
106 } // namespace torch
107