xref: /aosp_15_r20/external/executorch/exir/verification/bindings.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cstddef>
10 #include <cstdio>
11 #include <memory>
12 
13 #include <c10/core/ScalarType.h>
14 #include <c10/macros/Macros.h>
15 #include <c10/util/C++17.h>
16 #include <c10/util/Optional.h>
17 #include <pybind11/pybind11.h>
18 #include <pybind11/stl.h>
19 #include <torch/extension.h> // @manual=//caffe2:torch_extension
20 #include <torch/torch.h> // @manual=//caffe2:torch-cpp-cpu
21 
22 namespace exir {
23 namespace {
24 
25 class DataBuffer {
26  private:
27   void* buffer_ = nullptr;
28 
29  public:
DataBuffer(pybind11::bytes data,int64_t len)30   DataBuffer(pybind11::bytes data, int64_t len) {
31     // allocate buffer
32     buffer_ = malloc(len);
33     // convert data to std::string and copy to buffer
34     std::memcpy(buffer_, (std::string{data}).data(), len);
35   }
~DataBuffer()36   ~DataBuffer() {
37     if (buffer_) {
38       free(buffer_);
39     }
40   }
41   DataBuffer(const DataBuffer&) = delete;
42   DataBuffer& operator=(const DataBuffer&) = delete;
43 
get()44   void* get() {
45     return buffer_;
46   }
47 };
48 } // namespace
49 
PYBIND11_MODULE(bindings,m)50 PYBIND11_MODULE(bindings, m) {
51   pybind11::class_<DataBuffer>(m, "DataBuffer")
52       .def(pybind11::init<pybind11::bytes, int64_t>());
53   m.def(
54       "convert_to_tensor",
55       [&](DataBuffer& data_buffer,
56           const int64_t scalar_type,
57           const std::vector<int64_t>& sizes,
58           const std::vector<int64_t>& strides) {
59         at::ScalarType type_option = static_cast<at::ScalarType>(scalar_type);
60         auto opts = torch::TensorOptions().dtype(type_option);
61 
62         // get tensor from memory using metadata
63         torch::Tensor result =
64             torch::from_blob(data_buffer.get(), sizes, strides, opts);
65         return result;
66       });
67 }
68 } // namespace exir
69