1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/aten_util/aten_bridge.h>
10
11 #include <executorch/runtime/platform/assert.h>
12 #include <cstring>
13
14 namespace executorch {
15 namespace extension {
16
17 namespace {
check_tensor_meta(const at::Tensor & a,const exec_aten::Tensor & b)18 void check_tensor_meta(const at::Tensor& a, const exec_aten::Tensor& b) {
19 // Check sizes/strides pointers
20 ET_CHECK_MSG(
21 b.sizes().data() != nullptr, "ETensor must have valid sizes array");
22 ET_CHECK_MSG(
23 b.strides().data() != nullptr, "ETensor must have valid strides array");
24 // Check disabled because in ASR model we get 1 element tensor with different
25 // rank.
26 /*
27 ET_CHECK_MSG(
28 a.dim() == b.dim(),
29 "at::Tensor and ETensor must have same rank."
30 " at::Tensor rank %zd, ETensor rank %zd.",
31 a.dim(),
32 b.dim());
33 */
34 // check sizes
35 for (size_t i = 0, dims = a.dim(); i < dims; ++i) {
36 ET_CHECK_MSG(
37 a.size(i) == b.size(i),
38 "Sizes dont match at index:%zd, a size %zd != b size %zd",
39 i,
40 ssize_t(a.size(i)),
41 ssize_t(b.size(i)));
42 }
43 // check strides
44 for (size_t i = 0, dims = a.dim(); i < dims; ++i) {
45 // Dont match strides if the size is 1.
46 // Why? Because tensor is non-contig only if
47 // strides dont match product(sizes[i:]) when size(i) > 1
48 // Strong assumption that must be tested and validated.
49 ET_CHECK_MSG(
50 (a.size(i) == 1 || (a.stride(i) == b.strides()[i])),
51 "Strides dont match at index:%zd, a stride %zd != b stride %zd",
52 i,
53 ssize_t(a.stride(i)),
54 ssize_t(b.strides()[i]));
55 }
56 // check dtype
57 ET_CHECK_MSG(
58 b.scalar_type() == torch_to_executorch_scalar_type(a.options().dtype()),
59 "dtypes dont match a %hhd vs. b %hhd",
60 static_cast<int8_t>(torch_to_executorch_scalar_type(a.options().dtype())),
61 static_cast<int8_t>(b.scalar_type()));
62 }
63 } // namespace
64
torch_to_executorch_scalar_type(caffe2::TypeMeta type)65 executorch::runtime::etensor::ScalarType torch_to_executorch_scalar_type(
66 caffe2::TypeMeta type) {
67 const auto intermediate =
68 static_cast<std::underlying_type<c10::ScalarType>::type>(
69 c10::typeMetaToScalarType(type));
70
71 ET_CHECK_MSG(
72 intermediate >= 0 &&
73 intermediate <= static_cast<std::underlying_type<
74 executorch::runtime::etensor::ScalarType>::type>(
75 executorch::runtime::etensor::ScalarType::UInt64),
76 "ScalarType %d unsupported in Executorch",
77 intermediate);
78 return static_cast<executorch::runtime::etensor::ScalarType>(intermediate);
79 }
80
executorch_to_torch_scalar_type(torch::executor::ScalarType type)81 c10::ScalarType executorch_to_torch_scalar_type(
82 torch::executor::ScalarType type) {
83 const auto intermediate = static_cast<
84 std::underlying_type<executorch::runtime::etensor::ScalarType>::type>(
85 type);
86
87 ET_CHECK_MSG(
88 intermediate >= 0 &&
89 intermediate <= static_cast<std::underlying_type<
90 executorch::runtime::etensor::ScalarType>::type>(
91 executorch::runtime::etensor::ScalarType::UInt64),
92 "ScalarType %d unsupported in Executorch",
93 intermediate);
94 return static_cast<c10::ScalarType>(intermediate);
95 }
96
97 /*
98 * Following makes two assumptions:
99 * 1. aten_tensor's lifetime is longer than the liftime within which mutable_et
100 * is consumed
101 * 2. memory previously allocated to mutable_et, is leaked. However under the
102 * assumption , a strong one, that, such memory is arena allocated whose
103 * lifetime is tied to model's lifetime, we assume that memory is not leaked as
104 * it is freed when arean is freed.
105 * @param[in] aten_tensor Input at::Tensor
106 * @param[in/out] mutable_et ETensor whose underlying memory now will alias to
107 * aten_tensor
108 */
alias_etensor_to_attensor(at::Tensor & aten_tensor,torch::executor::Tensor & mutable_et)109 void alias_etensor_to_attensor(
110 at::Tensor& aten_tensor,
111 torch::executor::Tensor& mutable_et) {
112 // TODO(kimishpatel): contiguous according to memformat
113 // Right now we assume everything is channels first contiguous
114 // Note that input tensor must be contiguous for us to alias.
115 // Mixing aliasing and copying is dangerous since if we aliased
116 // the instance of mutatble_et to aten_tensor in the previous call,
117 // then in the next call copying will not be the correct behavior.
118 ET_CHECK_MSG(aten_tensor.is_contiguous(), "Input tensor must be contiguous");
119 check_tensor_meta(aten_tensor, mutable_et);
120 mutable_et.unsafeGetTensorImpl()->set_data(aten_tensor.mutable_data_ptr());
121 }
122
alias_attensor_to_etensor(const torch::executor::Tensor & etensor)123 at::Tensor alias_attensor_to_etensor(const torch::executor::Tensor& etensor) {
124 c10::ScalarType dtype =
125 executorch_to_torch_scalar_type(etensor.scalar_type());
126 std::vector<int64_t> at_tensor_sizes(
127 etensor.sizes().begin(), etensor.sizes().end());
128 std::vector<int64_t> at_tensor_strides(
129 etensor.strides().begin(), etensor.strides().end());
130
131 at::Tensor t = at::from_blob(
132 etensor.mutable_data_ptr(),
133 at_tensor_sizes,
134 at_tensor_strides,
135 at::TensorOptions(dtype));
136
137 check_tensor_meta(t, etensor);
138 return t;
139 }
140
alias_tensor_ptr_to_attensor(at::Tensor & t)141 TensorPtr alias_tensor_ptr_to_attensor(at::Tensor& t) {
142 return make_tensor_ptr(
143 {t.sizes().begin(), t.sizes().end()},
144 t.mutable_data_ptr(),
145 torch::executor::ScalarType(t.scalar_type()));
146 }
147
148 } // namespace extension
149 } // namespace executorch
150