1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/pytree/aten_util/ivalue_util.h>
10
11 #include <executorch/runtime/platform/assert.h>
12
13 namespace executorch {
14 namespace extension {
15
16 using namespace c10;
17 using namespace at;
18 using namespace executorch::extension::pytree;
19
getContainerHandle(const IValue & data)20 ContainerHandle<IValue> getContainerHandle(const IValue& data) {
21 if (data.isList()) {
22 const auto& values = data.toList();
23 auto c = ContainerHandle<IValue>(Kind::List, values.size());
24 for (size_t i = 0; i < values.size(); ++i) {
25 c[i] = getContainerHandle(values[i]);
26 }
27 return c;
28 }
29
30 if (data.isTuple()) {
31 const auto& values = data.toTupleRef().elements();
32 auto c = ContainerHandle<IValue>(Kind::Tuple, values.size());
33 for (size_t i = 0; i < values.size(); ++i) {
34 c[i] = getContainerHandle(values[i]);
35 }
36 return c;
37 }
38
39 if (data.isGenericDict()) {
40 const auto& dict = data.toGenericDict();
41 auto c = ContainerHandle<IValue>(Kind::Dict, dict.size());
42
43 size_t i = 0;
44 for (const auto& entry : dict) {
45 const auto& key = entry.key().toStringRef();
46 const auto& value = entry.value();
47
48 c.key(i) = Key(key);
49 c[i] = getContainerHandle(value);
50 ++i;
51 }
52 return c;
53 }
54
55 return const_cast<IValue*>(&data);
56 }
57
58 template <std::size_t... Is>
create_tuple_impl(std::index_sequence<Is...>,const std::vector<IValue> & arguments)59 auto create_tuple_impl(
60 std::index_sequence<Is...>,
61 const std::vector<IValue>& arguments) {
62 return std::make_tuple(arguments[Is]...);
63 }
64
65 template <std::size_t N>
create_tuple(const std::vector<IValue> & arguments)66 auto create_tuple(const std::vector<IValue>& arguments) {
67 return create_tuple_impl(std::make_index_sequence<N>{}, arguments);
68 }
69
constructTuple(const std::vector<IValue> & ivalues)70 IValue constructTuple(const std::vector<IValue>& ivalues) {
71 switch (ivalues.size()) {
72 case 1:
73 return create_tuple<1>(ivalues);
74 case 2:
75 return create_tuple<2>(ivalues);
76 case 3:
77 return create_tuple<3>(ivalues);
78 case 4:
79 return create_tuple<4>(ivalues);
80 case 5:
81 return create_tuple<5>(ivalues);
82 case 6:
83 return create_tuple<6>(ivalues);
84 case 7:
85 return create_tuple<7>(ivalues);
86 case 8:
87 return create_tuple<8>(ivalues);
88 case 9:
89 return create_tuple<9>(ivalues);
90 case 10:
91 return create_tuple<10>(ivalues);
92 }
93 ET_ASSERT_UNREACHABLE_MSG("Supports at most 10 inputs");
94 return {};
95 }
96
toIValue(const ContainerHandle<IValue> & c)97 IValue toIValue(const ContainerHandle<IValue>& c) {
98 if (c.isList()) {
99 auto ivalues = c10::impl::GenericList(c10::AnyType::get());
100 for (size_t i = 0; i < c.size(); ++i) {
101 ivalues.emplace_back(toIValue(c[i]));
102 }
103 return ivalues;
104 }
105
106 if (c.isTuple()) {
107 std::vector<IValue> ivalues;
108 for (size_t i = 0; i < c.size(); ++i) {
109 ivalues.emplace_back(toIValue(c[i]));
110 }
111 return constructTuple(ivalues);
112 }
113
114 if (c.isDict()) {
115 auto dict =
116 c10::impl::GenericDict(c10::StringType::get(), c10::AnyType::get());
117 for (size_t i = 0; i < c.size(); ++i) {
118 dict.insert(std::string(c.key(i)), toIValue(c[i]));
119 }
120 return dict;
121 }
122
123 ET_CHECK(c.isLeaf());
124 return {*c.leaf_ptr()};
125 }
126
flatten(const IValue & data)127 std::pair<std::vector<at::Tensor>, std::unique_ptr<TreeSpec<Empty>>> flatten(
128 const IValue& data) {
129 auto c = getContainerHandle(data);
130
131 auto p = flatten(c);
132
133 std::vector<at::Tensor> tensors;
134 for (int i = 0; i < p.first.size(); ++i) {
135 tensors.emplace_back(p.first[i]->toTensor());
136 }
137
138 return {tensors, std::move(p.second)};
139 }
140
unflatten(const std::vector<at::Tensor> & tensors,const std::unique_ptr<TreeSpec<Empty>> & tree_spec)141 IValue unflatten(
142 const std::vector<at::Tensor>& tensors,
143 const std::unique_ptr<TreeSpec<Empty>>& tree_spec) {
144 std::vector<IValue> ivalues;
145 for (const auto& tensor : tensors) {
146 ivalues.emplace_back(tensor);
147 }
148 ContainerHandle<IValue> c = unflatten(*tree_spec, ivalues.data());
149 return toIValue(c);
150 }
151
is_same(const std::vector<at::Tensor> & a,const std::vector<at::Tensor> & b)152 bool is_same(
153 const std::vector<at::Tensor>& a,
154 const std::vector<at::Tensor>& b) {
155 for (int i = 0; i < a.size(); ++i) {
156 if (!at::all(a[i] == b[i]).item<bool>()) {
157 return false;
158 }
159 }
160 return true;
161 }
162
is_same(const IValue & lhs,const IValue & rhs)163 bool is_same(const IValue& lhs, const IValue& rhs) {
164 if (lhs.isList() && rhs.isList()) {
165 const auto& l = lhs.toList();
166 const auto& r = rhs.toList();
167 if (l.size() != r.size()) {
168 return false;
169 }
170 for (size_t i = 0; i < l.size(); ++i) {
171 if (!is_same(l[i], r[i])) {
172 return false;
173 }
174 }
175 return true;
176 }
177
178 if (lhs.isTuple() && rhs.isTuple()) {
179 const auto& l = lhs.toTupleRef().elements();
180 const auto& r = rhs.toTupleRef().elements();
181 if (l.size() != r.size()) {
182 return false;
183 }
184 for (size_t i = 0; i < l.size(); ++i) {
185 if (!is_same(l[i], r[i])) {
186 return false;
187 }
188 }
189 return true;
190 }
191
192 if (lhs.isGenericDict() && rhs.isGenericDict()) {
193 const auto& lhs_dict = lhs.toGenericDict();
194 const auto& rhs_dict = rhs.toGenericDict();
195 if (lhs_dict.size() != rhs_dict.size()) {
196 return false;
197 }
198
199 for (const auto& entry : lhs_dict) {
200 if (!rhs_dict.contains(entry.key())) {
201 return false;
202 }
203 if (!is_same(entry.value(), rhs_dict.at(entry.key()))) {
204 return false;
205 }
206 }
207 return true;
208 }
209
210 ET_CHECK(lhs.isTensor() && rhs.isTensor());
211 const auto& l = lhs.toTensor();
212 const auto& r = rhs.toTensor();
213 return at::all(l == r).item<bool>();
214 }
215
216 } // namespace extension
217 } // namespace executorch
218