1 #include <vector> 2 3 #include <torch/csrc/jit/backends/backend.h> 4 #include <torch/csrc/jit/mobile/nnc/context.h> 5 6 namespace torch { 7 namespace jit { 8 namespace mobile { 9 namespace nnc { 10 11 class NNCBackend : public PyTorchBackendInterface { 12 public: 13 explicit NNCBackend() = default; 14 ~NNCBackend() override = default; 15 is_available()16 bool is_available() override { 17 return true; 18 } 19 compile(c10::IValue processed,c10::impl::GenericDict method_compile_spec)20 c10::impl::GenericDict compile( 21 c10::IValue processed, 22 c10::impl::GenericDict method_compile_spec) override { 23 cu_ = std::make_shared<CompilationUnit>(processed); 24 25 // Input method_compile_spec: 26 // Key: method name 27 // Value: compile spec for each method 28 // Output: 29 // Key: method name 30 // Value: a backend handle for each method 31 auto spec = 32 c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec); 33 auto handles = c10::Dict<std::string, std::string>(); 34 for (const auto& it : spec) { 35 // The handle for each method is the key (method name) itself. 36 handles.insert(it.key(), it.key()); 37 } 38 return c10::impl::toGenericDict(handles); 39 } 40 execute(c10::IValue handle,c10::impl::GenericList inputs)41 c10::impl::GenericList execute( 42 c10::IValue handle, 43 c10::impl::GenericList inputs) override { 44 const std::string& method_name = handle.toStringRef(); 45 auto function_name = c10::QualifiedName(method_name); 46 return cu_->run(function_name, inputs); 47 } 48 49 private: 50 std::shared_ptr<CompilationUnit> cu_; 51 }; 52 53 namespace { 54 // TODO(mvz): temporarily disable NNC backend in mobile builds. 55 // static const auto cls = torch::jit::backend<NNCBackend>("nnc"); 56 } // namespace 57 58 } // namespace nnc 59 } // namespace mobile 60 } // namespace jit 61 } // namespace torch 62