xref: /aosp_15_r20/external/pytorch/torch/_inductor/codegen/aoti_runtime/implementation.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // NOTE: Like interface.cpp, this file will be copied into AOTInductor
2 // generated output. This file is intended to keep implementation
3 // details separate from the implementation of the AOTI public
4 // interface. Note also that #includes should go into interface.cpp
5 // for simplicity of maintenance.
6 
7 namespace torch {
8 namespace aot_inductor {
9 template <typename T>
convert_output_to_handle(const ArrayRefTensor<T> & output,AtenTensorHandle & handle)10 void convert_output_to_handle(
11     const ArrayRefTensor<T>& output,
12     AtenTensorHandle& handle) {
13   handle = output.expensiveCopyToTensor();
14 }
15 
16 template <typename... Ts, std::size_t... Is>
convert_outputs_to_handles_helper(const std::tuple<ArrayRefTensor<Ts>...> & outputs,AtenTensorHandle * output_handles,std::index_sequence<Is...>)17 void convert_outputs_to_handles_helper(
18     const std::tuple<ArrayRefTensor<Ts>...>& outputs,
19     AtenTensorHandle* output_handles,
20     std::index_sequence<Is...>) {
21   (convert_output_to_handle(std::get<Is>(outputs), output_handles[Is]), ...);
22 }
23 template <typename... Ts>
convert_outputs_to_handles(const std::tuple<ArrayRefTensor<Ts>...> & outputs,AtenTensorHandle * output_handles)24 void convert_outputs_to_handles(
25     const std::tuple<ArrayRefTensor<Ts>...>& outputs,
26     AtenTensorHandle* output_handles) {
27   convert_outputs_to_handles_helper(
28       outputs, output_handles, std::make_index_sequence<sizeof...(Ts)>());
29 }
30 
31 template <typename T>
convert_handle_to_arrayref_tensor(AtenTensorHandle handle,ArrayRefTensor<T> & input)32 void convert_handle_to_arrayref_tensor(
33     AtenTensorHandle handle,
34     ArrayRefTensor<T>& input) {
35   void* data_ptr;
36   AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle, &data_ptr));
37   int64_t dim;
38   AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(handle, &dim));
39   int64_t numel;
40   AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(handle, &numel));
41   int64_t* sizes;
42   AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(handle, &sizes));
43   int64_t* strides;
44   AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(handle, &strides));
45   int32_t dtype;
46   AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(handle, &dtype));
47   int32_t device_type;
48   AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(handle, &device_type));
49   int32_t device_index;
50   AOTI_TORCH_ERROR_CODE_CHECK(
51       aoti_torch_get_device_index(handle, &device_index));
52 
53   input = ArrayRefTensor<T>(
54       MiniArrayRef<T>(reinterpret_cast<T*>(data_ptr), numel),
55       MiniArrayRef<const int64_t>(sizes, dim),
56       MiniArrayRef<const int64_t>(strides, dim),
57       device_type,
58       device_index);
59 }
60 
61 template <typename... Ts, std::size_t... Is>
convert_handles_to_inputs_helper(AtenTensorHandle * input_handles,std::tuple<ArrayRefTensor<Ts>...> & inputs,std::index_sequence<Is...>)62 void convert_handles_to_inputs_helper(
63     AtenTensorHandle* input_handles,
64     std::tuple<ArrayRefTensor<Ts>...>& inputs,
65     std::index_sequence<Is...>) {
66   (convert_handle_to_arrayref_tensor(input_handles[Is], std::get<Is>(inputs)),
67    ...);
68 }
69 
70 template <typename... Ts>
convert_handles_to_inputs(AtenTensorHandle * input_handles,std::tuple<ArrayRefTensor<Ts>...> & inputs)71 void convert_handles_to_inputs(
72     AtenTensorHandle* input_handles,
73     std::tuple<ArrayRefTensor<Ts>...>& inputs) {
74   convert_handles_to_inputs_helper(
75       input_handles, inputs, std::make_index_sequence<sizeof...(Ts)>());
76 }
77 
78 template <typename T>
assert_numel(const ArrayRefTensor<T> & tensor,uint64_t numel)79 void assert_numel(const ArrayRefTensor<T>& tensor, uint64_t numel) {
80   if (tensor.numel() != numel) {
81     std::stringstream err;
82     err << "incorrect numel for input tensor. expected " << numel << ", got " << tensor.numel();
83     throw std::runtime_error(err.str());
84   }
85 }
86 } // namespace aot_inductor
87 } // namespace torch
88