1 // Copyright (c) Meta Platforms, Inc. and affiliates.
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <pybind11/pybind11.h>
7 #include <torch/csrc/jit/backends/backend.h>
8 #include <torch/csrc/jit/backends/backend_preprocess.h>
9 #include <torch/csrc/jit/python/pybind_utils.h>
10 #include <torch/csrc/utils/pybind.h>
11 #include <torch/script.h>
12
13 namespace py = pybind11;
14
15 namespace {
16
preprocess(const torch::jit::Module & mod,const c10::Dict<c10::IValue,c10::IValue> & method_compile_spec,const torch::jit::BackendDebugHandleGenerator & generate_debug_handles)17 c10::IValue preprocess(
18 const torch::jit::Module& mod,
19 const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec,
20 const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) {
21 py::object pyModule =
22 py::module_::import("torch.backends._coreml.preprocess");
23 py::object pyMethod = pyModule.attr("preprocess");
24
25 py::dict modelDict =
26 pyMethod(mod, torch::jit::toPyObject(method_compile_spec));
27
28 c10::Dict<std::string, std::string> modelData;
29 for (auto item : modelDict) {
30 modelData.insert(
31 item.first.cast<std::string>(), item.second.cast<std::string>());
32 }
33 return modelData;
34 }
35
36 static auto pre_reg =
37 torch::jit::backend_preprocess_register("coreml", preprocess);
38
39 } // namespace
40