xref: /aosp_15_r20/external/executorch/extension/pytree/pytree.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #pragma once
10*523fa7a6SAndroid Build Coastguard Worker 
11*523fa7a6SAndroid Build Coastguard Worker #include <ctype.h>
12*523fa7a6SAndroid Build Coastguard Worker #include <algorithm>
13*523fa7a6SAndroid Build Coastguard Worker #include <cassert>
14*523fa7a6SAndroid Build Coastguard Worker #include <cstdint>
15*523fa7a6SAndroid Build Coastguard Worker #include <cstring>
16*523fa7a6SAndroid Build Coastguard Worker #include <memory>
17*523fa7a6SAndroid Build Coastguard Worker #include <string>
18*523fa7a6SAndroid Build Coastguard Worker #include <variant>
19*523fa7a6SAndroid Build Coastguard Worker 
20*523fa7a6SAndroid Build Coastguard Worker // NB: This is a local, pytree FunctionRef and not from the ExecuTorch runtime.
21*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/pytree/function_ref.h>
22*523fa7a6SAndroid Build Coastguard Worker 
23*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
24*523fa7a6SAndroid Build Coastguard Worker namespace extension {
25*523fa7a6SAndroid Build Coastguard Worker namespace pytree {
26*523fa7a6SAndroid Build Coastguard Worker 
pytree_assert(bool must_be_true)27*523fa7a6SAndroid Build Coastguard Worker inline void pytree_assert(bool must_be_true) {
28*523fa7a6SAndroid Build Coastguard Worker   assert(must_be_true);
29*523fa7a6SAndroid Build Coastguard Worker }
30*523fa7a6SAndroid Build Coastguard Worker 
31*523fa7a6SAndroid Build Coastguard Worker #ifdef _MSC_VER
32*523fa7a6SAndroid Build Coastguard Worker #define EXECUTORCH_ALWAYS_INLINE __forceinline
33*523fa7a6SAndroid Build Coastguard Worker #elif defined(__GNUC__)
34*523fa7a6SAndroid Build Coastguard Worker #define EXECUTORCH_ALWAYS_INLINE inline __attribute__((__always_inline__))
35*523fa7a6SAndroid Build Coastguard Worker #else
36*523fa7a6SAndroid Build Coastguard Worker #define EXECUTORCH_ALWAYS_INLINE inline
37*523fa7a6SAndroid Build Coastguard Worker #endif
38*523fa7a6SAndroid Build Coastguard Worker 
pytree_unreachable()39*523fa7a6SAndroid Build Coastguard Worker [[noreturn]] EXECUTORCH_ALWAYS_INLINE void pytree_unreachable() {
40*523fa7a6SAndroid Build Coastguard Worker   assert(false);
41*523fa7a6SAndroid Build Coastguard Worker #if defined(__GNUC__)
42*523fa7a6SAndroid Build Coastguard Worker   __builtin_unreachable();
43*523fa7a6SAndroid Build Coastguard Worker #elif defined(_MSC_VER)
44*523fa7a6SAndroid Build Coastguard Worker   __assume(0);
45*523fa7a6SAndroid Build Coastguard Worker #else
46*523fa7a6SAndroid Build Coastguard Worker   while (!0)
47*523fa7a6SAndroid Build Coastguard Worker     ;
48*523fa7a6SAndroid Build Coastguard Worker #endif
49*523fa7a6SAndroid Build Coastguard Worker }
50*523fa7a6SAndroid Build Coastguard Worker 
51*523fa7a6SAndroid Build Coastguard Worker enum class Kind : uint8_t { List, Tuple, NamedTuple, Dict, Leaf, Custom, None };
52*523fa7a6SAndroid Build Coastguard Worker 
53*523fa7a6SAndroid Build Coastguard Worker using KeyStr = std::string;
54*523fa7a6SAndroid Build Coastguard Worker using KeyInt = int32_t;
55*523fa7a6SAndroid Build Coastguard Worker 
56*523fa7a6SAndroid Build Coastguard Worker struct Key {
57*523fa7a6SAndroid Build Coastguard Worker   enum class Kind : uint8_t { None, Int, Str } kind_;
58*523fa7a6SAndroid Build Coastguard Worker 
59*523fa7a6SAndroid Build Coastguard Worker  private:
60*523fa7a6SAndroid Build Coastguard Worker   std::variant<std::monostate, KeyInt, KeyStr> repr_;
61*523fa7a6SAndroid Build Coastguard Worker 
62*523fa7a6SAndroid Build Coastguard Worker  public:
KeyKey63*523fa7a6SAndroid Build Coastguard Worker   Key() {}
KeyKey64*523fa7a6SAndroid Build Coastguard Worker   /*implicit*/ Key(KeyInt key) : repr_(key) {}
KeyKey65*523fa7a6SAndroid Build Coastguard Worker   /*implicit*/ Key(KeyStr key) : repr_(std::move(key)) {}
66*523fa7a6SAndroid Build Coastguard Worker 
kindKey67*523fa7a6SAndroid Build Coastguard Worker   Kind kind() const {
68*523fa7a6SAndroid Build Coastguard Worker     return static_cast<Kind>(repr_.index());
69*523fa7a6SAndroid Build Coastguard Worker   }
70*523fa7a6SAndroid Build Coastguard Worker 
as_intKey71*523fa7a6SAndroid Build Coastguard Worker   KeyInt as_int() const {
72*523fa7a6SAndroid Build Coastguard Worker     return std::get<KeyInt>(repr_);
73*523fa7a6SAndroid Build Coastguard Worker   }
74*523fa7a6SAndroid Build Coastguard Worker 
KeyIntKey75*523fa7a6SAndroid Build Coastguard Worker   operator KeyInt() const {
76*523fa7a6SAndroid Build Coastguard Worker     return as_int();
77*523fa7a6SAndroid Build Coastguard Worker   }
78*523fa7a6SAndroid Build Coastguard Worker 
as_strKey79*523fa7a6SAndroid Build Coastguard Worker   const KeyStr& as_str() const {
80*523fa7a6SAndroid Build Coastguard Worker     return std::get<KeyStr>(repr_);
81*523fa7a6SAndroid Build Coastguard Worker   }
82*523fa7a6SAndroid Build Coastguard Worker 
83*523fa7a6SAndroid Build Coastguard Worker   operator const KeyStr&() const {
84*523fa7a6SAndroid Build Coastguard Worker     return as_str();
85*523fa7a6SAndroid Build Coastguard Worker   }
86*523fa7a6SAndroid Build Coastguard Worker 
87*523fa7a6SAndroid Build Coastguard Worker   bool operator==(const Key& rhs) const {
88*523fa7a6SAndroid Build Coastguard Worker     return repr_ == rhs.repr_;
89*523fa7a6SAndroid Build Coastguard Worker   }
90*523fa7a6SAndroid Build Coastguard Worker 
91*523fa7a6SAndroid Build Coastguard Worker   bool operator!=(const Key& rhs) const {
92*523fa7a6SAndroid Build Coastguard Worker     return !operator==(rhs);
93*523fa7a6SAndroid Build Coastguard Worker   }
94*523fa7a6SAndroid Build Coastguard Worker };
95*523fa7a6SAndroid Build Coastguard Worker 
96*523fa7a6SAndroid Build Coastguard Worker struct Empty {};
97*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux = Empty>
98*523fa7a6SAndroid Build Coastguard Worker struct ContainerHandle;
99*523fa7a6SAndroid Build Coastguard Worker 
100*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux = Empty>
101*523fa7a6SAndroid Build Coastguard Worker struct Container final : public Aux {
102*523fa7a6SAndroid Build Coastguard Worker   using handle_type = ContainerHandle<T, Aux>;
103*523fa7a6SAndroid Build Coastguard Worker   using leaf_type = T;
104*523fa7a6SAndroid Build Coastguard Worker 
105*523fa7a6SAndroid Build Coastguard Worker   Kind kind = Kind::None;
106*523fa7a6SAndroid Build Coastguard Worker   size_t size = 0;
107*523fa7a6SAndroid Build Coastguard Worker   leaf_type* leaf = nullptr;
108*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<handle_type[]> items;
109*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<Key[]> keys;
110*523fa7a6SAndroid Build Coastguard Worker   std::string custom_type;
111*523fa7a6SAndroid Build Coastguard Worker   // internal only field to keep associated to every node meta info
112*523fa7a6SAndroid Build Coastguard Worker   mutable size_t leaves_num = 0u;
113*523fa7a6SAndroid Build Coastguard Worker 
114*523fa7a6SAndroid Build Coastguard Worker   /*implicit*/ Container(Kind kind, size_t size = 0u)
kindfinal115*523fa7a6SAndroid Build Coastguard Worker       : kind(kind),
116*523fa7a6SAndroid Build Coastguard Worker         size(size),
117*523fa7a6SAndroid Build Coastguard Worker         items(std::unique_ptr<handle_type[]>(new handle_type[size])) {
118*523fa7a6SAndroid Build Coastguard Worker     if (kind == Kind::Dict) {
119*523fa7a6SAndroid Build Coastguard Worker       keys = std::unique_ptr<Key[]>(new Key[size]);
120*523fa7a6SAndroid Build Coastguard Worker     }
121*523fa7a6SAndroid Build Coastguard Worker   }
Containerfinal122*523fa7a6SAndroid Build Coastguard Worker   /*implicit*/ Container(leaf_type* leaf)
123*523fa7a6SAndroid Build Coastguard Worker       : kind(Kind::Leaf), size(0u), leaf(leaf), leaves_num(1u) {}
124*523fa7a6SAndroid Build Coastguard Worker   Container(const Container&) = delete;
125*523fa7a6SAndroid Build Coastguard Worker   Container& operator=(const Container&) = delete;
126*523fa7a6SAndroid Build Coastguard Worker };
127*523fa7a6SAndroid Build Coastguard Worker 
128*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux>
129*523fa7a6SAndroid Build Coastguard Worker struct ContainerHandle {
130*523fa7a6SAndroid Build Coastguard Worker   using container_type = Container<T, Aux>;
131*523fa7a6SAndroid Build Coastguard Worker   using leaf_type = T;
132*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<container_type> handle;
133*523fa7a6SAndroid Build Coastguard Worker 
ContainerHandleContainerHandle134*523fa7a6SAndroid Build Coastguard Worker   ContainerHandle() {}
135*523fa7a6SAndroid Build Coastguard Worker 
136*523fa7a6SAndroid Build Coastguard Worker   template <typename... Args>
ContainerHandleContainerHandle137*523fa7a6SAndroid Build Coastguard Worker   ContainerHandle(Args... args)
138*523fa7a6SAndroid Build Coastguard Worker       : handle(std::make_unique<container_type>(std::forward<Args>(args)...)) {}
139*523fa7a6SAndroid Build Coastguard Worker 
ContainerHandleContainerHandle140*523fa7a6SAndroid Build Coastguard Worker   /*implicit*/ ContainerHandle(container_type* c) : handle(c) {}
141*523fa7a6SAndroid Build Coastguard Worker 
ContainerHandleContainerHandle142*523fa7a6SAndroid Build Coastguard Worker   /*implicit*/ ContainerHandle(std::unique_ptr<container_type> c)
143*523fa7a6SAndroid Build Coastguard Worker       : handle(std::move(c)) {}
144*523fa7a6SAndroid Build Coastguard Worker 
set_leafContainerHandle145*523fa7a6SAndroid Build Coastguard Worker   void set_leaf(leaf_type* leaf) {
146*523fa7a6SAndroid Build Coastguard Worker     pytree_assert(handle->kind == Kind::Leaf);
147*523fa7a6SAndroid Build Coastguard Worker     handle->leaf = leaf;
148*523fa7a6SAndroid Build Coastguard Worker   }
149*523fa7a6SAndroid Build Coastguard Worker 
leaf_typeContainerHandle150*523fa7a6SAndroid Build Coastguard Worker   operator leaf_type() const {
151*523fa7a6SAndroid Build Coastguard Worker     pytree_assert(handle->kind == Kind::Leaf);
152*523fa7a6SAndroid Build Coastguard Worker     return *handle->leaf;
153*523fa7a6SAndroid Build Coastguard Worker   }
154*523fa7a6SAndroid Build Coastguard Worker 
leafContainerHandle155*523fa7a6SAndroid Build Coastguard Worker   const leaf_type& leaf() const {
156*523fa7a6SAndroid Build Coastguard Worker     pytree_assert(handle->kind == Kind::Leaf);
157*523fa7a6SAndroid Build Coastguard Worker     return *handle->leaf;
158*523fa7a6SAndroid Build Coastguard Worker   }
leafContainerHandle159*523fa7a6SAndroid Build Coastguard Worker   leaf_type& leaf() {
160*523fa7a6SAndroid Build Coastguard Worker     pytree_assert(handle->kind == Kind::Leaf);
161*523fa7a6SAndroid Build Coastguard Worker     return *handle->leaf;
162*523fa7a6SAndroid Build Coastguard Worker   }
163*523fa7a6SAndroid Build Coastguard Worker 
leaf_ptrContainerHandle164*523fa7a6SAndroid Build Coastguard Worker   const leaf_type* leaf_ptr() const {
165*523fa7a6SAndroid Build Coastguard Worker     pytree_assert(handle->kind == Kind::Leaf);
166*523fa7a6SAndroid Build Coastguard Worker     return handle->leaf;
167*523fa7a6SAndroid Build Coastguard Worker   }
leaf_ptrContainerHandle168*523fa7a6SAndroid Build Coastguard Worker   leaf_type* leaf_ptr() {
169*523fa7a6SAndroid Build Coastguard Worker     pytree_assert(handle->kind == Kind::Leaf);
170*523fa7a6SAndroid Build Coastguard Worker     return handle->leaf;
171*523fa7a6SAndroid Build Coastguard Worker   }
172*523fa7a6SAndroid Build Coastguard Worker 
173*523fa7a6SAndroid Build Coastguard Worker   const ContainerHandle& operator[](size_t idx) const {
174*523fa7a6SAndroid Build Coastguard Worker     pytree_assert(idx < handle->size);
175*523fa7a6SAndroid Build Coastguard Worker     return handle->items[idx];
176*523fa7a6SAndroid Build Coastguard Worker   }
177*523fa7a6SAndroid Build Coastguard Worker 
178*523fa7a6SAndroid Build Coastguard Worker   ContainerHandle& operator[](size_t idx) {
179*523fa7a6SAndroid Build Coastguard Worker     pytree_assert(idx < handle->size);
180*523fa7a6SAndroid Build Coastguard Worker     return handle->items[idx];
181*523fa7a6SAndroid Build Coastguard Worker   }
182*523fa7a6SAndroid Build Coastguard Worker 
containsContainerHandle183*523fa7a6SAndroid Build Coastguard Worker   bool contains(const KeyStr& lookup_key) const {
184*523fa7a6SAndroid Build Coastguard Worker     pytree_assert(isDict());
185*523fa7a6SAndroid Build Coastguard Worker     for (size_t i = 0; i < handle->size; ++i) {
186*523fa7a6SAndroid Build Coastguard Worker       if (handle->keys[i] == lookup_key) {
187*523fa7a6SAndroid Build Coastguard Worker         return true;
188*523fa7a6SAndroid Build Coastguard Worker       }
189*523fa7a6SAndroid Build Coastguard Worker     }
190*523fa7a6SAndroid Build Coastguard Worker     return false;
191*523fa7a6SAndroid Build Coastguard Worker   }
192*523fa7a6SAndroid Build Coastguard Worker 
atContainerHandle193*523fa7a6SAndroid Build Coastguard Worker   const ContainerHandle& at(const Key& lookup_key) const {
194*523fa7a6SAndroid Build Coastguard Worker     pytree_assert(isDict());
195*523fa7a6SAndroid Build Coastguard Worker     for (size_t i = 0; i < handle->size; ++i) {
196*523fa7a6SAndroid Build Coastguard Worker       if (handle->keys[i] == lookup_key) {
197*523fa7a6SAndroid Build Coastguard Worker         return handle->items[i];
198*523fa7a6SAndroid Build Coastguard Worker       }
199*523fa7a6SAndroid Build Coastguard Worker     }
200*523fa7a6SAndroid Build Coastguard Worker     pytree_unreachable();
201*523fa7a6SAndroid Build Coastguard Worker   }
202*523fa7a6SAndroid Build Coastguard Worker 
atContainerHandle203*523fa7a6SAndroid Build Coastguard Worker   const ContainerHandle& at(const KeyInt& lookup_key) const {
204*523fa7a6SAndroid Build Coastguard Worker     return at(Key(lookup_key));
205*523fa7a6SAndroid Build Coastguard Worker   }
206*523fa7a6SAndroid Build Coastguard Worker 
atContainerHandle207*523fa7a6SAndroid Build Coastguard Worker   const ContainerHandle& at(const KeyStr& lookup_key) const {
208*523fa7a6SAndroid Build Coastguard Worker     return at(Key(lookup_key));
209*523fa7a6SAndroid Build Coastguard Worker   }
210*523fa7a6SAndroid Build Coastguard Worker 
keyContainerHandle211*523fa7a6SAndroid Build Coastguard Worker   const Key& key(size_t idx) const {
212*523fa7a6SAndroid Build Coastguard Worker     pytree_assert(isDict());
213*523fa7a6SAndroid Build Coastguard Worker     return handle->keys[idx];
214*523fa7a6SAndroid Build Coastguard Worker   }
keyContainerHandle215*523fa7a6SAndroid Build Coastguard Worker   Key& key(size_t idx) {
216*523fa7a6SAndroid Build Coastguard Worker     pytree_assert(isDict());
217*523fa7a6SAndroid Build Coastguard Worker     return handle->keys[idx];
218*523fa7a6SAndroid Build Coastguard Worker   }
219*523fa7a6SAndroid Build Coastguard Worker 
sizeContainerHandle220*523fa7a6SAndroid Build Coastguard Worker   size_t size() const {
221*523fa7a6SAndroid Build Coastguard Worker     return handle->size;
222*523fa7a6SAndroid Build Coastguard Worker   }
223*523fa7a6SAndroid Build Coastguard Worker 
leaves_numContainerHandle224*523fa7a6SAndroid Build Coastguard Worker   size_t leaves_num() const {
225*523fa7a6SAndroid Build Coastguard Worker     return handle->leaves_num;
226*523fa7a6SAndroid Build Coastguard Worker   }
227*523fa7a6SAndroid Build Coastguard Worker 
isDictContainerHandle228*523fa7a6SAndroid Build Coastguard Worker   bool isDict() const {
229*523fa7a6SAndroid Build Coastguard Worker     return handle->kind == Kind::Dict;
230*523fa7a6SAndroid Build Coastguard Worker   }
231*523fa7a6SAndroid Build Coastguard Worker 
isListContainerHandle232*523fa7a6SAndroid Build Coastguard Worker   bool isList() const {
233*523fa7a6SAndroid Build Coastguard Worker     return handle->kind == Kind::List;
234*523fa7a6SAndroid Build Coastguard Worker   }
235*523fa7a6SAndroid Build Coastguard Worker 
isNamedTupleContainerHandle236*523fa7a6SAndroid Build Coastguard Worker   bool isNamedTuple() const {
237*523fa7a6SAndroid Build Coastguard Worker     return handle->kind == Kind::NamedTuple;
238*523fa7a6SAndroid Build Coastguard Worker   }
239*523fa7a6SAndroid Build Coastguard Worker 
isTupleContainerHandle240*523fa7a6SAndroid Build Coastguard Worker   bool isTuple() const {
241*523fa7a6SAndroid Build Coastguard Worker     return handle->kind == Kind::Tuple;
242*523fa7a6SAndroid Build Coastguard Worker   }
243*523fa7a6SAndroid Build Coastguard Worker 
isLeafContainerHandle244*523fa7a6SAndroid Build Coastguard Worker   bool isLeaf() const {
245*523fa7a6SAndroid Build Coastguard Worker     return handle->kind == Kind::Leaf;
246*523fa7a6SAndroid Build Coastguard Worker   }
247*523fa7a6SAndroid Build Coastguard Worker 
kindContainerHandle248*523fa7a6SAndroid Build Coastguard Worker   Kind kind() const {
249*523fa7a6SAndroid Build Coastguard Worker     return handle->kind;
250*523fa7a6SAndroid Build Coastguard Worker   }
251*523fa7a6SAndroid Build Coastguard Worker 
252*523fa7a6SAndroid Build Coastguard Worker   // Checks only structure, no leaves comparison
253*523fa7a6SAndroid Build Coastguard Worker   bool operator==(const ContainerHandle& rhs) {
254*523fa7a6SAndroid Build Coastguard Worker     const Kind knd = kind();
255*523fa7a6SAndroid Build Coastguard Worker     if (knd != rhs.kind()) {
256*523fa7a6SAndroid Build Coastguard Worker       return false;
257*523fa7a6SAndroid Build Coastguard Worker     }
258*523fa7a6SAndroid Build Coastguard Worker     if (knd == Kind::Leaf) {
259*523fa7a6SAndroid Build Coastguard Worker       return true;
260*523fa7a6SAndroid Build Coastguard Worker     }
261*523fa7a6SAndroid Build Coastguard Worker     const size_t _size = size();
262*523fa7a6SAndroid Build Coastguard Worker     if (_size != rhs.size()) {
263*523fa7a6SAndroid Build Coastguard Worker       return false;
264*523fa7a6SAndroid Build Coastguard Worker     }
265*523fa7a6SAndroid Build Coastguard Worker 
266*523fa7a6SAndroid Build Coastguard Worker     for (size_t i = 0; i < _size; ++i) {
267*523fa7a6SAndroid Build Coastguard Worker       if (knd == Kind::Dict && (key(i) != rhs.key(i))) {
268*523fa7a6SAndroid Build Coastguard Worker         return false;
269*523fa7a6SAndroid Build Coastguard Worker       }
270*523fa7a6SAndroid Build Coastguard Worker       if (operator[](i) != rhs[i]) {
271*523fa7a6SAndroid Build Coastguard Worker         return false;
272*523fa7a6SAndroid Build Coastguard Worker       }
273*523fa7a6SAndroid Build Coastguard Worker     }
274*523fa7a6SAndroid Build Coastguard Worker     return true;
275*523fa7a6SAndroid Build Coastguard Worker   }
276*523fa7a6SAndroid Build Coastguard Worker 
277*523fa7a6SAndroid Build Coastguard Worker   bool operator!=(const ContainerHandle& rhs) {
278*523fa7a6SAndroid Build Coastguard Worker     return !operator==(rhs);
279*523fa7a6SAndroid Build Coastguard Worker   }
280*523fa7a6SAndroid Build Coastguard Worker };
281*523fa7a6SAndroid Build Coastguard Worker 
282*523fa7a6SAndroid Build Coastguard Worker struct TreeSpecLeaf {};
283*523fa7a6SAndroid Build Coastguard Worker 
284*523fa7a6SAndroid Build Coastguard Worker template <typename Aux>
285*523fa7a6SAndroid Build Coastguard Worker using TreeSpec = ContainerHandle<TreeSpecLeaf, Aux>;
286*523fa7a6SAndroid Build Coastguard Worker template <typename Aux>
287*523fa7a6SAndroid Build Coastguard Worker using TreeSpecContainer = Container<TreeSpecLeaf, Aux>;
288*523fa7a6SAndroid Build Coastguard Worker 
289*523fa7a6SAndroid Build Coastguard Worker using StrTreeSpec = std::string;
290*523fa7a6SAndroid Build Coastguard Worker 
291*523fa7a6SAndroid Build Coastguard Worker // Expects refresh_leaves_num() was called after the last modification
292*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename U, typename Aux>
clone(const ContainerHandle<T,Aux> & node,U * leaves)293*523fa7a6SAndroid Build Coastguard Worker ContainerHandle<U, Aux> clone(const ContainerHandle<T, Aux>& node, U* leaves) {
294*523fa7a6SAndroid Build Coastguard Worker   if (node.isLeaf()) {
295*523fa7a6SAndroid Build Coastguard Worker     return ContainerHandle<U, Aux>(leaves);
296*523fa7a6SAndroid Build Coastguard Worker   }
297*523fa7a6SAndroid Build Coastguard Worker 
298*523fa7a6SAndroid Build Coastguard Worker   ContainerHandle<U, Aux> ret(node.kind(), node.size());
299*523fa7a6SAndroid Build Coastguard Worker   size_t leaves_offset = 0;
300*523fa7a6SAndroid Build Coastguard Worker   size_t size = node.size();
301*523fa7a6SAndroid Build Coastguard Worker   for (size_t i = 0; i < size; ++i) {
302*523fa7a6SAndroid Build Coastguard Worker     ret[i] = clone(node[i], leaves + leaves_offset);
303*523fa7a6SAndroid Build Coastguard Worker     leaves_offset += node[i].leaves_num();
304*523fa7a6SAndroid Build Coastguard Worker   }
305*523fa7a6SAndroid Build Coastguard Worker 
306*523fa7a6SAndroid Build Coastguard Worker   if (node.isDict()) {
307*523fa7a6SAndroid Build Coastguard Worker     ret.handle->keys = std::unique_ptr<Key[]>(new Key[size]);
308*523fa7a6SAndroid Build Coastguard Worker     for (size_t i = 0; i < size; ++i) {
309*523fa7a6SAndroid Build Coastguard Worker       ret.handle->keys[i] = node.handle->keys[i];
310*523fa7a6SAndroid Build Coastguard Worker     }
311*523fa7a6SAndroid Build Coastguard Worker   }
312*523fa7a6SAndroid Build Coastguard Worker 
313*523fa7a6SAndroid Build Coastguard Worker   return ret;
314*523fa7a6SAndroid Build Coastguard Worker }
315*523fa7a6SAndroid Build Coastguard Worker 
316*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux>
traverse(ContainerHandle<T,Aux> & node,FunctionRef<void (ContainerHandle<T,Aux> &)> func)317*523fa7a6SAndroid Build Coastguard Worker void traverse(
318*523fa7a6SAndroid Build Coastguard Worker     ContainerHandle<T, Aux>& node,
319*523fa7a6SAndroid Build Coastguard Worker     FunctionRef<void(ContainerHandle<T, Aux>&)> func) {
320*523fa7a6SAndroid Build Coastguard Worker   for (size_t i = 0; i < node.size(); ++i) {
321*523fa7a6SAndroid Build Coastguard Worker     traverse(node[i], func);
322*523fa7a6SAndroid Build Coastguard Worker   }
323*523fa7a6SAndroid Build Coastguard Worker 
324*523fa7a6SAndroid Build Coastguard Worker   func(node);
325*523fa7a6SAndroid Build Coastguard Worker }
326*523fa7a6SAndroid Build Coastguard Worker 
327*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux>
traverse(const ContainerHandle<T,Aux> & node,FunctionRef<void (const ContainerHandle<T,Aux> &)> func)328*523fa7a6SAndroid Build Coastguard Worker void traverse(
329*523fa7a6SAndroid Build Coastguard Worker     const ContainerHandle<T, Aux>& node,
330*523fa7a6SAndroid Build Coastguard Worker     FunctionRef<void(const ContainerHandle<T, Aux>&)> func) {
331*523fa7a6SAndroid Build Coastguard Worker   for (size_t i = 0; i < node.size(); ++i) {
332*523fa7a6SAndroid Build Coastguard Worker     traverse(node[i], func);
333*523fa7a6SAndroid Build Coastguard Worker   }
334*523fa7a6SAndroid Build Coastguard Worker 
335*523fa7a6SAndroid Build Coastguard Worker   func(node);
336*523fa7a6SAndroid Build Coastguard Worker }
337*523fa7a6SAndroid Build Coastguard Worker 
338*523fa7a6SAndroid Build Coastguard Worker struct Config final {
339*523fa7a6SAndroid Build Coastguard Worker   static constexpr char kTuple = 'T';
340*523fa7a6SAndroid Build Coastguard Worker   static constexpr char kNamedTuple = 'N';
341*523fa7a6SAndroid Build Coastguard Worker   static constexpr char kList = 'L';
342*523fa7a6SAndroid Build Coastguard Worker   static constexpr char kDict = 'D';
343*523fa7a6SAndroid Build Coastguard Worker   static constexpr char kCustom = 'C';
344*523fa7a6SAndroid Build Coastguard Worker   static constexpr char kLeaf = '$';
345*523fa7a6SAndroid Build Coastguard Worker   static constexpr char kNodeDataBegin = '(';
346*523fa7a6SAndroid Build Coastguard Worker   static constexpr char kNodeDataEnd = ')';
347*523fa7a6SAndroid Build Coastguard Worker   static constexpr char kDictStrKeyQuote = '\'';
348*523fa7a6SAndroid Build Coastguard Worker   static constexpr char kDictKeyValueSep = ':';
349*523fa7a6SAndroid Build Coastguard Worker   static constexpr char kChildrenSep = ',';
350*523fa7a6SAndroid Build Coastguard Worker   static constexpr char kChildrenDataSep = '#';
351*523fa7a6SAndroid Build Coastguard Worker };
352*523fa7a6SAndroid Build Coastguard Worker 
353*523fa7a6SAndroid Build Coastguard Worker template <typename Aux>
to_str_internal(const TreeSpec<Aux> & spec)354*523fa7a6SAndroid Build Coastguard Worker StrTreeSpec to_str_internal(const TreeSpec<Aux>& spec) {
355*523fa7a6SAndroid Build Coastguard Worker   std::string s;
356*523fa7a6SAndroid Build Coastguard Worker   switch (spec.kind()) {
357*523fa7a6SAndroid Build Coastguard Worker     case Kind::List:
358*523fa7a6SAndroid Build Coastguard Worker       s.push_back(Config::kList);
359*523fa7a6SAndroid Build Coastguard Worker       break;
360*523fa7a6SAndroid Build Coastguard Worker     case Kind::NamedTuple:
361*523fa7a6SAndroid Build Coastguard Worker       s.push_back(Config::kNamedTuple);
362*523fa7a6SAndroid Build Coastguard Worker       break;
363*523fa7a6SAndroid Build Coastguard Worker     case Kind::Tuple:
364*523fa7a6SAndroid Build Coastguard Worker       s.push_back(Config::kTuple);
365*523fa7a6SAndroid Build Coastguard Worker       break;
366*523fa7a6SAndroid Build Coastguard Worker     case Kind::Dict:
367*523fa7a6SAndroid Build Coastguard Worker       s.push_back(Config::kDict);
368*523fa7a6SAndroid Build Coastguard Worker       break;
369*523fa7a6SAndroid Build Coastguard Worker     case Kind::Leaf:
370*523fa7a6SAndroid Build Coastguard Worker       s.push_back(Config::kLeaf);
371*523fa7a6SAndroid Build Coastguard Worker       return s;
372*523fa7a6SAndroid Build Coastguard Worker     case Kind::Custom:
373*523fa7a6SAndroid Build Coastguard Worker       s.push_back(Config::kCustom);
374*523fa7a6SAndroid Build Coastguard Worker       s.push_back('(');
375*523fa7a6SAndroid Build Coastguard Worker       s.append(spec.handle->custom_type);
376*523fa7a6SAndroid Build Coastguard Worker       s.push_back(')');
377*523fa7a6SAndroid Build Coastguard Worker       break;
378*523fa7a6SAndroid Build Coastguard Worker     case Kind::None:
379*523fa7a6SAndroid Build Coastguard Worker       return s;
380*523fa7a6SAndroid Build Coastguard Worker   }
381*523fa7a6SAndroid Build Coastguard Worker   const size_t size = spec.size();
382*523fa7a6SAndroid Build Coastguard Worker   s.append(std::to_string(size));
383*523fa7a6SAndroid Build Coastguard Worker   for (size_t i = 0; i < size; ++i) {
384*523fa7a6SAndroid Build Coastguard Worker     s.push_back(Config::kChildrenDataSep);
385*523fa7a6SAndroid Build Coastguard Worker     s.append(std::to_string(spec[i].leaves_num()));
386*523fa7a6SAndroid Build Coastguard Worker   }
387*523fa7a6SAndroid Build Coastguard Worker   s.push_back(Config::kNodeDataBegin);
388*523fa7a6SAndroid Build Coastguard Worker   if (spec.kind() == Kind::Dict) {
389*523fa7a6SAndroid Build Coastguard Worker     for (size_t i = 0; i < size; ++i) {
390*523fa7a6SAndroid Build Coastguard Worker       if (i) {
391*523fa7a6SAndroid Build Coastguard Worker         s.push_back(Config::kChildrenSep);
392*523fa7a6SAndroid Build Coastguard Worker       }
393*523fa7a6SAndroid Build Coastguard Worker       const auto& key = spec.key(i);
394*523fa7a6SAndroid Build Coastguard Worker       if (key.kind() == Key::Kind::Int) {
395*523fa7a6SAndroid Build Coastguard Worker         s.append(std::to_string(key.as_int()));
396*523fa7a6SAndroid Build Coastguard Worker       } else if (key.kind() == Key::Kind::Str) {
397*523fa7a6SAndroid Build Coastguard Worker         s.push_back(Config::kDictStrKeyQuote);
398*523fa7a6SAndroid Build Coastguard Worker         s.append(key.as_str());
399*523fa7a6SAndroid Build Coastguard Worker         s.push_back(Config::kDictStrKeyQuote);
400*523fa7a6SAndroid Build Coastguard Worker       } else {
401*523fa7a6SAndroid Build Coastguard Worker         pytree_unreachable();
402*523fa7a6SAndroid Build Coastguard Worker       }
403*523fa7a6SAndroid Build Coastguard Worker       s.push_back(Config::kDictKeyValueSep);
404*523fa7a6SAndroid Build Coastguard Worker       s.append(to_str_internal(spec[i]));
405*523fa7a6SAndroid Build Coastguard Worker     }
406*523fa7a6SAndroid Build Coastguard Worker   } else {
407*523fa7a6SAndroid Build Coastguard Worker     for (size_t i = 0; i < size; ++i) {
408*523fa7a6SAndroid Build Coastguard Worker       if (i) {
409*523fa7a6SAndroid Build Coastguard Worker         s.push_back(Config::kChildrenSep);
410*523fa7a6SAndroid Build Coastguard Worker       }
411*523fa7a6SAndroid Build Coastguard Worker       s.append(to_str_internal(spec[i]));
412*523fa7a6SAndroid Build Coastguard Worker     }
413*523fa7a6SAndroid Build Coastguard Worker   }
414*523fa7a6SAndroid Build Coastguard Worker   s.push_back(Config::kNodeDataEnd);
415*523fa7a6SAndroid Build Coastguard Worker   return s;
416*523fa7a6SAndroid Build Coastguard Worker }
417*523fa7a6SAndroid Build Coastguard Worker 
418*523fa7a6SAndroid Build Coastguard Worker template <typename T>
419*523fa7a6SAndroid Build Coastguard Worker struct arr {
arrarr420*523fa7a6SAndroid Build Coastguard Worker   explicit arr(const size_t n) : data_(std::unique_ptr<T[]>(new T[n])), n_(n) {}
421*523fa7a6SAndroid Build Coastguard Worker 
422*523fa7a6SAndroid Build Coastguard Worker   T& operator[](size_t idx) {
423*523fa7a6SAndroid Build Coastguard Worker     return data_[idx];
424*523fa7a6SAndroid Build Coastguard Worker   }
425*523fa7a6SAndroid Build Coastguard Worker 
426*523fa7a6SAndroid Build Coastguard Worker   const T& operator[](size_t idx) const {
427*523fa7a6SAndroid Build Coastguard Worker     return data_[idx];
428*523fa7a6SAndroid Build Coastguard Worker   }
429*523fa7a6SAndroid Build Coastguard Worker 
dataarr430*523fa7a6SAndroid Build Coastguard Worker   inline T* data() {
431*523fa7a6SAndroid Build Coastguard Worker     return data_.get();
432*523fa7a6SAndroid Build Coastguard Worker   }
433*523fa7a6SAndroid Build Coastguard Worker 
sizearr434*523fa7a6SAndroid Build Coastguard Worker   inline size_t size() const {
435*523fa7a6SAndroid Build Coastguard Worker     return n_;
436*523fa7a6SAndroid Build Coastguard Worker   }
437*523fa7a6SAndroid Build Coastguard Worker 
438*523fa7a6SAndroid Build Coastguard Worker  private:
439*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<T[]> data_;
440*523fa7a6SAndroid Build Coastguard Worker   size_t n_;
441*523fa7a6SAndroid Build Coastguard Worker };
442*523fa7a6SAndroid Build Coastguard Worker 
read_number(const StrTreeSpec & spec,size_t & read_idx)443*523fa7a6SAndroid Build Coastguard Worker inline size_t read_number(const StrTreeSpec& spec, size_t& read_idx) {
444*523fa7a6SAndroid Build Coastguard Worker   size_t num = 0;
445*523fa7a6SAndroid Build Coastguard Worker   while (isdigit(spec[read_idx])) {
446*523fa7a6SAndroid Build Coastguard Worker     num = 10 * num + (spec[read_idx] - '0');
447*523fa7a6SAndroid Build Coastguard Worker     read_idx++;
448*523fa7a6SAndroid Build Coastguard Worker   }
449*523fa7a6SAndroid Build Coastguard Worker   return num;
450*523fa7a6SAndroid Build Coastguard Worker }
451*523fa7a6SAndroid Build Coastguard Worker 
read_node_layout(const StrTreeSpec & spec,size_t & read_idx)452*523fa7a6SAndroid Build Coastguard Worker inline arr<size_t> read_node_layout(const StrTreeSpec& spec, size_t& read_idx) {
453*523fa7a6SAndroid Build Coastguard Worker   const size_t child_num = read_number(spec, read_idx);
454*523fa7a6SAndroid Build Coastguard Worker   arr<size_t> ret(child_num);
455*523fa7a6SAndroid Build Coastguard Worker 
456*523fa7a6SAndroid Build Coastguard Worker   size_t child_idx = 0;
457*523fa7a6SAndroid Build Coastguard Worker   while (spec[read_idx] == Config::kChildrenDataSep) {
458*523fa7a6SAndroid Build Coastguard Worker     ++read_idx;
459*523fa7a6SAndroid Build Coastguard Worker     ret[child_idx++] = read_number(spec, read_idx);
460*523fa7a6SAndroid Build Coastguard Worker   }
461*523fa7a6SAndroid Build Coastguard Worker   return ret;
462*523fa7a6SAndroid Build Coastguard Worker }
463*523fa7a6SAndroid Build Coastguard Worker 
464*523fa7a6SAndroid Build Coastguard Worker template <typename Aux>
from_str_internal(const StrTreeSpec & spec,size_t read_idx,const arr<size_t> & spec_data)465*523fa7a6SAndroid Build Coastguard Worker TreeSpec<Aux> from_str_internal(
466*523fa7a6SAndroid Build Coastguard Worker     const StrTreeSpec& spec,
467*523fa7a6SAndroid Build Coastguard Worker     size_t read_idx,
468*523fa7a6SAndroid Build Coastguard Worker     const arr<size_t>& spec_data) {
469*523fa7a6SAndroid Build Coastguard Worker   const auto kind_char = spec[read_idx];
470*523fa7a6SAndroid Build Coastguard Worker   switch (kind_char) {
471*523fa7a6SAndroid Build Coastguard Worker     case Config::kTuple:
472*523fa7a6SAndroid Build Coastguard Worker     case Config::kNamedTuple:
473*523fa7a6SAndroid Build Coastguard Worker     case Config::kList: {
474*523fa7a6SAndroid Build Coastguard Worker       Kind kind = Kind::List;
475*523fa7a6SAndroid Build Coastguard Worker       std::string custom_type;
476*523fa7a6SAndroid Build Coastguard Worker       if (Config::kNamedTuple == kind_char) {
477*523fa7a6SAndroid Build Coastguard Worker         kind = Kind::NamedTuple;
478*523fa7a6SAndroid Build Coastguard Worker       } else if (Config::kTuple == kind_char) {
479*523fa7a6SAndroid Build Coastguard Worker         kind = Kind::Tuple;
480*523fa7a6SAndroid Build Coastguard Worker       } else if (Config::kCustom == kind_char) {
481*523fa7a6SAndroid Build Coastguard Worker         kind = Kind::Custom;
482*523fa7a6SAndroid Build Coastguard Worker         read_idx++;
483*523fa7a6SAndroid Build Coastguard Worker         assert(spec[read_idx] == '(');
484*523fa7a6SAndroid Build Coastguard Worker         auto type_str_end = spec_data[read_idx];
485*523fa7a6SAndroid Build Coastguard Worker         read_idx++;
486*523fa7a6SAndroid Build Coastguard Worker         custom_type = spec.substr(read_idx, type_str_end - read_idx);
487*523fa7a6SAndroid Build Coastguard Worker         assert(false);
488*523fa7a6SAndroid Build Coastguard Worker       }
489*523fa7a6SAndroid Build Coastguard Worker       read_idx++;
490*523fa7a6SAndroid Build Coastguard Worker       auto layout = read_node_layout(spec, read_idx);
491*523fa7a6SAndroid Build Coastguard Worker       const auto size = layout.size();
492*523fa7a6SAndroid Build Coastguard Worker       auto c = std::make_unique<TreeSpecContainer<Aux>>(kind, size);
493*523fa7a6SAndroid Build Coastguard Worker 
494*523fa7a6SAndroid Build Coastguard Worker       if (Kind::Custom == kind) {
495*523fa7a6SAndroid Build Coastguard Worker         c->custom_type = std::move(custom_type);
496*523fa7a6SAndroid Build Coastguard Worker       }
497*523fa7a6SAndroid Build Coastguard Worker 
498*523fa7a6SAndroid Build Coastguard Worker       size_t child_idx = 0;
499*523fa7a6SAndroid Build Coastguard Worker       size_t leaves_offset = 0;
500*523fa7a6SAndroid Build Coastguard Worker 
501*523fa7a6SAndroid Build Coastguard Worker       if (size > 0) {
502*523fa7a6SAndroid Build Coastguard Worker         while (spec[read_idx] != Config::kNodeDataEnd) {
503*523fa7a6SAndroid Build Coastguard Worker           // NOLINTNEXTLINE
504*523fa7a6SAndroid Build Coastguard Worker           auto next_delim_idx = spec_data[read_idx];
505*523fa7a6SAndroid Build Coastguard Worker           read_idx++;
506*523fa7a6SAndroid Build Coastguard Worker           c->items[child_idx] =
507*523fa7a6SAndroid Build Coastguard Worker               from_str_internal<Aux>(spec, read_idx, spec_data);
508*523fa7a6SAndroid Build Coastguard Worker           read_idx = next_delim_idx;
509*523fa7a6SAndroid Build Coastguard Worker           leaves_offset += layout[child_idx++];
510*523fa7a6SAndroid Build Coastguard Worker         }
511*523fa7a6SAndroid Build Coastguard Worker       } else {
512*523fa7a6SAndroid Build Coastguard Worker         read_idx++;
513*523fa7a6SAndroid Build Coastguard Worker       }
514*523fa7a6SAndroid Build Coastguard Worker       c->leaves_num = leaves_offset;
515*523fa7a6SAndroid Build Coastguard Worker       return TreeSpec<Aux>(std::move(c));
516*523fa7a6SAndroid Build Coastguard Worker     }
517*523fa7a6SAndroid Build Coastguard Worker 
518*523fa7a6SAndroid Build Coastguard Worker     case Config::kDict: {
519*523fa7a6SAndroid Build Coastguard Worker       read_idx++;
520*523fa7a6SAndroid Build Coastguard Worker       auto layout = read_node_layout(spec, read_idx);
521*523fa7a6SAndroid Build Coastguard Worker       const auto size = layout.size();
522*523fa7a6SAndroid Build Coastguard Worker       auto c = std::make_unique<TreeSpecContainer<Aux>>(Kind::Dict, size);
523*523fa7a6SAndroid Build Coastguard Worker 
524*523fa7a6SAndroid Build Coastguard Worker       size_t child_idx = 0;
525*523fa7a6SAndroid Build Coastguard Worker       size_t leaves_offset = 0;
526*523fa7a6SAndroid Build Coastguard Worker 
527*523fa7a6SAndroid Build Coastguard Worker       if (size > 0) {
528*523fa7a6SAndroid Build Coastguard Worker         while (spec[read_idx] != Config::kNodeDataEnd) {
529*523fa7a6SAndroid Build Coastguard Worker           // NOLINTNEXTLINE
530*523fa7a6SAndroid Build Coastguard Worker           auto next_delim_idx = spec_data[read_idx];
531*523fa7a6SAndroid Build Coastguard Worker           read_idx++;
532*523fa7a6SAndroid Build Coastguard Worker           if (spec[read_idx] == Config::kDictStrKeyQuote) {
533*523fa7a6SAndroid Build Coastguard Worker             auto key_delim_idx = spec_data[read_idx];
534*523fa7a6SAndroid Build Coastguard Worker             read_idx++;
535*523fa7a6SAndroid Build Coastguard Worker             const size_t key_len = key_delim_idx - read_idx;
536*523fa7a6SAndroid Build Coastguard Worker             // NOLINTNEXTLINE
537*523fa7a6SAndroid Build Coastguard Worker             c->keys[child_idx] = spec.substr(read_idx, key_len);
538*523fa7a6SAndroid Build Coastguard Worker             read_idx = key_delim_idx + 2;
539*523fa7a6SAndroid Build Coastguard Worker           } else {
540*523fa7a6SAndroid Build Coastguard Worker             pytree_assert(isdigit(spec[read_idx]));
541*523fa7a6SAndroid Build Coastguard Worker             size_t key = read_number(spec, read_idx);
542*523fa7a6SAndroid Build Coastguard Worker             c->keys[child_idx] = KeyInt(key);
543*523fa7a6SAndroid Build Coastguard Worker             read_idx += 1;
544*523fa7a6SAndroid Build Coastguard Worker           }
545*523fa7a6SAndroid Build Coastguard Worker 
546*523fa7a6SAndroid Build Coastguard Worker           c->items[child_idx] =
547*523fa7a6SAndroid Build Coastguard Worker               from_str_internal<Aux>(spec, read_idx, spec_data);
548*523fa7a6SAndroid Build Coastguard Worker           read_idx = next_delim_idx;
549*523fa7a6SAndroid Build Coastguard Worker           leaves_offset += layout[child_idx++];
550*523fa7a6SAndroid Build Coastguard Worker         }
551*523fa7a6SAndroid Build Coastguard Worker       } else {
552*523fa7a6SAndroid Build Coastguard Worker         read_idx++;
553*523fa7a6SAndroid Build Coastguard Worker       }
554*523fa7a6SAndroid Build Coastguard Worker       c->leaves_num = leaves_offset;
555*523fa7a6SAndroid Build Coastguard Worker       return TreeSpec<Aux>(std::move(c));
556*523fa7a6SAndroid Build Coastguard Worker     }
557*523fa7a6SAndroid Build Coastguard Worker 
558*523fa7a6SAndroid Build Coastguard Worker     case Config::kLeaf:
559*523fa7a6SAndroid Build Coastguard Worker       return new TreeSpecContainer<Aux>(nullptr);
560*523fa7a6SAndroid Build Coastguard Worker   }
561*523fa7a6SAndroid Build Coastguard Worker   pytree_unreachable();
562*523fa7a6SAndroid Build Coastguard Worker   return new TreeSpecContainer<Aux>(Kind::None);
563*523fa7a6SAndroid Build Coastguard Worker }
564*523fa7a6SAndroid Build Coastguard Worker 
565*523fa7a6SAndroid Build Coastguard Worker template <typename T>
566*523fa7a6SAndroid Build Coastguard Worker struct stack final {
567*523fa7a6SAndroid Build Coastguard Worker   constexpr static const size_t SIZE = 8;
568*523fa7a6SAndroid Build Coastguard Worker 
569*523fa7a6SAndroid Build Coastguard Worker   size_t size_ = 0;
570*523fa7a6SAndroid Build Coastguard Worker   T data[SIZE];
571*523fa7a6SAndroid Build Coastguard Worker 
pushfinal572*523fa7a6SAndroid Build Coastguard Worker   void push(T&& item) {
573*523fa7a6SAndroid Build Coastguard Worker     pytree_assert(size_ < SIZE);
574*523fa7a6SAndroid Build Coastguard Worker     data[size_++] = std::move(item);
575*523fa7a6SAndroid Build Coastguard Worker   }
576*523fa7a6SAndroid Build Coastguard Worker 
popfinal577*523fa7a6SAndroid Build Coastguard Worker   T pop() {
578*523fa7a6SAndroid Build Coastguard Worker     pytree_assert(size_ > 0);
579*523fa7a6SAndroid Build Coastguard Worker     return data[--size_];
580*523fa7a6SAndroid Build Coastguard Worker   }
581*523fa7a6SAndroid Build Coastguard Worker 
topfinal582*523fa7a6SAndroid Build Coastguard Worker   T& top() {
583*523fa7a6SAndroid Build Coastguard Worker     pytree_assert(size_ > 0);
584*523fa7a6SAndroid Build Coastguard Worker     return data[size_ - 1];
585*523fa7a6SAndroid Build Coastguard Worker   }
586*523fa7a6SAndroid Build Coastguard Worker 
sizefinal587*523fa7a6SAndroid Build Coastguard Worker   size_t size() {
588*523fa7a6SAndroid Build Coastguard Worker     return size_;
589*523fa7a6SAndroid Build Coastguard Worker   }
590*523fa7a6SAndroid Build Coastguard Worker };
591*523fa7a6SAndroid Build Coastguard Worker 
pre_parse(const StrTreeSpec & spec)592*523fa7a6SAndroid Build Coastguard Worker inline arr<size_t> pre_parse(const StrTreeSpec& spec) {
593*523fa7a6SAndroid Build Coastguard Worker   stack<std::pair<size_t, size_t>> stack;
594*523fa7a6SAndroid Build Coastguard Worker   size_t i = 0;
595*523fa7a6SAndroid Build Coastguard Worker   const size_t size = spec.size();
596*523fa7a6SAndroid Build Coastguard Worker   arr<size_t> ret(size);
597*523fa7a6SAndroid Build Coastguard Worker   while (i < size) {
598*523fa7a6SAndroid Build Coastguard Worker     const auto c = spec[i];
599*523fa7a6SAndroid Build Coastguard Worker     switch (c) {
600*523fa7a6SAndroid Build Coastguard Worker       case Config::kNodeDataBegin: {
601*523fa7a6SAndroid Build Coastguard Worker         stack.push({i, i});
602*523fa7a6SAndroid Build Coastguard Worker         break;
603*523fa7a6SAndroid Build Coastguard Worker       }
604*523fa7a6SAndroid Build Coastguard Worker       case Config::kNodeDataEnd: {
605*523fa7a6SAndroid Build Coastguard Worker         auto& item = stack.top();
606*523fa7a6SAndroid Build Coastguard Worker         size_t last_sep_idx = item.second;
607*523fa7a6SAndroid Build Coastguard Worker         ret[last_sep_idx] = i;
608*523fa7a6SAndroid Build Coastguard Worker         stack.pop();
609*523fa7a6SAndroid Build Coastguard Worker         break;
610*523fa7a6SAndroid Build Coastguard Worker       }
611*523fa7a6SAndroid Build Coastguard Worker       case Config::kDictStrKeyQuote: {
612*523fa7a6SAndroid Build Coastguard Worker         size_t idx = i;
613*523fa7a6SAndroid Build Coastguard Worker         i++;
614*523fa7a6SAndroid Build Coastguard Worker         while (spec[i] != Config::kDictStrKeyQuote) {
615*523fa7a6SAndroid Build Coastguard Worker           i++;
616*523fa7a6SAndroid Build Coastguard Worker         }
617*523fa7a6SAndroid Build Coastguard Worker         ret[idx] = i;
618*523fa7a6SAndroid Build Coastguard Worker         ret[i] = idx;
619*523fa7a6SAndroid Build Coastguard Worker         break;
620*523fa7a6SAndroid Build Coastguard Worker       }
621*523fa7a6SAndroid Build Coastguard Worker       case Config::kChildrenSep: {
622*523fa7a6SAndroid Build Coastguard Worker         auto& item = stack.top();
623*523fa7a6SAndroid Build Coastguard Worker         size_t last_sep_idx = item.second;
624*523fa7a6SAndroid Build Coastguard Worker         ret[last_sep_idx] = i;
625*523fa7a6SAndroid Build Coastguard Worker         item.second = i;
626*523fa7a6SAndroid Build Coastguard Worker         break;
627*523fa7a6SAndroid Build Coastguard Worker       }
628*523fa7a6SAndroid Build Coastguard Worker     }
629*523fa7a6SAndroid Build Coastguard Worker     i++;
630*523fa7a6SAndroid Build Coastguard Worker   }
631*523fa7a6SAndroid Build Coastguard Worker   return ret;
632*523fa7a6SAndroid Build Coastguard Worker }
633*523fa7a6SAndroid Build Coastguard Worker 
634*523fa7a6SAndroid Build Coastguard Worker template <typename Aux = Empty>
from_str(const StrTreeSpec & spec)635*523fa7a6SAndroid Build Coastguard Worker TreeSpec<Aux> from_str(const StrTreeSpec& spec) {
636*523fa7a6SAndroid Build Coastguard Worker   return from_str_internal<Aux>(spec, 0u, pre_parse(spec));
637*523fa7a6SAndroid Build Coastguard Worker }
638*523fa7a6SAndroid Build Coastguard Worker 
639*523fa7a6SAndroid Build Coastguard Worker template <typename Aux>
to_str(const TreeSpec<Aux> & spec)640*523fa7a6SAndroid Build Coastguard Worker StrTreeSpec to_str(const TreeSpec<Aux>& spec) {
641*523fa7a6SAndroid Build Coastguard Worker   if (spec.leaves_num() == 0) {
642*523fa7a6SAndroid Build Coastguard Worker     refresh_leaves_num(spec);
643*523fa7a6SAndroid Build Coastguard Worker   }
644*523fa7a6SAndroid Build Coastguard Worker   return to_str_internal(spec);
645*523fa7a6SAndroid Build Coastguard Worker }
646*523fa7a6SAndroid Build Coastguard Worker 
647*523fa7a6SAndroid Build Coastguard Worker template <typename Aux>
648*523fa7a6SAndroid Build Coastguard Worker StrTreeSpec to_str(const TreeSpec<Aux>& spec);
649*523fa7a6SAndroid Build Coastguard Worker 
650*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux>
unflatten(const TreeSpec<Aux> & spec,T * leaves)651*523fa7a6SAndroid Build Coastguard Worker ContainerHandle<T, Aux> unflatten(const TreeSpec<Aux>& spec, T* leaves) {
652*523fa7a6SAndroid Build Coastguard Worker   if (spec.leaves_num() == 0) {
653*523fa7a6SAndroid Build Coastguard Worker     refresh_leaves_num(spec);
654*523fa7a6SAndroid Build Coastguard Worker   }
655*523fa7a6SAndroid Build Coastguard Worker   return clone(spec, leaves);
656*523fa7a6SAndroid Build Coastguard Worker }
657*523fa7a6SAndroid Build Coastguard Worker 
658*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux = Empty>
unflatten(const StrTreeSpec & spec,T * leaves)659*523fa7a6SAndroid Build Coastguard Worker ContainerHandle<T, Aux> unflatten(const StrTreeSpec& spec, T* leaves) {
660*523fa7a6SAndroid Build Coastguard Worker   return unflatten(from_str<Aux>(spec), leaves);
661*523fa7a6SAndroid Build Coastguard Worker }
662*523fa7a6SAndroid Build Coastguard Worker 
663*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux>
flatten_internal(const ContainerHandle<T,Aux> & tree,const T ** leaves)664*523fa7a6SAndroid Build Coastguard Worker void flatten_internal(const ContainerHandle<T, Aux>& tree, const T** leaves) {
665*523fa7a6SAndroid Build Coastguard Worker   using tree_t = decltype(tree);
666*523fa7a6SAndroid Build Coastguard Worker   size_t leaves_idx = 0;
667*523fa7a6SAndroid Build Coastguard Worker   auto func = [&](tree_t node) {
668*523fa7a6SAndroid Build Coastguard Worker     if (node.isLeaf()) {
669*523fa7a6SAndroid Build Coastguard Worker       leaves[leaves_idx++] = node.leaf_ptr();
670*523fa7a6SAndroid Build Coastguard Worker     }
671*523fa7a6SAndroid Build Coastguard Worker   };
672*523fa7a6SAndroid Build Coastguard Worker   traverse(tree, FunctionRef<void(tree_t&)>{func});
673*523fa7a6SAndroid Build Coastguard Worker }
674*523fa7a6SAndroid Build Coastguard Worker 
675*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux>
flatten_internal(ContainerHandle<T,Aux> & tree,T ** leaves)676*523fa7a6SAndroid Build Coastguard Worker void flatten_internal(ContainerHandle<T, Aux>& tree, T** leaves) {
677*523fa7a6SAndroid Build Coastguard Worker   using tree_t = decltype(tree);
678*523fa7a6SAndroid Build Coastguard Worker   size_t leaves_idx = 0;
679*523fa7a6SAndroid Build Coastguard Worker   auto func = [&](tree_t node) {
680*523fa7a6SAndroid Build Coastguard Worker     if (node.isLeaf()) {
681*523fa7a6SAndroid Build Coastguard Worker       leaves[leaves_idx++] = node.leaf_ptr();
682*523fa7a6SAndroid Build Coastguard Worker     }
683*523fa7a6SAndroid Build Coastguard Worker   };
684*523fa7a6SAndroid Build Coastguard Worker   traverse(tree, FunctionRef<void(tree_t&)>{func});
685*523fa7a6SAndroid Build Coastguard Worker }
686*523fa7a6SAndroid Build Coastguard Worker 
687*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux>
refresh_leaves_num(const ContainerHandle<T,Aux> & node)688*523fa7a6SAndroid Build Coastguard Worker size_t refresh_leaves_num(const ContainerHandle<T, Aux>& node) {
689*523fa7a6SAndroid Build Coastguard Worker   if (node.isLeaf()) {
690*523fa7a6SAndroid Build Coastguard Worker     node.handle->leaves_num = 1;
691*523fa7a6SAndroid Build Coastguard Worker     return 1;
692*523fa7a6SAndroid Build Coastguard Worker   }
693*523fa7a6SAndroid Build Coastguard Worker 
694*523fa7a6SAndroid Build Coastguard Worker   size_t n = 0;
695*523fa7a6SAndroid Build Coastguard Worker   for (size_t i = 0; i < node.size(); ++i) {
696*523fa7a6SAndroid Build Coastguard Worker     n += refresh_leaves_num(node[i]);
697*523fa7a6SAndroid Build Coastguard Worker   }
698*523fa7a6SAndroid Build Coastguard Worker 
699*523fa7a6SAndroid Build Coastguard Worker   node.handle->leaves_num = n;
700*523fa7a6SAndroid Build Coastguard Worker   return n;
701*523fa7a6SAndroid Build Coastguard Worker }
702*523fa7a6SAndroid Build Coastguard Worker 
703*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux>
flatten(const ContainerHandle<T,Aux> & tree)704*523fa7a6SAndroid Build Coastguard Worker std::pair<arr<const T*>, std::unique_ptr<TreeSpec<Aux>>> flatten(
705*523fa7a6SAndroid Build Coastguard Worker     const ContainerHandle<T, Aux>& tree) {
706*523fa7a6SAndroid Build Coastguard Worker   refresh_leaves_num(tree);
707*523fa7a6SAndroid Build Coastguard Worker   const size_t n = tree.leaves_num();
708*523fa7a6SAndroid Build Coastguard Worker   arr<T*> leaves(n);
709*523fa7a6SAndroid Build Coastguard Worker   flatten_internal(tree, leaves.data());
710*523fa7a6SAndroid Build Coastguard Worker   auto spec_leaves = std::make_unique<TreeSpecLeaf[]>(n);
711*523fa7a6SAndroid Build Coastguard Worker   return {
712*523fa7a6SAndroid Build Coastguard Worker       std::move(leaves),
713*523fa7a6SAndroid Build Coastguard Worker       std::make_unique<TreeSpec<Aux>>(clone(tree, spec_leaves.get()))};
714*523fa7a6SAndroid Build Coastguard Worker }
715*523fa7a6SAndroid Build Coastguard Worker 
716*523fa7a6SAndroid Build Coastguard Worker // Duplication of logic for non const ContainerHandle
717*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux>
flatten(ContainerHandle<T,Aux> & tree)718*523fa7a6SAndroid Build Coastguard Worker std::pair<arr<T*>, std::unique_ptr<TreeSpec<Aux>>> flatten(
719*523fa7a6SAndroid Build Coastguard Worker     ContainerHandle<T, Aux>& tree) {
720*523fa7a6SAndroid Build Coastguard Worker   refresh_leaves_num(tree);
721*523fa7a6SAndroid Build Coastguard Worker   const size_t n = tree.leaves_num();
722*523fa7a6SAndroid Build Coastguard Worker   arr<T*> leaves(n);
723*523fa7a6SAndroid Build Coastguard Worker   flatten_internal(tree, leaves.data());
724*523fa7a6SAndroid Build Coastguard Worker   auto spec_leaves = std::make_unique<TreeSpecLeaf[]>(n);
725*523fa7a6SAndroid Build Coastguard Worker   return {
726*523fa7a6SAndroid Build Coastguard Worker       std::move(leaves),
727*523fa7a6SAndroid Build Coastguard Worker       std::make_unique<TreeSpec<Aux>>(clone(tree, spec_leaves.get()))};
728*523fa7a6SAndroid Build Coastguard Worker }
729*523fa7a6SAndroid Build Coastguard Worker 
730*523fa7a6SAndroid Build Coastguard Worker } // namespace pytree
731*523fa7a6SAndroid Build Coastguard Worker } // namespace extension
732*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
733*523fa7a6SAndroid Build Coastguard Worker 
734*523fa7a6SAndroid Build Coastguard Worker namespace torch {
735*523fa7a6SAndroid Build Coastguard Worker namespace executor {
736*523fa7a6SAndroid Build Coastguard Worker namespace pytree {
737*523fa7a6SAndroid Build Coastguard Worker // TODO(T197294990): Remove these deprecated aliases once all users have moved
738*523fa7a6SAndroid Build Coastguard Worker // to the new `::executorch` namespaces.
739*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::pytree::Empty;
740*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::pytree::from_str;
741*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::pytree::TreeSpec;
742*523fa7a6SAndroid Build Coastguard Worker } // namespace pytree
743*523fa7a6SAndroid Build Coastguard Worker } // namespace executor
744*523fa7a6SAndroid Build Coastguard Worker } // namespace torch
745