xref: /aosp_15_r20/external/executorch/extension/pytree/aten_util/ivalue_util.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/pytree/aten_util/ivalue_util.h>
10 
11 #include <executorch/runtime/platform/assert.h>
12 
13 namespace executorch {
14 namespace extension {
15 
16 using namespace c10;
17 using namespace at;
18 using namespace executorch::extension::pytree;
19 
getContainerHandle(const IValue & data)20 ContainerHandle<IValue> getContainerHandle(const IValue& data) {
21   if (data.isList()) {
22     const auto& values = data.toList();
23     auto c = ContainerHandle<IValue>(Kind::List, values.size());
24     for (size_t i = 0; i < values.size(); ++i) {
25       c[i] = getContainerHandle(values[i]);
26     }
27     return c;
28   }
29 
30   if (data.isTuple()) {
31     const auto& values = data.toTupleRef().elements();
32     auto c = ContainerHandle<IValue>(Kind::Tuple, values.size());
33     for (size_t i = 0; i < values.size(); ++i) {
34       c[i] = getContainerHandle(values[i]);
35     }
36     return c;
37   }
38 
39   if (data.isGenericDict()) {
40     const auto& dict = data.toGenericDict();
41     auto c = ContainerHandle<IValue>(Kind::Dict, dict.size());
42 
43     size_t i = 0;
44     for (const auto& entry : dict) {
45       const auto& key = entry.key().toStringRef();
46       const auto& value = entry.value();
47 
48       c.key(i) = Key(key);
49       c[i] = getContainerHandle(value);
50       ++i;
51     }
52     return c;
53   }
54 
55   return const_cast<IValue*>(&data);
56 }
57 
58 template <std::size_t... Is>
create_tuple_impl(std::index_sequence<Is...>,const std::vector<IValue> & arguments)59 auto create_tuple_impl(
60     std::index_sequence<Is...>,
61     const std::vector<IValue>& arguments) {
62   return std::make_tuple(arguments[Is]...);
63 }
64 
65 template <std::size_t N>
create_tuple(const std::vector<IValue> & arguments)66 auto create_tuple(const std::vector<IValue>& arguments) {
67   return create_tuple_impl(std::make_index_sequence<N>{}, arguments);
68 }
69 
constructTuple(const std::vector<IValue> & ivalues)70 IValue constructTuple(const std::vector<IValue>& ivalues) {
71   switch (ivalues.size()) {
72     case 1:
73       return create_tuple<1>(ivalues);
74     case 2:
75       return create_tuple<2>(ivalues);
76     case 3:
77       return create_tuple<3>(ivalues);
78     case 4:
79       return create_tuple<4>(ivalues);
80     case 5:
81       return create_tuple<5>(ivalues);
82     case 6:
83       return create_tuple<6>(ivalues);
84     case 7:
85       return create_tuple<7>(ivalues);
86     case 8:
87       return create_tuple<8>(ivalues);
88     case 9:
89       return create_tuple<9>(ivalues);
90     case 10:
91       return create_tuple<10>(ivalues);
92   }
93   ET_ASSERT_UNREACHABLE_MSG("Supports at most 10 inputs");
94   return {};
95 }
96 
toIValue(const ContainerHandle<IValue> & c)97 IValue toIValue(const ContainerHandle<IValue>& c) {
98   if (c.isList()) {
99     auto ivalues = c10::impl::GenericList(c10::AnyType::get());
100     for (size_t i = 0; i < c.size(); ++i) {
101       ivalues.emplace_back(toIValue(c[i]));
102     }
103     return ivalues;
104   }
105 
106   if (c.isTuple()) {
107     std::vector<IValue> ivalues;
108     for (size_t i = 0; i < c.size(); ++i) {
109       ivalues.emplace_back(toIValue(c[i]));
110     }
111     return constructTuple(ivalues);
112   }
113 
114   if (c.isDict()) {
115     auto dict =
116         c10::impl::GenericDict(c10::StringType::get(), c10::AnyType::get());
117     for (size_t i = 0; i < c.size(); ++i) {
118       dict.insert(std::string(c.key(i)), toIValue(c[i]));
119     }
120     return dict;
121   }
122 
123   ET_CHECK(c.isLeaf());
124   return {*c.leaf_ptr()};
125 }
126 
flatten(const IValue & data)127 std::pair<std::vector<at::Tensor>, std::unique_ptr<TreeSpec<Empty>>> flatten(
128     const IValue& data) {
129   auto c = getContainerHandle(data);
130 
131   auto p = flatten(c);
132 
133   std::vector<at::Tensor> tensors;
134   for (int i = 0; i < p.first.size(); ++i) {
135     tensors.emplace_back(p.first[i]->toTensor());
136   }
137 
138   return {tensors, std::move(p.second)};
139 }
140 
unflatten(const std::vector<at::Tensor> & tensors,const std::unique_ptr<TreeSpec<Empty>> & tree_spec)141 IValue unflatten(
142     const std::vector<at::Tensor>& tensors,
143     const std::unique_ptr<TreeSpec<Empty>>& tree_spec) {
144   std::vector<IValue> ivalues;
145   for (const auto& tensor : tensors) {
146     ivalues.emplace_back(tensor);
147   }
148   ContainerHandle<IValue> c = unflatten(*tree_spec, ivalues.data());
149   return toIValue(c);
150 }
151 
is_same(const std::vector<at::Tensor> & a,const std::vector<at::Tensor> & b)152 bool is_same(
153     const std::vector<at::Tensor>& a,
154     const std::vector<at::Tensor>& b) {
155   for (int i = 0; i < a.size(); ++i) {
156     if (!at::all(a[i] == b[i]).item<bool>()) {
157       return false;
158     }
159   }
160   return true;
161 }
162 
is_same(const IValue & lhs,const IValue & rhs)163 bool is_same(const IValue& lhs, const IValue& rhs) {
164   if (lhs.isList() && rhs.isList()) {
165     const auto& l = lhs.toList();
166     const auto& r = rhs.toList();
167     if (l.size() != r.size()) {
168       return false;
169     }
170     for (size_t i = 0; i < l.size(); ++i) {
171       if (!is_same(l[i], r[i])) {
172         return false;
173       }
174     }
175     return true;
176   }
177 
178   if (lhs.isTuple() && rhs.isTuple()) {
179     const auto& l = lhs.toTupleRef().elements();
180     const auto& r = rhs.toTupleRef().elements();
181     if (l.size() != r.size()) {
182       return false;
183     }
184     for (size_t i = 0; i < l.size(); ++i) {
185       if (!is_same(l[i], r[i])) {
186         return false;
187       }
188     }
189     return true;
190   }
191 
192   if (lhs.isGenericDict() && rhs.isGenericDict()) {
193     const auto& lhs_dict = lhs.toGenericDict();
194     const auto& rhs_dict = rhs.toGenericDict();
195     if (lhs_dict.size() != rhs_dict.size()) {
196       return false;
197     }
198 
199     for (const auto& entry : lhs_dict) {
200       if (!rhs_dict.contains(entry.key())) {
201         return false;
202       }
203       if (!is_same(entry.value(), rhs_dict.at(entry.key()))) {
204         return false;
205       }
206     }
207     return true;
208   }
209 
210   ET_CHECK(lhs.isTensor() && rhs.isTensor());
211   const auto& l = lhs.toTensor();
212   const auto& r = rhs.toTensor();
213   return at::all(l == r).item<bool>();
214 }
215 
216 } // namespace extension
217 } // namespace executorch
218