xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/nnc/backend.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <vector>
2 
3 #include <torch/csrc/jit/backends/backend.h>
4 #include <torch/csrc/jit/mobile/nnc/context.h>
5 
6 namespace torch {
7 namespace jit {
8 namespace mobile {
9 namespace nnc {
10 
11 class NNCBackend : public PyTorchBackendInterface {
12  public:
13   explicit NNCBackend() = default;
14   ~NNCBackend() override = default;
15 
is_available()16   bool is_available() override {
17     return true;
18   }
19 
compile(c10::IValue processed,c10::impl::GenericDict method_compile_spec)20   c10::impl::GenericDict compile(
21       c10::IValue processed,
22       c10::impl::GenericDict method_compile_spec) override {
23     cu_ = std::make_shared<CompilationUnit>(processed);
24 
25     // Input method_compile_spec:
26     //   Key: method name
27     //   Value: compile spec for each method
28     // Output:
29     //   Key: method name
30     //   Value: a backend handle for each method
31     auto spec =
32         c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
33     auto handles = c10::Dict<std::string, std::string>();
34     for (const auto& it : spec) {
35       // The handle for each method is the key (method name) itself.
36       handles.insert(it.key(), it.key());
37     }
38     return c10::impl::toGenericDict(handles);
39   }
40 
execute(c10::IValue handle,c10::impl::GenericList inputs)41   c10::impl::GenericList execute(
42       c10::IValue handle,
43       c10::impl::GenericList inputs) override {
44     const std::string& method_name = handle.toStringRef();
45     auto function_name = c10::QualifiedName(method_name);
46     return cu_->run(function_name, inputs);
47   }
48 
49  private:
50   std::shared_ptr<CompilationUnit> cu_;
51 };
52 
53 namespace {
54 // TODO(mvz): temporarily disable NNC backend in mobile builds.
55 // static const auto cls = torch::jit::backend<NNCBackend>("nnc");
56 } // namespace
57 
58 } // namespace nnc
59 } // namespace mobile
60 } // namespace jit
61 } // namespace torch
62