1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker *
5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker */
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker #pragma once
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Worker #include <ctype.h>
12*523fa7a6SAndroid Build Coastguard Worker #include <algorithm>
13*523fa7a6SAndroid Build Coastguard Worker #include <cassert>
14*523fa7a6SAndroid Build Coastguard Worker #include <cstdint>
15*523fa7a6SAndroid Build Coastguard Worker #include <cstring>
16*523fa7a6SAndroid Build Coastguard Worker #include <memory>
17*523fa7a6SAndroid Build Coastguard Worker #include <string>
18*523fa7a6SAndroid Build Coastguard Worker #include <variant>
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Worker // NB: This is a local, pytree FunctionRef and not from the ExecuTorch runtime.
21*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/pytree/function_ref.h>
22*523fa7a6SAndroid Build Coastguard Worker
23*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
24*523fa7a6SAndroid Build Coastguard Worker namespace extension {
25*523fa7a6SAndroid Build Coastguard Worker namespace pytree {
26*523fa7a6SAndroid Build Coastguard Worker
pytree_assert(bool must_be_true)27*523fa7a6SAndroid Build Coastguard Worker inline void pytree_assert(bool must_be_true) {
28*523fa7a6SAndroid Build Coastguard Worker assert(must_be_true);
29*523fa7a6SAndroid Build Coastguard Worker }
30*523fa7a6SAndroid Build Coastguard Worker
31*523fa7a6SAndroid Build Coastguard Worker #ifdef _MSC_VER
32*523fa7a6SAndroid Build Coastguard Worker #define EXECUTORCH_ALWAYS_INLINE __forceinline
33*523fa7a6SAndroid Build Coastguard Worker #elif defined(__GNUC__)
34*523fa7a6SAndroid Build Coastguard Worker #define EXECUTORCH_ALWAYS_INLINE inline __attribute__((__always_inline__))
35*523fa7a6SAndroid Build Coastguard Worker #else
36*523fa7a6SAndroid Build Coastguard Worker #define EXECUTORCH_ALWAYS_INLINE inline
37*523fa7a6SAndroid Build Coastguard Worker #endif
38*523fa7a6SAndroid Build Coastguard Worker
pytree_unreachable()39*523fa7a6SAndroid Build Coastguard Worker [[noreturn]] EXECUTORCH_ALWAYS_INLINE void pytree_unreachable() {
40*523fa7a6SAndroid Build Coastguard Worker assert(false);
41*523fa7a6SAndroid Build Coastguard Worker #if defined(__GNUC__)
42*523fa7a6SAndroid Build Coastguard Worker __builtin_unreachable();
43*523fa7a6SAndroid Build Coastguard Worker #elif defined(_MSC_VER)
44*523fa7a6SAndroid Build Coastguard Worker __assume(0);
45*523fa7a6SAndroid Build Coastguard Worker #else
46*523fa7a6SAndroid Build Coastguard Worker while (!0)
47*523fa7a6SAndroid Build Coastguard Worker ;
48*523fa7a6SAndroid Build Coastguard Worker #endif
49*523fa7a6SAndroid Build Coastguard Worker }
50*523fa7a6SAndroid Build Coastguard Worker
51*523fa7a6SAndroid Build Coastguard Worker enum class Kind : uint8_t { List, Tuple, NamedTuple, Dict, Leaf, Custom, None };
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Worker using KeyStr = std::string;
54*523fa7a6SAndroid Build Coastguard Worker using KeyInt = int32_t;
55*523fa7a6SAndroid Build Coastguard Worker
56*523fa7a6SAndroid Build Coastguard Worker struct Key {
57*523fa7a6SAndroid Build Coastguard Worker enum class Kind : uint8_t { None, Int, Str } kind_;
58*523fa7a6SAndroid Build Coastguard Worker
59*523fa7a6SAndroid Build Coastguard Worker private:
60*523fa7a6SAndroid Build Coastguard Worker std::variant<std::monostate, KeyInt, KeyStr> repr_;
61*523fa7a6SAndroid Build Coastguard Worker
62*523fa7a6SAndroid Build Coastguard Worker public:
KeyKey63*523fa7a6SAndroid Build Coastguard Worker Key() {}
KeyKey64*523fa7a6SAndroid Build Coastguard Worker /*implicit*/ Key(KeyInt key) : repr_(key) {}
KeyKey65*523fa7a6SAndroid Build Coastguard Worker /*implicit*/ Key(KeyStr key) : repr_(std::move(key)) {}
66*523fa7a6SAndroid Build Coastguard Worker
kindKey67*523fa7a6SAndroid Build Coastguard Worker Kind kind() const {
68*523fa7a6SAndroid Build Coastguard Worker return static_cast<Kind>(repr_.index());
69*523fa7a6SAndroid Build Coastguard Worker }
70*523fa7a6SAndroid Build Coastguard Worker
as_intKey71*523fa7a6SAndroid Build Coastguard Worker KeyInt as_int() const {
72*523fa7a6SAndroid Build Coastguard Worker return std::get<KeyInt>(repr_);
73*523fa7a6SAndroid Build Coastguard Worker }
74*523fa7a6SAndroid Build Coastguard Worker
KeyIntKey75*523fa7a6SAndroid Build Coastguard Worker operator KeyInt() const {
76*523fa7a6SAndroid Build Coastguard Worker return as_int();
77*523fa7a6SAndroid Build Coastguard Worker }
78*523fa7a6SAndroid Build Coastguard Worker
as_strKey79*523fa7a6SAndroid Build Coastguard Worker const KeyStr& as_str() const {
80*523fa7a6SAndroid Build Coastguard Worker return std::get<KeyStr>(repr_);
81*523fa7a6SAndroid Build Coastguard Worker }
82*523fa7a6SAndroid Build Coastguard Worker
83*523fa7a6SAndroid Build Coastguard Worker operator const KeyStr&() const {
84*523fa7a6SAndroid Build Coastguard Worker return as_str();
85*523fa7a6SAndroid Build Coastguard Worker }
86*523fa7a6SAndroid Build Coastguard Worker
87*523fa7a6SAndroid Build Coastguard Worker bool operator==(const Key& rhs) const {
88*523fa7a6SAndroid Build Coastguard Worker return repr_ == rhs.repr_;
89*523fa7a6SAndroid Build Coastguard Worker }
90*523fa7a6SAndroid Build Coastguard Worker
91*523fa7a6SAndroid Build Coastguard Worker bool operator!=(const Key& rhs) const {
92*523fa7a6SAndroid Build Coastguard Worker return !operator==(rhs);
93*523fa7a6SAndroid Build Coastguard Worker }
94*523fa7a6SAndroid Build Coastguard Worker };
95*523fa7a6SAndroid Build Coastguard Worker
96*523fa7a6SAndroid Build Coastguard Worker struct Empty {};
97*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux = Empty>
98*523fa7a6SAndroid Build Coastguard Worker struct ContainerHandle;
99*523fa7a6SAndroid Build Coastguard Worker
100*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux = Empty>
101*523fa7a6SAndroid Build Coastguard Worker struct Container final : public Aux {
102*523fa7a6SAndroid Build Coastguard Worker using handle_type = ContainerHandle<T, Aux>;
103*523fa7a6SAndroid Build Coastguard Worker using leaf_type = T;
104*523fa7a6SAndroid Build Coastguard Worker
105*523fa7a6SAndroid Build Coastguard Worker Kind kind = Kind::None;
106*523fa7a6SAndroid Build Coastguard Worker size_t size = 0;
107*523fa7a6SAndroid Build Coastguard Worker leaf_type* leaf = nullptr;
108*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<handle_type[]> items;
109*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<Key[]> keys;
110*523fa7a6SAndroid Build Coastguard Worker std::string custom_type;
111*523fa7a6SAndroid Build Coastguard Worker // internal only field to keep associated to every node meta info
112*523fa7a6SAndroid Build Coastguard Worker mutable size_t leaves_num = 0u;
113*523fa7a6SAndroid Build Coastguard Worker
114*523fa7a6SAndroid Build Coastguard Worker /*implicit*/ Container(Kind kind, size_t size = 0u)
kindfinal115*523fa7a6SAndroid Build Coastguard Worker : kind(kind),
116*523fa7a6SAndroid Build Coastguard Worker size(size),
117*523fa7a6SAndroid Build Coastguard Worker items(std::unique_ptr<handle_type[]>(new handle_type[size])) {
118*523fa7a6SAndroid Build Coastguard Worker if (kind == Kind::Dict) {
119*523fa7a6SAndroid Build Coastguard Worker keys = std::unique_ptr<Key[]>(new Key[size]);
120*523fa7a6SAndroid Build Coastguard Worker }
121*523fa7a6SAndroid Build Coastguard Worker }
Containerfinal122*523fa7a6SAndroid Build Coastguard Worker /*implicit*/ Container(leaf_type* leaf)
123*523fa7a6SAndroid Build Coastguard Worker : kind(Kind::Leaf), size(0u), leaf(leaf), leaves_num(1u) {}
124*523fa7a6SAndroid Build Coastguard Worker Container(const Container&) = delete;
125*523fa7a6SAndroid Build Coastguard Worker Container& operator=(const Container&) = delete;
126*523fa7a6SAndroid Build Coastguard Worker };
127*523fa7a6SAndroid Build Coastguard Worker
128*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux>
129*523fa7a6SAndroid Build Coastguard Worker struct ContainerHandle {
130*523fa7a6SAndroid Build Coastguard Worker using container_type = Container<T, Aux>;
131*523fa7a6SAndroid Build Coastguard Worker using leaf_type = T;
132*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<container_type> handle;
133*523fa7a6SAndroid Build Coastguard Worker
ContainerHandleContainerHandle134*523fa7a6SAndroid Build Coastguard Worker ContainerHandle() {}
135*523fa7a6SAndroid Build Coastguard Worker
136*523fa7a6SAndroid Build Coastguard Worker template <typename... Args>
ContainerHandleContainerHandle137*523fa7a6SAndroid Build Coastguard Worker ContainerHandle(Args... args)
138*523fa7a6SAndroid Build Coastguard Worker : handle(std::make_unique<container_type>(std::forward<Args>(args)...)) {}
139*523fa7a6SAndroid Build Coastguard Worker
ContainerHandleContainerHandle140*523fa7a6SAndroid Build Coastguard Worker /*implicit*/ ContainerHandle(container_type* c) : handle(c) {}
141*523fa7a6SAndroid Build Coastguard Worker
ContainerHandleContainerHandle142*523fa7a6SAndroid Build Coastguard Worker /*implicit*/ ContainerHandle(std::unique_ptr<container_type> c)
143*523fa7a6SAndroid Build Coastguard Worker : handle(std::move(c)) {}
144*523fa7a6SAndroid Build Coastguard Worker
set_leafContainerHandle145*523fa7a6SAndroid Build Coastguard Worker void set_leaf(leaf_type* leaf) {
146*523fa7a6SAndroid Build Coastguard Worker pytree_assert(handle->kind == Kind::Leaf);
147*523fa7a6SAndroid Build Coastguard Worker handle->leaf = leaf;
148*523fa7a6SAndroid Build Coastguard Worker }
149*523fa7a6SAndroid Build Coastguard Worker
leaf_typeContainerHandle150*523fa7a6SAndroid Build Coastguard Worker operator leaf_type() const {
151*523fa7a6SAndroid Build Coastguard Worker pytree_assert(handle->kind == Kind::Leaf);
152*523fa7a6SAndroid Build Coastguard Worker return *handle->leaf;
153*523fa7a6SAndroid Build Coastguard Worker }
154*523fa7a6SAndroid Build Coastguard Worker
leafContainerHandle155*523fa7a6SAndroid Build Coastguard Worker const leaf_type& leaf() const {
156*523fa7a6SAndroid Build Coastguard Worker pytree_assert(handle->kind == Kind::Leaf);
157*523fa7a6SAndroid Build Coastguard Worker return *handle->leaf;
158*523fa7a6SAndroid Build Coastguard Worker }
leafContainerHandle159*523fa7a6SAndroid Build Coastguard Worker leaf_type& leaf() {
160*523fa7a6SAndroid Build Coastguard Worker pytree_assert(handle->kind == Kind::Leaf);
161*523fa7a6SAndroid Build Coastguard Worker return *handle->leaf;
162*523fa7a6SAndroid Build Coastguard Worker }
163*523fa7a6SAndroid Build Coastguard Worker
leaf_ptrContainerHandle164*523fa7a6SAndroid Build Coastguard Worker const leaf_type* leaf_ptr() const {
165*523fa7a6SAndroid Build Coastguard Worker pytree_assert(handle->kind == Kind::Leaf);
166*523fa7a6SAndroid Build Coastguard Worker return handle->leaf;
167*523fa7a6SAndroid Build Coastguard Worker }
leaf_ptrContainerHandle168*523fa7a6SAndroid Build Coastguard Worker leaf_type* leaf_ptr() {
169*523fa7a6SAndroid Build Coastguard Worker pytree_assert(handle->kind == Kind::Leaf);
170*523fa7a6SAndroid Build Coastguard Worker return handle->leaf;
171*523fa7a6SAndroid Build Coastguard Worker }
172*523fa7a6SAndroid Build Coastguard Worker
173*523fa7a6SAndroid Build Coastguard Worker const ContainerHandle& operator[](size_t idx) const {
174*523fa7a6SAndroid Build Coastguard Worker pytree_assert(idx < handle->size);
175*523fa7a6SAndroid Build Coastguard Worker return handle->items[idx];
176*523fa7a6SAndroid Build Coastguard Worker }
177*523fa7a6SAndroid Build Coastguard Worker
178*523fa7a6SAndroid Build Coastguard Worker ContainerHandle& operator[](size_t idx) {
179*523fa7a6SAndroid Build Coastguard Worker pytree_assert(idx < handle->size);
180*523fa7a6SAndroid Build Coastguard Worker return handle->items[idx];
181*523fa7a6SAndroid Build Coastguard Worker }
182*523fa7a6SAndroid Build Coastguard Worker
containsContainerHandle183*523fa7a6SAndroid Build Coastguard Worker bool contains(const KeyStr& lookup_key) const {
184*523fa7a6SAndroid Build Coastguard Worker pytree_assert(isDict());
185*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < handle->size; ++i) {
186*523fa7a6SAndroid Build Coastguard Worker if (handle->keys[i] == lookup_key) {
187*523fa7a6SAndroid Build Coastguard Worker return true;
188*523fa7a6SAndroid Build Coastguard Worker }
189*523fa7a6SAndroid Build Coastguard Worker }
190*523fa7a6SAndroid Build Coastguard Worker return false;
191*523fa7a6SAndroid Build Coastguard Worker }
192*523fa7a6SAndroid Build Coastguard Worker
atContainerHandle193*523fa7a6SAndroid Build Coastguard Worker const ContainerHandle& at(const Key& lookup_key) const {
194*523fa7a6SAndroid Build Coastguard Worker pytree_assert(isDict());
195*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < handle->size; ++i) {
196*523fa7a6SAndroid Build Coastguard Worker if (handle->keys[i] == lookup_key) {
197*523fa7a6SAndroid Build Coastguard Worker return handle->items[i];
198*523fa7a6SAndroid Build Coastguard Worker }
199*523fa7a6SAndroid Build Coastguard Worker }
200*523fa7a6SAndroid Build Coastguard Worker pytree_unreachable();
201*523fa7a6SAndroid Build Coastguard Worker }
202*523fa7a6SAndroid Build Coastguard Worker
atContainerHandle203*523fa7a6SAndroid Build Coastguard Worker const ContainerHandle& at(const KeyInt& lookup_key) const {
204*523fa7a6SAndroid Build Coastguard Worker return at(Key(lookup_key));
205*523fa7a6SAndroid Build Coastguard Worker }
206*523fa7a6SAndroid Build Coastguard Worker
atContainerHandle207*523fa7a6SAndroid Build Coastguard Worker const ContainerHandle& at(const KeyStr& lookup_key) const {
208*523fa7a6SAndroid Build Coastguard Worker return at(Key(lookup_key));
209*523fa7a6SAndroid Build Coastguard Worker }
210*523fa7a6SAndroid Build Coastguard Worker
keyContainerHandle211*523fa7a6SAndroid Build Coastguard Worker const Key& key(size_t idx) const {
212*523fa7a6SAndroid Build Coastguard Worker pytree_assert(isDict());
213*523fa7a6SAndroid Build Coastguard Worker return handle->keys[idx];
214*523fa7a6SAndroid Build Coastguard Worker }
keyContainerHandle215*523fa7a6SAndroid Build Coastguard Worker Key& key(size_t idx) {
216*523fa7a6SAndroid Build Coastguard Worker pytree_assert(isDict());
217*523fa7a6SAndroid Build Coastguard Worker return handle->keys[idx];
218*523fa7a6SAndroid Build Coastguard Worker }
219*523fa7a6SAndroid Build Coastguard Worker
sizeContainerHandle220*523fa7a6SAndroid Build Coastguard Worker size_t size() const {
221*523fa7a6SAndroid Build Coastguard Worker return handle->size;
222*523fa7a6SAndroid Build Coastguard Worker }
223*523fa7a6SAndroid Build Coastguard Worker
leaves_numContainerHandle224*523fa7a6SAndroid Build Coastguard Worker size_t leaves_num() const {
225*523fa7a6SAndroid Build Coastguard Worker return handle->leaves_num;
226*523fa7a6SAndroid Build Coastguard Worker }
227*523fa7a6SAndroid Build Coastguard Worker
isDictContainerHandle228*523fa7a6SAndroid Build Coastguard Worker bool isDict() const {
229*523fa7a6SAndroid Build Coastguard Worker return handle->kind == Kind::Dict;
230*523fa7a6SAndroid Build Coastguard Worker }
231*523fa7a6SAndroid Build Coastguard Worker
isListContainerHandle232*523fa7a6SAndroid Build Coastguard Worker bool isList() const {
233*523fa7a6SAndroid Build Coastguard Worker return handle->kind == Kind::List;
234*523fa7a6SAndroid Build Coastguard Worker }
235*523fa7a6SAndroid Build Coastguard Worker
isNamedTupleContainerHandle236*523fa7a6SAndroid Build Coastguard Worker bool isNamedTuple() const {
237*523fa7a6SAndroid Build Coastguard Worker return handle->kind == Kind::NamedTuple;
238*523fa7a6SAndroid Build Coastguard Worker }
239*523fa7a6SAndroid Build Coastguard Worker
isTupleContainerHandle240*523fa7a6SAndroid Build Coastguard Worker bool isTuple() const {
241*523fa7a6SAndroid Build Coastguard Worker return handle->kind == Kind::Tuple;
242*523fa7a6SAndroid Build Coastguard Worker }
243*523fa7a6SAndroid Build Coastguard Worker
isLeafContainerHandle244*523fa7a6SAndroid Build Coastguard Worker bool isLeaf() const {
245*523fa7a6SAndroid Build Coastguard Worker return handle->kind == Kind::Leaf;
246*523fa7a6SAndroid Build Coastguard Worker }
247*523fa7a6SAndroid Build Coastguard Worker
kindContainerHandle248*523fa7a6SAndroid Build Coastguard Worker Kind kind() const {
249*523fa7a6SAndroid Build Coastguard Worker return handle->kind;
250*523fa7a6SAndroid Build Coastguard Worker }
251*523fa7a6SAndroid Build Coastguard Worker
252*523fa7a6SAndroid Build Coastguard Worker // Checks only structure, no leaves comparison
253*523fa7a6SAndroid Build Coastguard Worker bool operator==(const ContainerHandle& rhs) {
254*523fa7a6SAndroid Build Coastguard Worker const Kind knd = kind();
255*523fa7a6SAndroid Build Coastguard Worker if (knd != rhs.kind()) {
256*523fa7a6SAndroid Build Coastguard Worker return false;
257*523fa7a6SAndroid Build Coastguard Worker }
258*523fa7a6SAndroid Build Coastguard Worker if (knd == Kind::Leaf) {
259*523fa7a6SAndroid Build Coastguard Worker return true;
260*523fa7a6SAndroid Build Coastguard Worker }
261*523fa7a6SAndroid Build Coastguard Worker const size_t _size = size();
262*523fa7a6SAndroid Build Coastguard Worker if (_size != rhs.size()) {
263*523fa7a6SAndroid Build Coastguard Worker return false;
264*523fa7a6SAndroid Build Coastguard Worker }
265*523fa7a6SAndroid Build Coastguard Worker
266*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < _size; ++i) {
267*523fa7a6SAndroid Build Coastguard Worker if (knd == Kind::Dict && (key(i) != rhs.key(i))) {
268*523fa7a6SAndroid Build Coastguard Worker return false;
269*523fa7a6SAndroid Build Coastguard Worker }
270*523fa7a6SAndroid Build Coastguard Worker if (operator[](i) != rhs[i]) {
271*523fa7a6SAndroid Build Coastguard Worker return false;
272*523fa7a6SAndroid Build Coastguard Worker }
273*523fa7a6SAndroid Build Coastguard Worker }
274*523fa7a6SAndroid Build Coastguard Worker return true;
275*523fa7a6SAndroid Build Coastguard Worker }
276*523fa7a6SAndroid Build Coastguard Worker
277*523fa7a6SAndroid Build Coastguard Worker bool operator!=(const ContainerHandle& rhs) {
278*523fa7a6SAndroid Build Coastguard Worker return !operator==(rhs);
279*523fa7a6SAndroid Build Coastguard Worker }
280*523fa7a6SAndroid Build Coastguard Worker };
281*523fa7a6SAndroid Build Coastguard Worker
282*523fa7a6SAndroid Build Coastguard Worker struct TreeSpecLeaf {};
283*523fa7a6SAndroid Build Coastguard Worker
284*523fa7a6SAndroid Build Coastguard Worker template <typename Aux>
285*523fa7a6SAndroid Build Coastguard Worker using TreeSpec = ContainerHandle<TreeSpecLeaf, Aux>;
286*523fa7a6SAndroid Build Coastguard Worker template <typename Aux>
287*523fa7a6SAndroid Build Coastguard Worker using TreeSpecContainer = Container<TreeSpecLeaf, Aux>;
288*523fa7a6SAndroid Build Coastguard Worker
289*523fa7a6SAndroid Build Coastguard Worker using StrTreeSpec = std::string;
290*523fa7a6SAndroid Build Coastguard Worker
291*523fa7a6SAndroid Build Coastguard Worker // Expects refresh_leaves_num() was called after the last modification
292*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename U, typename Aux>
clone(const ContainerHandle<T,Aux> & node,U * leaves)293*523fa7a6SAndroid Build Coastguard Worker ContainerHandle<U, Aux> clone(const ContainerHandle<T, Aux>& node, U* leaves) {
294*523fa7a6SAndroid Build Coastguard Worker if (node.isLeaf()) {
295*523fa7a6SAndroid Build Coastguard Worker return ContainerHandle<U, Aux>(leaves);
296*523fa7a6SAndroid Build Coastguard Worker }
297*523fa7a6SAndroid Build Coastguard Worker
298*523fa7a6SAndroid Build Coastguard Worker ContainerHandle<U, Aux> ret(node.kind(), node.size());
299*523fa7a6SAndroid Build Coastguard Worker size_t leaves_offset = 0;
300*523fa7a6SAndroid Build Coastguard Worker size_t size = node.size();
301*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < size; ++i) {
302*523fa7a6SAndroid Build Coastguard Worker ret[i] = clone(node[i], leaves + leaves_offset);
303*523fa7a6SAndroid Build Coastguard Worker leaves_offset += node[i].leaves_num();
304*523fa7a6SAndroid Build Coastguard Worker }
305*523fa7a6SAndroid Build Coastguard Worker
306*523fa7a6SAndroid Build Coastguard Worker if (node.isDict()) {
307*523fa7a6SAndroid Build Coastguard Worker ret.handle->keys = std::unique_ptr<Key[]>(new Key[size]);
308*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < size; ++i) {
309*523fa7a6SAndroid Build Coastguard Worker ret.handle->keys[i] = node.handle->keys[i];
310*523fa7a6SAndroid Build Coastguard Worker }
311*523fa7a6SAndroid Build Coastguard Worker }
312*523fa7a6SAndroid Build Coastguard Worker
313*523fa7a6SAndroid Build Coastguard Worker return ret;
314*523fa7a6SAndroid Build Coastguard Worker }
315*523fa7a6SAndroid Build Coastguard Worker
316*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux>
traverse(ContainerHandle<T,Aux> & node,FunctionRef<void (ContainerHandle<T,Aux> &)> func)317*523fa7a6SAndroid Build Coastguard Worker void traverse(
318*523fa7a6SAndroid Build Coastguard Worker ContainerHandle<T, Aux>& node,
319*523fa7a6SAndroid Build Coastguard Worker FunctionRef<void(ContainerHandle<T, Aux>&)> func) {
320*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < node.size(); ++i) {
321*523fa7a6SAndroid Build Coastguard Worker traverse(node[i], func);
322*523fa7a6SAndroid Build Coastguard Worker }
323*523fa7a6SAndroid Build Coastguard Worker
324*523fa7a6SAndroid Build Coastguard Worker func(node);
325*523fa7a6SAndroid Build Coastguard Worker }
326*523fa7a6SAndroid Build Coastguard Worker
327*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux>
traverse(const ContainerHandle<T,Aux> & node,FunctionRef<void (const ContainerHandle<T,Aux> &)> func)328*523fa7a6SAndroid Build Coastguard Worker void traverse(
329*523fa7a6SAndroid Build Coastguard Worker const ContainerHandle<T, Aux>& node,
330*523fa7a6SAndroid Build Coastguard Worker FunctionRef<void(const ContainerHandle<T, Aux>&)> func) {
331*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < node.size(); ++i) {
332*523fa7a6SAndroid Build Coastguard Worker traverse(node[i], func);
333*523fa7a6SAndroid Build Coastguard Worker }
334*523fa7a6SAndroid Build Coastguard Worker
335*523fa7a6SAndroid Build Coastguard Worker func(node);
336*523fa7a6SAndroid Build Coastguard Worker }
337*523fa7a6SAndroid Build Coastguard Worker
338*523fa7a6SAndroid Build Coastguard Worker struct Config final {
339*523fa7a6SAndroid Build Coastguard Worker static constexpr char kTuple = 'T';
340*523fa7a6SAndroid Build Coastguard Worker static constexpr char kNamedTuple = 'N';
341*523fa7a6SAndroid Build Coastguard Worker static constexpr char kList = 'L';
342*523fa7a6SAndroid Build Coastguard Worker static constexpr char kDict = 'D';
343*523fa7a6SAndroid Build Coastguard Worker static constexpr char kCustom = 'C';
344*523fa7a6SAndroid Build Coastguard Worker static constexpr char kLeaf = '$';
345*523fa7a6SAndroid Build Coastguard Worker static constexpr char kNodeDataBegin = '(';
346*523fa7a6SAndroid Build Coastguard Worker static constexpr char kNodeDataEnd = ')';
347*523fa7a6SAndroid Build Coastguard Worker static constexpr char kDictStrKeyQuote = '\'';
348*523fa7a6SAndroid Build Coastguard Worker static constexpr char kDictKeyValueSep = ':';
349*523fa7a6SAndroid Build Coastguard Worker static constexpr char kChildrenSep = ',';
350*523fa7a6SAndroid Build Coastguard Worker static constexpr char kChildrenDataSep = '#';
351*523fa7a6SAndroid Build Coastguard Worker };
352*523fa7a6SAndroid Build Coastguard Worker
353*523fa7a6SAndroid Build Coastguard Worker template <typename Aux>
to_str_internal(const TreeSpec<Aux> & spec)354*523fa7a6SAndroid Build Coastguard Worker StrTreeSpec to_str_internal(const TreeSpec<Aux>& spec) {
355*523fa7a6SAndroid Build Coastguard Worker std::string s;
356*523fa7a6SAndroid Build Coastguard Worker switch (spec.kind()) {
357*523fa7a6SAndroid Build Coastguard Worker case Kind::List:
358*523fa7a6SAndroid Build Coastguard Worker s.push_back(Config::kList);
359*523fa7a6SAndroid Build Coastguard Worker break;
360*523fa7a6SAndroid Build Coastguard Worker case Kind::NamedTuple:
361*523fa7a6SAndroid Build Coastguard Worker s.push_back(Config::kNamedTuple);
362*523fa7a6SAndroid Build Coastguard Worker break;
363*523fa7a6SAndroid Build Coastguard Worker case Kind::Tuple:
364*523fa7a6SAndroid Build Coastguard Worker s.push_back(Config::kTuple);
365*523fa7a6SAndroid Build Coastguard Worker break;
366*523fa7a6SAndroid Build Coastguard Worker case Kind::Dict:
367*523fa7a6SAndroid Build Coastguard Worker s.push_back(Config::kDict);
368*523fa7a6SAndroid Build Coastguard Worker break;
369*523fa7a6SAndroid Build Coastguard Worker case Kind::Leaf:
370*523fa7a6SAndroid Build Coastguard Worker s.push_back(Config::kLeaf);
371*523fa7a6SAndroid Build Coastguard Worker return s;
372*523fa7a6SAndroid Build Coastguard Worker case Kind::Custom:
373*523fa7a6SAndroid Build Coastguard Worker s.push_back(Config::kCustom);
374*523fa7a6SAndroid Build Coastguard Worker s.push_back('(');
375*523fa7a6SAndroid Build Coastguard Worker s.append(spec.handle->custom_type);
376*523fa7a6SAndroid Build Coastguard Worker s.push_back(')');
377*523fa7a6SAndroid Build Coastguard Worker break;
378*523fa7a6SAndroid Build Coastguard Worker case Kind::None:
379*523fa7a6SAndroid Build Coastguard Worker return s;
380*523fa7a6SAndroid Build Coastguard Worker }
381*523fa7a6SAndroid Build Coastguard Worker const size_t size = spec.size();
382*523fa7a6SAndroid Build Coastguard Worker s.append(std::to_string(size));
383*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < size; ++i) {
384*523fa7a6SAndroid Build Coastguard Worker s.push_back(Config::kChildrenDataSep);
385*523fa7a6SAndroid Build Coastguard Worker s.append(std::to_string(spec[i].leaves_num()));
386*523fa7a6SAndroid Build Coastguard Worker }
387*523fa7a6SAndroid Build Coastguard Worker s.push_back(Config::kNodeDataBegin);
388*523fa7a6SAndroid Build Coastguard Worker if (spec.kind() == Kind::Dict) {
389*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < size; ++i) {
390*523fa7a6SAndroid Build Coastguard Worker if (i) {
391*523fa7a6SAndroid Build Coastguard Worker s.push_back(Config::kChildrenSep);
392*523fa7a6SAndroid Build Coastguard Worker }
393*523fa7a6SAndroid Build Coastguard Worker const auto& key = spec.key(i);
394*523fa7a6SAndroid Build Coastguard Worker if (key.kind() == Key::Kind::Int) {
395*523fa7a6SAndroid Build Coastguard Worker s.append(std::to_string(key.as_int()));
396*523fa7a6SAndroid Build Coastguard Worker } else if (key.kind() == Key::Kind::Str) {
397*523fa7a6SAndroid Build Coastguard Worker s.push_back(Config::kDictStrKeyQuote);
398*523fa7a6SAndroid Build Coastguard Worker s.append(key.as_str());
399*523fa7a6SAndroid Build Coastguard Worker s.push_back(Config::kDictStrKeyQuote);
400*523fa7a6SAndroid Build Coastguard Worker } else {
401*523fa7a6SAndroid Build Coastguard Worker pytree_unreachable();
402*523fa7a6SAndroid Build Coastguard Worker }
403*523fa7a6SAndroid Build Coastguard Worker s.push_back(Config::kDictKeyValueSep);
404*523fa7a6SAndroid Build Coastguard Worker s.append(to_str_internal(spec[i]));
405*523fa7a6SAndroid Build Coastguard Worker }
406*523fa7a6SAndroid Build Coastguard Worker } else {
407*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < size; ++i) {
408*523fa7a6SAndroid Build Coastguard Worker if (i) {
409*523fa7a6SAndroid Build Coastguard Worker s.push_back(Config::kChildrenSep);
410*523fa7a6SAndroid Build Coastguard Worker }
411*523fa7a6SAndroid Build Coastguard Worker s.append(to_str_internal(spec[i]));
412*523fa7a6SAndroid Build Coastguard Worker }
413*523fa7a6SAndroid Build Coastguard Worker }
414*523fa7a6SAndroid Build Coastguard Worker s.push_back(Config::kNodeDataEnd);
415*523fa7a6SAndroid Build Coastguard Worker return s;
416*523fa7a6SAndroid Build Coastguard Worker }
417*523fa7a6SAndroid Build Coastguard Worker
418*523fa7a6SAndroid Build Coastguard Worker template <typename T>
419*523fa7a6SAndroid Build Coastguard Worker struct arr {
arrarr420*523fa7a6SAndroid Build Coastguard Worker explicit arr(const size_t n) : data_(std::unique_ptr<T[]>(new T[n])), n_(n) {}
421*523fa7a6SAndroid Build Coastguard Worker
422*523fa7a6SAndroid Build Coastguard Worker T& operator[](size_t idx) {
423*523fa7a6SAndroid Build Coastguard Worker return data_[idx];
424*523fa7a6SAndroid Build Coastguard Worker }
425*523fa7a6SAndroid Build Coastguard Worker
426*523fa7a6SAndroid Build Coastguard Worker const T& operator[](size_t idx) const {
427*523fa7a6SAndroid Build Coastguard Worker return data_[idx];
428*523fa7a6SAndroid Build Coastguard Worker }
429*523fa7a6SAndroid Build Coastguard Worker
dataarr430*523fa7a6SAndroid Build Coastguard Worker inline T* data() {
431*523fa7a6SAndroid Build Coastguard Worker return data_.get();
432*523fa7a6SAndroid Build Coastguard Worker }
433*523fa7a6SAndroid Build Coastguard Worker
sizearr434*523fa7a6SAndroid Build Coastguard Worker inline size_t size() const {
435*523fa7a6SAndroid Build Coastguard Worker return n_;
436*523fa7a6SAndroid Build Coastguard Worker }
437*523fa7a6SAndroid Build Coastguard Worker
438*523fa7a6SAndroid Build Coastguard Worker private:
439*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<T[]> data_;
440*523fa7a6SAndroid Build Coastguard Worker size_t n_;
441*523fa7a6SAndroid Build Coastguard Worker };
442*523fa7a6SAndroid Build Coastguard Worker
read_number(const StrTreeSpec & spec,size_t & read_idx)443*523fa7a6SAndroid Build Coastguard Worker inline size_t read_number(const StrTreeSpec& spec, size_t& read_idx) {
444*523fa7a6SAndroid Build Coastguard Worker size_t num = 0;
445*523fa7a6SAndroid Build Coastguard Worker while (isdigit(spec[read_idx])) {
446*523fa7a6SAndroid Build Coastguard Worker num = 10 * num + (spec[read_idx] - '0');
447*523fa7a6SAndroid Build Coastguard Worker read_idx++;
448*523fa7a6SAndroid Build Coastguard Worker }
449*523fa7a6SAndroid Build Coastguard Worker return num;
450*523fa7a6SAndroid Build Coastguard Worker }
451*523fa7a6SAndroid Build Coastguard Worker
read_node_layout(const StrTreeSpec & spec,size_t & read_idx)452*523fa7a6SAndroid Build Coastguard Worker inline arr<size_t> read_node_layout(const StrTreeSpec& spec, size_t& read_idx) {
453*523fa7a6SAndroid Build Coastguard Worker const size_t child_num = read_number(spec, read_idx);
454*523fa7a6SAndroid Build Coastguard Worker arr<size_t> ret(child_num);
455*523fa7a6SAndroid Build Coastguard Worker
456*523fa7a6SAndroid Build Coastguard Worker size_t child_idx = 0;
457*523fa7a6SAndroid Build Coastguard Worker while (spec[read_idx] == Config::kChildrenDataSep) {
458*523fa7a6SAndroid Build Coastguard Worker ++read_idx;
459*523fa7a6SAndroid Build Coastguard Worker ret[child_idx++] = read_number(spec, read_idx);
460*523fa7a6SAndroid Build Coastguard Worker }
461*523fa7a6SAndroid Build Coastguard Worker return ret;
462*523fa7a6SAndroid Build Coastguard Worker }
463*523fa7a6SAndroid Build Coastguard Worker
464*523fa7a6SAndroid Build Coastguard Worker template <typename Aux>
from_str_internal(const StrTreeSpec & spec,size_t read_idx,const arr<size_t> & spec_data)465*523fa7a6SAndroid Build Coastguard Worker TreeSpec<Aux> from_str_internal(
466*523fa7a6SAndroid Build Coastguard Worker const StrTreeSpec& spec,
467*523fa7a6SAndroid Build Coastguard Worker size_t read_idx,
468*523fa7a6SAndroid Build Coastguard Worker const arr<size_t>& spec_data) {
469*523fa7a6SAndroid Build Coastguard Worker const auto kind_char = spec[read_idx];
470*523fa7a6SAndroid Build Coastguard Worker switch (kind_char) {
471*523fa7a6SAndroid Build Coastguard Worker case Config::kTuple:
472*523fa7a6SAndroid Build Coastguard Worker case Config::kNamedTuple:
473*523fa7a6SAndroid Build Coastguard Worker case Config::kList: {
474*523fa7a6SAndroid Build Coastguard Worker Kind kind = Kind::List;
475*523fa7a6SAndroid Build Coastguard Worker std::string custom_type;
476*523fa7a6SAndroid Build Coastguard Worker if (Config::kNamedTuple == kind_char) {
477*523fa7a6SAndroid Build Coastguard Worker kind = Kind::NamedTuple;
478*523fa7a6SAndroid Build Coastguard Worker } else if (Config::kTuple == kind_char) {
479*523fa7a6SAndroid Build Coastguard Worker kind = Kind::Tuple;
480*523fa7a6SAndroid Build Coastguard Worker } else if (Config::kCustom == kind_char) {
481*523fa7a6SAndroid Build Coastguard Worker kind = Kind::Custom;
482*523fa7a6SAndroid Build Coastguard Worker read_idx++;
483*523fa7a6SAndroid Build Coastguard Worker assert(spec[read_idx] == '(');
484*523fa7a6SAndroid Build Coastguard Worker auto type_str_end = spec_data[read_idx];
485*523fa7a6SAndroid Build Coastguard Worker read_idx++;
486*523fa7a6SAndroid Build Coastguard Worker custom_type = spec.substr(read_idx, type_str_end - read_idx);
487*523fa7a6SAndroid Build Coastguard Worker assert(false);
488*523fa7a6SAndroid Build Coastguard Worker }
489*523fa7a6SAndroid Build Coastguard Worker read_idx++;
490*523fa7a6SAndroid Build Coastguard Worker auto layout = read_node_layout(spec, read_idx);
491*523fa7a6SAndroid Build Coastguard Worker const auto size = layout.size();
492*523fa7a6SAndroid Build Coastguard Worker auto c = std::make_unique<TreeSpecContainer<Aux>>(kind, size);
493*523fa7a6SAndroid Build Coastguard Worker
494*523fa7a6SAndroid Build Coastguard Worker if (Kind::Custom == kind) {
495*523fa7a6SAndroid Build Coastguard Worker c->custom_type = std::move(custom_type);
496*523fa7a6SAndroid Build Coastguard Worker }
497*523fa7a6SAndroid Build Coastguard Worker
498*523fa7a6SAndroid Build Coastguard Worker size_t child_idx = 0;
499*523fa7a6SAndroid Build Coastguard Worker size_t leaves_offset = 0;
500*523fa7a6SAndroid Build Coastguard Worker
501*523fa7a6SAndroid Build Coastguard Worker if (size > 0) {
502*523fa7a6SAndroid Build Coastguard Worker while (spec[read_idx] != Config::kNodeDataEnd) {
503*523fa7a6SAndroid Build Coastguard Worker // NOLINTNEXTLINE
504*523fa7a6SAndroid Build Coastguard Worker auto next_delim_idx = spec_data[read_idx];
505*523fa7a6SAndroid Build Coastguard Worker read_idx++;
506*523fa7a6SAndroid Build Coastguard Worker c->items[child_idx] =
507*523fa7a6SAndroid Build Coastguard Worker from_str_internal<Aux>(spec, read_idx, spec_data);
508*523fa7a6SAndroid Build Coastguard Worker read_idx = next_delim_idx;
509*523fa7a6SAndroid Build Coastguard Worker leaves_offset += layout[child_idx++];
510*523fa7a6SAndroid Build Coastguard Worker }
511*523fa7a6SAndroid Build Coastguard Worker } else {
512*523fa7a6SAndroid Build Coastguard Worker read_idx++;
513*523fa7a6SAndroid Build Coastguard Worker }
514*523fa7a6SAndroid Build Coastguard Worker c->leaves_num = leaves_offset;
515*523fa7a6SAndroid Build Coastguard Worker return TreeSpec<Aux>(std::move(c));
516*523fa7a6SAndroid Build Coastguard Worker }
517*523fa7a6SAndroid Build Coastguard Worker
518*523fa7a6SAndroid Build Coastguard Worker case Config::kDict: {
519*523fa7a6SAndroid Build Coastguard Worker read_idx++;
520*523fa7a6SAndroid Build Coastguard Worker auto layout = read_node_layout(spec, read_idx);
521*523fa7a6SAndroid Build Coastguard Worker const auto size = layout.size();
522*523fa7a6SAndroid Build Coastguard Worker auto c = std::make_unique<TreeSpecContainer<Aux>>(Kind::Dict, size);
523*523fa7a6SAndroid Build Coastguard Worker
524*523fa7a6SAndroid Build Coastguard Worker size_t child_idx = 0;
525*523fa7a6SAndroid Build Coastguard Worker size_t leaves_offset = 0;
526*523fa7a6SAndroid Build Coastguard Worker
527*523fa7a6SAndroid Build Coastguard Worker if (size > 0) {
528*523fa7a6SAndroid Build Coastguard Worker while (spec[read_idx] != Config::kNodeDataEnd) {
529*523fa7a6SAndroid Build Coastguard Worker // NOLINTNEXTLINE
530*523fa7a6SAndroid Build Coastguard Worker auto next_delim_idx = spec_data[read_idx];
531*523fa7a6SAndroid Build Coastguard Worker read_idx++;
532*523fa7a6SAndroid Build Coastguard Worker if (spec[read_idx] == Config::kDictStrKeyQuote) {
533*523fa7a6SAndroid Build Coastguard Worker auto key_delim_idx = spec_data[read_idx];
534*523fa7a6SAndroid Build Coastguard Worker read_idx++;
535*523fa7a6SAndroid Build Coastguard Worker const size_t key_len = key_delim_idx - read_idx;
536*523fa7a6SAndroid Build Coastguard Worker // NOLINTNEXTLINE
537*523fa7a6SAndroid Build Coastguard Worker c->keys[child_idx] = spec.substr(read_idx, key_len);
538*523fa7a6SAndroid Build Coastguard Worker read_idx = key_delim_idx + 2;
539*523fa7a6SAndroid Build Coastguard Worker } else {
540*523fa7a6SAndroid Build Coastguard Worker pytree_assert(isdigit(spec[read_idx]));
541*523fa7a6SAndroid Build Coastguard Worker size_t key = read_number(spec, read_idx);
542*523fa7a6SAndroid Build Coastguard Worker c->keys[child_idx] = KeyInt(key);
543*523fa7a6SAndroid Build Coastguard Worker read_idx += 1;
544*523fa7a6SAndroid Build Coastguard Worker }
545*523fa7a6SAndroid Build Coastguard Worker
546*523fa7a6SAndroid Build Coastguard Worker c->items[child_idx] =
547*523fa7a6SAndroid Build Coastguard Worker from_str_internal<Aux>(spec, read_idx, spec_data);
548*523fa7a6SAndroid Build Coastguard Worker read_idx = next_delim_idx;
549*523fa7a6SAndroid Build Coastguard Worker leaves_offset += layout[child_idx++];
550*523fa7a6SAndroid Build Coastguard Worker }
551*523fa7a6SAndroid Build Coastguard Worker } else {
552*523fa7a6SAndroid Build Coastguard Worker read_idx++;
553*523fa7a6SAndroid Build Coastguard Worker }
554*523fa7a6SAndroid Build Coastguard Worker c->leaves_num = leaves_offset;
555*523fa7a6SAndroid Build Coastguard Worker return TreeSpec<Aux>(std::move(c));
556*523fa7a6SAndroid Build Coastguard Worker }
557*523fa7a6SAndroid Build Coastguard Worker
558*523fa7a6SAndroid Build Coastguard Worker case Config::kLeaf:
559*523fa7a6SAndroid Build Coastguard Worker return new TreeSpecContainer<Aux>(nullptr);
560*523fa7a6SAndroid Build Coastguard Worker }
561*523fa7a6SAndroid Build Coastguard Worker pytree_unreachable();
562*523fa7a6SAndroid Build Coastguard Worker return new TreeSpecContainer<Aux>(Kind::None);
563*523fa7a6SAndroid Build Coastguard Worker }
564*523fa7a6SAndroid Build Coastguard Worker
565*523fa7a6SAndroid Build Coastguard Worker template <typename T>
566*523fa7a6SAndroid Build Coastguard Worker struct stack final {
567*523fa7a6SAndroid Build Coastguard Worker constexpr static const size_t SIZE = 8;
568*523fa7a6SAndroid Build Coastguard Worker
569*523fa7a6SAndroid Build Coastguard Worker size_t size_ = 0;
570*523fa7a6SAndroid Build Coastguard Worker T data[SIZE];
571*523fa7a6SAndroid Build Coastguard Worker
pushfinal572*523fa7a6SAndroid Build Coastguard Worker void push(T&& item) {
573*523fa7a6SAndroid Build Coastguard Worker pytree_assert(size_ < SIZE);
574*523fa7a6SAndroid Build Coastguard Worker data[size_++] = std::move(item);
575*523fa7a6SAndroid Build Coastguard Worker }
576*523fa7a6SAndroid Build Coastguard Worker
popfinal577*523fa7a6SAndroid Build Coastguard Worker T pop() {
578*523fa7a6SAndroid Build Coastguard Worker pytree_assert(size_ > 0);
579*523fa7a6SAndroid Build Coastguard Worker return data[--size_];
580*523fa7a6SAndroid Build Coastguard Worker }
581*523fa7a6SAndroid Build Coastguard Worker
topfinal582*523fa7a6SAndroid Build Coastguard Worker T& top() {
583*523fa7a6SAndroid Build Coastguard Worker pytree_assert(size_ > 0);
584*523fa7a6SAndroid Build Coastguard Worker return data[size_ - 1];
585*523fa7a6SAndroid Build Coastguard Worker }
586*523fa7a6SAndroid Build Coastguard Worker
sizefinal587*523fa7a6SAndroid Build Coastguard Worker size_t size() {
588*523fa7a6SAndroid Build Coastguard Worker return size_;
589*523fa7a6SAndroid Build Coastguard Worker }
590*523fa7a6SAndroid Build Coastguard Worker };
591*523fa7a6SAndroid Build Coastguard Worker
pre_parse(const StrTreeSpec & spec)592*523fa7a6SAndroid Build Coastguard Worker inline arr<size_t> pre_parse(const StrTreeSpec& spec) {
593*523fa7a6SAndroid Build Coastguard Worker stack<std::pair<size_t, size_t>> stack;
594*523fa7a6SAndroid Build Coastguard Worker size_t i = 0;
595*523fa7a6SAndroid Build Coastguard Worker const size_t size = spec.size();
596*523fa7a6SAndroid Build Coastguard Worker arr<size_t> ret(size);
597*523fa7a6SAndroid Build Coastguard Worker while (i < size) {
598*523fa7a6SAndroid Build Coastguard Worker const auto c = spec[i];
599*523fa7a6SAndroid Build Coastguard Worker switch (c) {
600*523fa7a6SAndroid Build Coastguard Worker case Config::kNodeDataBegin: {
601*523fa7a6SAndroid Build Coastguard Worker stack.push({i, i});
602*523fa7a6SAndroid Build Coastguard Worker break;
603*523fa7a6SAndroid Build Coastguard Worker }
604*523fa7a6SAndroid Build Coastguard Worker case Config::kNodeDataEnd: {
605*523fa7a6SAndroid Build Coastguard Worker auto& item = stack.top();
606*523fa7a6SAndroid Build Coastguard Worker size_t last_sep_idx = item.second;
607*523fa7a6SAndroid Build Coastguard Worker ret[last_sep_idx] = i;
608*523fa7a6SAndroid Build Coastguard Worker stack.pop();
609*523fa7a6SAndroid Build Coastguard Worker break;
610*523fa7a6SAndroid Build Coastguard Worker }
611*523fa7a6SAndroid Build Coastguard Worker case Config::kDictStrKeyQuote: {
612*523fa7a6SAndroid Build Coastguard Worker size_t idx = i;
613*523fa7a6SAndroid Build Coastguard Worker i++;
614*523fa7a6SAndroid Build Coastguard Worker while (spec[i] != Config::kDictStrKeyQuote) {
615*523fa7a6SAndroid Build Coastguard Worker i++;
616*523fa7a6SAndroid Build Coastguard Worker }
617*523fa7a6SAndroid Build Coastguard Worker ret[idx] = i;
618*523fa7a6SAndroid Build Coastguard Worker ret[i] = idx;
619*523fa7a6SAndroid Build Coastguard Worker break;
620*523fa7a6SAndroid Build Coastguard Worker }
621*523fa7a6SAndroid Build Coastguard Worker case Config::kChildrenSep: {
622*523fa7a6SAndroid Build Coastguard Worker auto& item = stack.top();
623*523fa7a6SAndroid Build Coastguard Worker size_t last_sep_idx = item.second;
624*523fa7a6SAndroid Build Coastguard Worker ret[last_sep_idx] = i;
625*523fa7a6SAndroid Build Coastguard Worker item.second = i;
626*523fa7a6SAndroid Build Coastguard Worker break;
627*523fa7a6SAndroid Build Coastguard Worker }
628*523fa7a6SAndroid Build Coastguard Worker }
629*523fa7a6SAndroid Build Coastguard Worker i++;
630*523fa7a6SAndroid Build Coastguard Worker }
631*523fa7a6SAndroid Build Coastguard Worker return ret;
632*523fa7a6SAndroid Build Coastguard Worker }
633*523fa7a6SAndroid Build Coastguard Worker
634*523fa7a6SAndroid Build Coastguard Worker template <typename Aux = Empty>
from_str(const StrTreeSpec & spec)635*523fa7a6SAndroid Build Coastguard Worker TreeSpec<Aux> from_str(const StrTreeSpec& spec) {
636*523fa7a6SAndroid Build Coastguard Worker return from_str_internal<Aux>(spec, 0u, pre_parse(spec));
637*523fa7a6SAndroid Build Coastguard Worker }
638*523fa7a6SAndroid Build Coastguard Worker
639*523fa7a6SAndroid Build Coastguard Worker template <typename Aux>
to_str(const TreeSpec<Aux> & spec)640*523fa7a6SAndroid Build Coastguard Worker StrTreeSpec to_str(const TreeSpec<Aux>& spec) {
641*523fa7a6SAndroid Build Coastguard Worker if (spec.leaves_num() == 0) {
642*523fa7a6SAndroid Build Coastguard Worker refresh_leaves_num(spec);
643*523fa7a6SAndroid Build Coastguard Worker }
644*523fa7a6SAndroid Build Coastguard Worker return to_str_internal(spec);
645*523fa7a6SAndroid Build Coastguard Worker }
646*523fa7a6SAndroid Build Coastguard Worker
647*523fa7a6SAndroid Build Coastguard Worker template <typename Aux>
648*523fa7a6SAndroid Build Coastguard Worker StrTreeSpec to_str(const TreeSpec<Aux>& spec);
649*523fa7a6SAndroid Build Coastguard Worker
650*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux>
unflatten(const TreeSpec<Aux> & spec,T * leaves)651*523fa7a6SAndroid Build Coastguard Worker ContainerHandle<T, Aux> unflatten(const TreeSpec<Aux>& spec, T* leaves) {
652*523fa7a6SAndroid Build Coastguard Worker if (spec.leaves_num() == 0) {
653*523fa7a6SAndroid Build Coastguard Worker refresh_leaves_num(spec);
654*523fa7a6SAndroid Build Coastguard Worker }
655*523fa7a6SAndroid Build Coastguard Worker return clone(spec, leaves);
656*523fa7a6SAndroid Build Coastguard Worker }
657*523fa7a6SAndroid Build Coastguard Worker
658*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux = Empty>
unflatten(const StrTreeSpec & spec,T * leaves)659*523fa7a6SAndroid Build Coastguard Worker ContainerHandle<T, Aux> unflatten(const StrTreeSpec& spec, T* leaves) {
660*523fa7a6SAndroid Build Coastguard Worker return unflatten(from_str<Aux>(spec), leaves);
661*523fa7a6SAndroid Build Coastguard Worker }
662*523fa7a6SAndroid Build Coastguard Worker
663*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux>
flatten_internal(const ContainerHandle<T,Aux> & tree,const T ** leaves)664*523fa7a6SAndroid Build Coastguard Worker void flatten_internal(const ContainerHandle<T, Aux>& tree, const T** leaves) {
665*523fa7a6SAndroid Build Coastguard Worker using tree_t = decltype(tree);
666*523fa7a6SAndroid Build Coastguard Worker size_t leaves_idx = 0;
667*523fa7a6SAndroid Build Coastguard Worker auto func = [&](tree_t node) {
668*523fa7a6SAndroid Build Coastguard Worker if (node.isLeaf()) {
669*523fa7a6SAndroid Build Coastguard Worker leaves[leaves_idx++] = node.leaf_ptr();
670*523fa7a6SAndroid Build Coastguard Worker }
671*523fa7a6SAndroid Build Coastguard Worker };
672*523fa7a6SAndroid Build Coastguard Worker traverse(tree, FunctionRef<void(tree_t&)>{func});
673*523fa7a6SAndroid Build Coastguard Worker }
674*523fa7a6SAndroid Build Coastguard Worker
675*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux>
flatten_internal(ContainerHandle<T,Aux> & tree,T ** leaves)676*523fa7a6SAndroid Build Coastguard Worker void flatten_internal(ContainerHandle<T, Aux>& tree, T** leaves) {
677*523fa7a6SAndroid Build Coastguard Worker using tree_t = decltype(tree);
678*523fa7a6SAndroid Build Coastguard Worker size_t leaves_idx = 0;
679*523fa7a6SAndroid Build Coastguard Worker auto func = [&](tree_t node) {
680*523fa7a6SAndroid Build Coastguard Worker if (node.isLeaf()) {
681*523fa7a6SAndroid Build Coastguard Worker leaves[leaves_idx++] = node.leaf_ptr();
682*523fa7a6SAndroid Build Coastguard Worker }
683*523fa7a6SAndroid Build Coastguard Worker };
684*523fa7a6SAndroid Build Coastguard Worker traverse(tree, FunctionRef<void(tree_t&)>{func});
685*523fa7a6SAndroid Build Coastguard Worker }
686*523fa7a6SAndroid Build Coastguard Worker
687*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux>
refresh_leaves_num(const ContainerHandle<T,Aux> & node)688*523fa7a6SAndroid Build Coastguard Worker size_t refresh_leaves_num(const ContainerHandle<T, Aux>& node) {
689*523fa7a6SAndroid Build Coastguard Worker if (node.isLeaf()) {
690*523fa7a6SAndroid Build Coastguard Worker node.handle->leaves_num = 1;
691*523fa7a6SAndroid Build Coastguard Worker return 1;
692*523fa7a6SAndroid Build Coastguard Worker }
693*523fa7a6SAndroid Build Coastguard Worker
694*523fa7a6SAndroid Build Coastguard Worker size_t n = 0;
695*523fa7a6SAndroid Build Coastguard Worker for (size_t i = 0; i < node.size(); ++i) {
696*523fa7a6SAndroid Build Coastguard Worker n += refresh_leaves_num(node[i]);
697*523fa7a6SAndroid Build Coastguard Worker }
698*523fa7a6SAndroid Build Coastguard Worker
699*523fa7a6SAndroid Build Coastguard Worker node.handle->leaves_num = n;
700*523fa7a6SAndroid Build Coastguard Worker return n;
701*523fa7a6SAndroid Build Coastguard Worker }
702*523fa7a6SAndroid Build Coastguard Worker
703*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux>
flatten(const ContainerHandle<T,Aux> & tree)704*523fa7a6SAndroid Build Coastguard Worker std::pair<arr<const T*>, std::unique_ptr<TreeSpec<Aux>>> flatten(
705*523fa7a6SAndroid Build Coastguard Worker const ContainerHandle<T, Aux>& tree) {
706*523fa7a6SAndroid Build Coastguard Worker refresh_leaves_num(tree);
707*523fa7a6SAndroid Build Coastguard Worker const size_t n = tree.leaves_num();
708*523fa7a6SAndroid Build Coastguard Worker arr<T*> leaves(n);
709*523fa7a6SAndroid Build Coastguard Worker flatten_internal(tree, leaves.data());
710*523fa7a6SAndroid Build Coastguard Worker auto spec_leaves = std::make_unique<TreeSpecLeaf[]>(n);
711*523fa7a6SAndroid Build Coastguard Worker return {
712*523fa7a6SAndroid Build Coastguard Worker std::move(leaves),
713*523fa7a6SAndroid Build Coastguard Worker std::make_unique<TreeSpec<Aux>>(clone(tree, spec_leaves.get()))};
714*523fa7a6SAndroid Build Coastguard Worker }
715*523fa7a6SAndroid Build Coastguard Worker
716*523fa7a6SAndroid Build Coastguard Worker // Duplication of logic for non const ContainerHandle
717*523fa7a6SAndroid Build Coastguard Worker template <typename T, typename Aux>
flatten(ContainerHandle<T,Aux> & tree)718*523fa7a6SAndroid Build Coastguard Worker std::pair<arr<T*>, std::unique_ptr<TreeSpec<Aux>>> flatten(
719*523fa7a6SAndroid Build Coastguard Worker ContainerHandle<T, Aux>& tree) {
720*523fa7a6SAndroid Build Coastguard Worker refresh_leaves_num(tree);
721*523fa7a6SAndroid Build Coastguard Worker const size_t n = tree.leaves_num();
722*523fa7a6SAndroid Build Coastguard Worker arr<T*> leaves(n);
723*523fa7a6SAndroid Build Coastguard Worker flatten_internal(tree, leaves.data());
724*523fa7a6SAndroid Build Coastguard Worker auto spec_leaves = std::make_unique<TreeSpecLeaf[]>(n);
725*523fa7a6SAndroid Build Coastguard Worker return {
726*523fa7a6SAndroid Build Coastguard Worker std::move(leaves),
727*523fa7a6SAndroid Build Coastguard Worker std::make_unique<TreeSpec<Aux>>(clone(tree, spec_leaves.get()))};
728*523fa7a6SAndroid Build Coastguard Worker }
729*523fa7a6SAndroid Build Coastguard Worker
730*523fa7a6SAndroid Build Coastguard Worker } // namespace pytree
731*523fa7a6SAndroid Build Coastguard Worker } // namespace extension
732*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
733*523fa7a6SAndroid Build Coastguard Worker
734*523fa7a6SAndroid Build Coastguard Worker namespace torch {
735*523fa7a6SAndroid Build Coastguard Worker namespace executor {
736*523fa7a6SAndroid Build Coastguard Worker namespace pytree {
737*523fa7a6SAndroid Build Coastguard Worker // TODO(T197294990): Remove these deprecated aliases once all users have moved
738*523fa7a6SAndroid Build Coastguard Worker // to the new `::executorch` namespaces.
739*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::pytree::Empty;
740*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::pytree::from_str;
741*523fa7a6SAndroid Build Coastguard Worker using ::executorch::extension::pytree::TreeSpec;
742*523fa7a6SAndroid Build Coastguard Worker } // namespace pytree
743*523fa7a6SAndroid Build Coastguard Worker } // namespace executor
744*523fa7a6SAndroid Build Coastguard Worker } // namespace torch
745