1 #include <ATen/Functions.h>
2 #include <gtest/gtest.h>
3 #include <torch/csrc/jit/backends/backend.h>
4 #include <torch/csrc/jit/backends/backend_detail.h>
5 #include <torch/csrc/jit/backends/backend_preprocess.h>
6 #include <torch/csrc/jit/frontend/resolver.h>
7 #include <torch/csrc/jit/mobile/import.h>
8 #include <torch/csrc/jit/mobile/module.h>
9 #include <torch/csrc/jit/mobile/nnc/context.h>
10 #include <torch/csrc/jit/mobile/nnc/registry.h>
11 #include <torch/csrc/jit/passes/freeze_module.h>
12 #include <torch/custom_class.h>
13 #include <torch/script.h>
14
15 namespace torch {
16 namespace jit {
17 namespace mobile {
18 namespace nnc {
19
20 namespace {
21
create_compile_spec(const std::string & method_name,const std::string & model_name,const std::string & input_shapes,const std::string & input_types,const std::string & memory_formats,const std::string & dynamic_sizes)22 c10::Dict<c10::IValue, c10::IValue> create_compile_spec(
23 const std::string& method_name,
24 const std::string& model_name,
25 const std::string& input_shapes,
26 const std::string& input_types,
27 const std::string& memory_formats,
28 const std::string& dynamic_sizes) {
29 c10::Dict<c10::IValue, c10::IValue> method_spec(
30 c10::StringType::get(), c10::AnyType::get());
31
32 method_spec.insert("sizes", input_shapes);
33 method_spec.insert("types", input_types);
34 method_spec.insert("model_name", model_name);
35 method_spec.insert("model_version", "v1");
36 method_spec.insert("asmfile", "fake_nnc_model.s");
37 method_spec.insert("arch", "x86-64");
38 method_spec.insert("memory_formats", memory_formats);
39 method_spec.insert("dynamic_sizes", dynamic_sizes);
40
41 c10::Dict<c10::IValue, c10::IValue> compile_spec(
42 c10::StringType::get(), c10::AnyType::get());
43 compile_spec.insert(method_name, method_spec);
44 return compile_spec;
45 }
46
47 } // namespace
48
49 extern "C" {
50
51 // The test kernels are supposed to be generated by the NNC compiler ahead-of-
52 // time. For integration test purpose we manually wrote instead.
add_kernel(void ** args)53 int add_kernel(void** args) {
54 // out = input + param
55 at::Tensor input = at::from_blob(args[0], {4, 4}, at::kFloat);
56 at::Tensor out = at::from_blob(args[1], {4, 4}, at::kFloat);
57 at::Tensor param = at::from_blob(args[2], {1}, at::kFloat);
58 out.copy_(at::add(input, param));
59 return 0;
60 }
61
62 } // extern "C"
63
64 REGISTER_NNC_KERNEL(
65 "_add_kernel_nnc_fake_model:v1:forward:VERTOKEN",
66 add_kernel)
67
TEST(DISABLED_NNCBackendTest,AOTCompileThenExecute)68 TEST(DISABLED_NNCBackendTest, AOTCompileThenExecute) {
69 torch::jit::Module m("m");
70 auto param = torch::ones({1});
71 m.register_parameter("param", param, false);
72 m.define(R"(
73 def forward(self, input):
74 return input + self.param
75 )");
76
77 // Run the TorchScript module to get reference result.
78 std::vector<IValue> inputs;
79 inputs.emplace_back(2.0 * torch::ones({4, 4}));
80 auto reference = m.forward(inputs);
81
82 // Compile the model with NNC.
83 auto compile_spec = create_compile_spec(
84 "forward", "_add_kernel_nnc_fake_model", "4,4", "float", "", "");
85 auto any_dict_ty =
86 c10::DictType::create(c10::StringType::get(), c10::AnyType::get());
87 auto frozen_m = torch::jit::freeze_module(m.clone());
88 auto compiled_module = torch::jit::detail::codegen_backend_module(
89 "nnc", frozen_m, compile_spec, any_dict_ty);
90
91 // Save the compiled model.
92 std::stringstream ss;
93 compiled_module._save_for_mobile(ss);
94
95 // Load and run the saved model.
96 auto loaded_module = _load_for_mobile(ss);
97 auto result = loaded_module.forward(inputs);
98 EXPECT_TRUE(result.toTensor().equal(3.0 * torch::ones({4, 4})));
99 EXPECT_TRUE(result.toTensor().equal(reference.toTensor()));
100 EXPECT_EQ(remove("fake_nnc_model.s"), 0);
101 }
102
103 } // namespace nnc
104 } // namespace mobile
105 } // namespace jit
106 } // namespace torch
107