xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/coreml/cpp/preprocess.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Meta Platforms, Inc. and affiliates.
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <pybind11/pybind11.h>
7 #include <torch/csrc/jit/backends/backend.h>
8 #include <torch/csrc/jit/backends/backend_preprocess.h>
9 #include <torch/csrc/jit/python/pybind_utils.h>
10 #include <torch/csrc/utils/pybind.h>
11 #include <torch/script.h>
12 
13 namespace py = pybind11;
14 
15 namespace {
16 
preprocess(const torch::jit::Module & mod,const c10::Dict<c10::IValue,c10::IValue> & method_compile_spec,const torch::jit::BackendDebugHandleGenerator & generate_debug_handles)17 c10::IValue preprocess(
18     const torch::jit::Module& mod,
19     const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec,
20     const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) {
21   py::object pyModule =
22       py::module_::import("torch.backends._coreml.preprocess");
23   py::object pyMethod = pyModule.attr("preprocess");
24 
25   py::dict modelDict =
26       pyMethod(mod, torch::jit::toPyObject(method_compile_spec));
27 
28   c10::Dict<std::string, std::string> modelData;
29   for (auto item : modelDict) {
30     modelData.insert(
31         item.first.cast<std::string>(), item.second.cast<std::string>());
32   }
33   return modelData;
34 }
35 
36 static auto pre_reg =
37     torch::jit::backend_preprocess_register("coreml", preprocess);
38 
39 } // namespace
40