xref: /aosp_15_r20/external/executorch/extension/pytree/pytree.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <ctype.h>
12 #include <algorithm>
13 #include <cassert>
14 #include <cstdint>
15 #include <cstring>
16 #include <memory>
17 #include <string>
18 #include <variant>
19 
20 // NB: This is a local, pytree FunctionRef and not from the ExecuTorch runtime.
21 #include <executorch/extension/pytree/function_ref.h>
22 
23 namespace executorch {
24 namespace extension {
25 namespace pytree {
26 
pytree_assert(bool must_be_true)27 inline void pytree_assert(bool must_be_true) {
28   assert(must_be_true);
29 }
30 
31 #ifdef _MSC_VER
32 #define EXECUTORCH_ALWAYS_INLINE __forceinline
33 #elif defined(__GNUC__)
34 #define EXECUTORCH_ALWAYS_INLINE inline __attribute__((__always_inline__))
35 #else
36 #define EXECUTORCH_ALWAYS_INLINE inline
37 #endif
38 
pytree_unreachable()39 [[noreturn]] EXECUTORCH_ALWAYS_INLINE void pytree_unreachable() {
40   assert(false);
41 #if defined(__GNUC__)
42   __builtin_unreachable();
43 #elif defined(_MSC_VER)
44   __assume(0);
45 #else
46   while (!0)
47     ;
48 #endif
49 }
50 
51 enum class Kind : uint8_t { List, Tuple, NamedTuple, Dict, Leaf, Custom, None };
52 
53 using KeyStr = std::string;
54 using KeyInt = int32_t;
55 
56 struct Key {
57   enum class Kind : uint8_t { None, Int, Str } kind_;
58 
59  private:
60   std::variant<std::monostate, KeyInt, KeyStr> repr_;
61 
62  public:
KeyKey63   Key() {}
KeyKey64   /*implicit*/ Key(KeyInt key) : repr_(key) {}
KeyKey65   /*implicit*/ Key(KeyStr key) : repr_(std::move(key)) {}
66 
kindKey67   Kind kind() const {
68     return static_cast<Kind>(repr_.index());
69   }
70 
as_intKey71   KeyInt as_int() const {
72     return std::get<KeyInt>(repr_);
73   }
74 
KeyIntKey75   operator KeyInt() const {
76     return as_int();
77   }
78 
as_strKey79   const KeyStr& as_str() const {
80     return std::get<KeyStr>(repr_);
81   }
82 
83   operator const KeyStr&() const {
84     return as_str();
85   }
86 
87   bool operator==(const Key& rhs) const {
88     return repr_ == rhs.repr_;
89   }
90 
91   bool operator!=(const Key& rhs) const {
92     return !operator==(rhs);
93   }
94 };
95 
96 struct Empty {};
97 template <typename T, typename Aux = Empty>
98 struct ContainerHandle;
99 
100 template <typename T, typename Aux = Empty>
101 struct Container final : public Aux {
102   using handle_type = ContainerHandle<T, Aux>;
103   using leaf_type = T;
104 
105   Kind kind = Kind::None;
106   size_t size = 0;
107   leaf_type* leaf = nullptr;
108   std::unique_ptr<handle_type[]> items;
109   std::unique_ptr<Key[]> keys;
110   std::string custom_type;
111   // internal only field to keep associated to every node meta info
112   mutable size_t leaves_num = 0u;
113 
114   /*implicit*/ Container(Kind kind, size_t size = 0u)
kindfinal115       : kind(kind),
116         size(size),
117         items(std::unique_ptr<handle_type[]>(new handle_type[size])) {
118     if (kind == Kind::Dict) {
119       keys = std::unique_ptr<Key[]>(new Key[size]);
120     }
121   }
Containerfinal122   /*implicit*/ Container(leaf_type* leaf)
123       : kind(Kind::Leaf), size(0u), leaf(leaf), leaves_num(1u) {}
124   Container(const Container&) = delete;
125   Container& operator=(const Container&) = delete;
126 };
127 
128 template <typename T, typename Aux>
129 struct ContainerHandle {
130   using container_type = Container<T, Aux>;
131   using leaf_type = T;
132   std::unique_ptr<container_type> handle;
133 
ContainerHandleContainerHandle134   ContainerHandle() {}
135 
136   template <typename... Args>
ContainerHandleContainerHandle137   ContainerHandle(Args... args)
138       : handle(std::make_unique<container_type>(std::forward<Args>(args)...)) {}
139 
ContainerHandleContainerHandle140   /*implicit*/ ContainerHandle(container_type* c) : handle(c) {}
141 
ContainerHandleContainerHandle142   /*implicit*/ ContainerHandle(std::unique_ptr<container_type> c)
143       : handle(std::move(c)) {}
144 
set_leafContainerHandle145   void set_leaf(leaf_type* leaf) {
146     pytree_assert(handle->kind == Kind::Leaf);
147     handle->leaf = leaf;
148   }
149 
leaf_typeContainerHandle150   operator leaf_type() const {
151     pytree_assert(handle->kind == Kind::Leaf);
152     return *handle->leaf;
153   }
154 
leafContainerHandle155   const leaf_type& leaf() const {
156     pytree_assert(handle->kind == Kind::Leaf);
157     return *handle->leaf;
158   }
leafContainerHandle159   leaf_type& leaf() {
160     pytree_assert(handle->kind == Kind::Leaf);
161     return *handle->leaf;
162   }
163 
leaf_ptrContainerHandle164   const leaf_type* leaf_ptr() const {
165     pytree_assert(handle->kind == Kind::Leaf);
166     return handle->leaf;
167   }
leaf_ptrContainerHandle168   leaf_type* leaf_ptr() {
169     pytree_assert(handle->kind == Kind::Leaf);
170     return handle->leaf;
171   }
172 
173   const ContainerHandle& operator[](size_t idx) const {
174     pytree_assert(idx < handle->size);
175     return handle->items[idx];
176   }
177 
178   ContainerHandle& operator[](size_t idx) {
179     pytree_assert(idx < handle->size);
180     return handle->items[idx];
181   }
182 
containsContainerHandle183   bool contains(const KeyStr& lookup_key) const {
184     pytree_assert(isDict());
185     for (size_t i = 0; i < handle->size; ++i) {
186       if (handle->keys[i] == lookup_key) {
187         return true;
188       }
189     }
190     return false;
191   }
192 
atContainerHandle193   const ContainerHandle& at(const Key& lookup_key) const {
194     pytree_assert(isDict());
195     for (size_t i = 0; i < handle->size; ++i) {
196       if (handle->keys[i] == lookup_key) {
197         return handle->items[i];
198       }
199     }
200     pytree_unreachable();
201   }
202 
atContainerHandle203   const ContainerHandle& at(const KeyInt& lookup_key) const {
204     return at(Key(lookup_key));
205   }
206 
atContainerHandle207   const ContainerHandle& at(const KeyStr& lookup_key) const {
208     return at(Key(lookup_key));
209   }
210 
keyContainerHandle211   const Key& key(size_t idx) const {
212     pytree_assert(isDict());
213     return handle->keys[idx];
214   }
keyContainerHandle215   Key& key(size_t idx) {
216     pytree_assert(isDict());
217     return handle->keys[idx];
218   }
219 
sizeContainerHandle220   size_t size() const {
221     return handle->size;
222   }
223 
leaves_numContainerHandle224   size_t leaves_num() const {
225     return handle->leaves_num;
226   }
227 
isDictContainerHandle228   bool isDict() const {
229     return handle->kind == Kind::Dict;
230   }
231 
isListContainerHandle232   bool isList() const {
233     return handle->kind == Kind::List;
234   }
235 
isNamedTupleContainerHandle236   bool isNamedTuple() const {
237     return handle->kind == Kind::NamedTuple;
238   }
239 
isTupleContainerHandle240   bool isTuple() const {
241     return handle->kind == Kind::Tuple;
242   }
243 
isLeafContainerHandle244   bool isLeaf() const {
245     return handle->kind == Kind::Leaf;
246   }
247 
kindContainerHandle248   Kind kind() const {
249     return handle->kind;
250   }
251 
252   // Checks only structure, no leaves comparison
253   bool operator==(const ContainerHandle& rhs) {
254     const Kind knd = kind();
255     if (knd != rhs.kind()) {
256       return false;
257     }
258     if (knd == Kind::Leaf) {
259       return true;
260     }
261     const size_t _size = size();
262     if (_size != rhs.size()) {
263       return false;
264     }
265 
266     for (size_t i = 0; i < _size; ++i) {
267       if (knd == Kind::Dict && (key(i) != rhs.key(i))) {
268         return false;
269       }
270       if (operator[](i) != rhs[i]) {
271         return false;
272       }
273     }
274     return true;
275   }
276 
277   bool operator!=(const ContainerHandle& rhs) {
278     return !operator==(rhs);
279   }
280 };
281 
282 struct TreeSpecLeaf {};
283 
284 template <typename Aux>
285 using TreeSpec = ContainerHandle<TreeSpecLeaf, Aux>;
286 template <typename Aux>
287 using TreeSpecContainer = Container<TreeSpecLeaf, Aux>;
288 
289 using StrTreeSpec = std::string;
290 
291 // Expects refresh_leaves_num() was called after the last modification
292 template <typename T, typename U, typename Aux>
clone(const ContainerHandle<T,Aux> & node,U * leaves)293 ContainerHandle<U, Aux> clone(const ContainerHandle<T, Aux>& node, U* leaves) {
294   if (node.isLeaf()) {
295     return ContainerHandle<U, Aux>(leaves);
296   }
297 
298   ContainerHandle<U, Aux> ret(node.kind(), node.size());
299   size_t leaves_offset = 0;
300   size_t size = node.size();
301   for (size_t i = 0; i < size; ++i) {
302     ret[i] = clone(node[i], leaves + leaves_offset);
303     leaves_offset += node[i].leaves_num();
304   }
305 
306   if (node.isDict()) {
307     ret.handle->keys = std::unique_ptr<Key[]>(new Key[size]);
308     for (size_t i = 0; i < size; ++i) {
309       ret.handle->keys[i] = node.handle->keys[i];
310     }
311   }
312 
313   return ret;
314 }
315 
316 template <typename T, typename Aux>
traverse(ContainerHandle<T,Aux> & node,FunctionRef<void (ContainerHandle<T,Aux> &)> func)317 void traverse(
318     ContainerHandle<T, Aux>& node,
319     FunctionRef<void(ContainerHandle<T, Aux>&)> func) {
320   for (size_t i = 0; i < node.size(); ++i) {
321     traverse(node[i], func);
322   }
323 
324   func(node);
325 }
326 
327 template <typename T, typename Aux>
traverse(const ContainerHandle<T,Aux> & node,FunctionRef<void (const ContainerHandle<T,Aux> &)> func)328 void traverse(
329     const ContainerHandle<T, Aux>& node,
330     FunctionRef<void(const ContainerHandle<T, Aux>&)> func) {
331   for (size_t i = 0; i < node.size(); ++i) {
332     traverse(node[i], func);
333   }
334 
335   func(node);
336 }
337 
338 struct Config final {
339   static constexpr char kTuple = 'T';
340   static constexpr char kNamedTuple = 'N';
341   static constexpr char kList = 'L';
342   static constexpr char kDict = 'D';
343   static constexpr char kCustom = 'C';
344   static constexpr char kLeaf = '$';
345   static constexpr char kNodeDataBegin = '(';
346   static constexpr char kNodeDataEnd = ')';
347   static constexpr char kDictStrKeyQuote = '\'';
348   static constexpr char kDictKeyValueSep = ':';
349   static constexpr char kChildrenSep = ',';
350   static constexpr char kChildrenDataSep = '#';
351 };
352 
353 template <typename Aux>
to_str_internal(const TreeSpec<Aux> & spec)354 StrTreeSpec to_str_internal(const TreeSpec<Aux>& spec) {
355   std::string s;
356   switch (spec.kind()) {
357     case Kind::List:
358       s.push_back(Config::kList);
359       break;
360     case Kind::NamedTuple:
361       s.push_back(Config::kNamedTuple);
362       break;
363     case Kind::Tuple:
364       s.push_back(Config::kTuple);
365       break;
366     case Kind::Dict:
367       s.push_back(Config::kDict);
368       break;
369     case Kind::Leaf:
370       s.push_back(Config::kLeaf);
371       return s;
372     case Kind::Custom:
373       s.push_back(Config::kCustom);
374       s.push_back('(');
375       s.append(spec.handle->custom_type);
376       s.push_back(')');
377       break;
378     case Kind::None:
379       return s;
380   }
381   const size_t size = spec.size();
382   s.append(std::to_string(size));
383   for (size_t i = 0; i < size; ++i) {
384     s.push_back(Config::kChildrenDataSep);
385     s.append(std::to_string(spec[i].leaves_num()));
386   }
387   s.push_back(Config::kNodeDataBegin);
388   if (spec.kind() == Kind::Dict) {
389     for (size_t i = 0; i < size; ++i) {
390       if (i) {
391         s.push_back(Config::kChildrenSep);
392       }
393       const auto& key = spec.key(i);
394       if (key.kind() == Key::Kind::Int) {
395         s.append(std::to_string(key.as_int()));
396       } else if (key.kind() == Key::Kind::Str) {
397         s.push_back(Config::kDictStrKeyQuote);
398         s.append(key.as_str());
399         s.push_back(Config::kDictStrKeyQuote);
400       } else {
401         pytree_unreachable();
402       }
403       s.push_back(Config::kDictKeyValueSep);
404       s.append(to_str_internal(spec[i]));
405     }
406   } else {
407     for (size_t i = 0; i < size; ++i) {
408       if (i) {
409         s.push_back(Config::kChildrenSep);
410       }
411       s.append(to_str_internal(spec[i]));
412     }
413   }
414   s.push_back(Config::kNodeDataEnd);
415   return s;
416 }
417 
418 template <typename T>
419 struct arr {
arrarr420   explicit arr(const size_t n) : data_(std::unique_ptr<T[]>(new T[n])), n_(n) {}
421 
422   T& operator[](size_t idx) {
423     return data_[idx];
424   }
425 
426   const T& operator[](size_t idx) const {
427     return data_[idx];
428   }
429 
dataarr430   inline T* data() {
431     return data_.get();
432   }
433 
sizearr434   inline size_t size() const {
435     return n_;
436   }
437 
438  private:
439   std::unique_ptr<T[]> data_;
440   size_t n_;
441 };
442 
read_number(const StrTreeSpec & spec,size_t & read_idx)443 inline size_t read_number(const StrTreeSpec& spec, size_t& read_idx) {
444   size_t num = 0;
445   while (isdigit(spec[read_idx])) {
446     num = 10 * num + (spec[read_idx] - '0');
447     read_idx++;
448   }
449   return num;
450 }
451 
read_node_layout(const StrTreeSpec & spec,size_t & read_idx)452 inline arr<size_t> read_node_layout(const StrTreeSpec& spec, size_t& read_idx) {
453   const size_t child_num = read_number(spec, read_idx);
454   arr<size_t> ret(child_num);
455 
456   size_t child_idx = 0;
457   while (spec[read_idx] == Config::kChildrenDataSep) {
458     ++read_idx;
459     ret[child_idx++] = read_number(spec, read_idx);
460   }
461   return ret;
462 }
463 
464 template <typename Aux>
from_str_internal(const StrTreeSpec & spec,size_t read_idx,const arr<size_t> & spec_data)465 TreeSpec<Aux> from_str_internal(
466     const StrTreeSpec& spec,
467     size_t read_idx,
468     const arr<size_t>& spec_data) {
469   const auto kind_char = spec[read_idx];
470   switch (kind_char) {
471     case Config::kTuple:
472     case Config::kNamedTuple:
473     case Config::kList: {
474       Kind kind = Kind::List;
475       std::string custom_type;
476       if (Config::kNamedTuple == kind_char) {
477         kind = Kind::NamedTuple;
478       } else if (Config::kTuple == kind_char) {
479         kind = Kind::Tuple;
480       } else if (Config::kCustom == kind_char) {
481         kind = Kind::Custom;
482         read_idx++;
483         assert(spec[read_idx] == '(');
484         auto type_str_end = spec_data[read_idx];
485         read_idx++;
486         custom_type = spec.substr(read_idx, type_str_end - read_idx);
487         assert(false);
488       }
489       read_idx++;
490       auto layout = read_node_layout(spec, read_idx);
491       const auto size = layout.size();
492       auto c = std::make_unique<TreeSpecContainer<Aux>>(kind, size);
493 
494       if (Kind::Custom == kind) {
495         c->custom_type = std::move(custom_type);
496       }
497 
498       size_t child_idx = 0;
499       size_t leaves_offset = 0;
500 
501       if (size > 0) {
502         while (spec[read_idx] != Config::kNodeDataEnd) {
503           // NOLINTNEXTLINE
504           auto next_delim_idx = spec_data[read_idx];
505           read_idx++;
506           c->items[child_idx] =
507               from_str_internal<Aux>(spec, read_idx, spec_data);
508           read_idx = next_delim_idx;
509           leaves_offset += layout[child_idx++];
510         }
511       } else {
512         read_idx++;
513       }
514       c->leaves_num = leaves_offset;
515       return TreeSpec<Aux>(std::move(c));
516     }
517 
518     case Config::kDict: {
519       read_idx++;
520       auto layout = read_node_layout(spec, read_idx);
521       const auto size = layout.size();
522       auto c = std::make_unique<TreeSpecContainer<Aux>>(Kind::Dict, size);
523 
524       size_t child_idx = 0;
525       size_t leaves_offset = 0;
526 
527       if (size > 0) {
528         while (spec[read_idx] != Config::kNodeDataEnd) {
529           // NOLINTNEXTLINE
530           auto next_delim_idx = spec_data[read_idx];
531           read_idx++;
532           if (spec[read_idx] == Config::kDictStrKeyQuote) {
533             auto key_delim_idx = spec_data[read_idx];
534             read_idx++;
535             const size_t key_len = key_delim_idx - read_idx;
536             // NOLINTNEXTLINE
537             c->keys[child_idx] = spec.substr(read_idx, key_len);
538             read_idx = key_delim_idx + 2;
539           } else {
540             pytree_assert(isdigit(spec[read_idx]));
541             size_t key = read_number(spec, read_idx);
542             c->keys[child_idx] = KeyInt(key);
543             read_idx += 1;
544           }
545 
546           c->items[child_idx] =
547               from_str_internal<Aux>(spec, read_idx, spec_data);
548           read_idx = next_delim_idx;
549           leaves_offset += layout[child_idx++];
550         }
551       } else {
552         read_idx++;
553       }
554       c->leaves_num = leaves_offset;
555       return TreeSpec<Aux>(std::move(c));
556     }
557 
558     case Config::kLeaf:
559       return new TreeSpecContainer<Aux>(nullptr);
560   }
561   pytree_unreachable();
562   return new TreeSpecContainer<Aux>(Kind::None);
563 }
564 
565 template <typename T>
566 struct stack final {
567   constexpr static const size_t SIZE = 8;
568 
569   size_t size_ = 0;
570   T data[SIZE];
571 
pushfinal572   void push(T&& item) {
573     pytree_assert(size_ < SIZE);
574     data[size_++] = std::move(item);
575   }
576 
popfinal577   T pop() {
578     pytree_assert(size_ > 0);
579     return data[--size_];
580   }
581 
topfinal582   T& top() {
583     pytree_assert(size_ > 0);
584     return data[size_ - 1];
585   }
586 
sizefinal587   size_t size() {
588     return size_;
589   }
590 };
591 
pre_parse(const StrTreeSpec & spec)592 inline arr<size_t> pre_parse(const StrTreeSpec& spec) {
593   stack<std::pair<size_t, size_t>> stack;
594   size_t i = 0;
595   const size_t size = spec.size();
596   arr<size_t> ret(size);
597   while (i < size) {
598     const auto c = spec[i];
599     switch (c) {
600       case Config::kNodeDataBegin: {
601         stack.push({i, i});
602         break;
603       }
604       case Config::kNodeDataEnd: {
605         auto& item = stack.top();
606         size_t last_sep_idx = item.second;
607         ret[last_sep_idx] = i;
608         stack.pop();
609         break;
610       }
611       case Config::kDictStrKeyQuote: {
612         size_t idx = i;
613         i++;
614         while (spec[i] != Config::kDictStrKeyQuote) {
615           i++;
616         }
617         ret[idx] = i;
618         ret[i] = idx;
619         break;
620       }
621       case Config::kChildrenSep: {
622         auto& item = stack.top();
623         size_t last_sep_idx = item.second;
624         ret[last_sep_idx] = i;
625         item.second = i;
626         break;
627       }
628     }
629     i++;
630   }
631   return ret;
632 }
633 
634 template <typename Aux = Empty>
from_str(const StrTreeSpec & spec)635 TreeSpec<Aux> from_str(const StrTreeSpec& spec) {
636   return from_str_internal<Aux>(spec, 0u, pre_parse(spec));
637 }
638 
639 template <typename Aux>
to_str(const TreeSpec<Aux> & spec)640 StrTreeSpec to_str(const TreeSpec<Aux>& spec) {
641   if (spec.leaves_num() == 0) {
642     refresh_leaves_num(spec);
643   }
644   return to_str_internal(spec);
645 }
646 
647 template <typename Aux>
648 StrTreeSpec to_str(const TreeSpec<Aux>& spec);
649 
650 template <typename T, typename Aux>
unflatten(const TreeSpec<Aux> & spec,T * leaves)651 ContainerHandle<T, Aux> unflatten(const TreeSpec<Aux>& spec, T* leaves) {
652   if (spec.leaves_num() == 0) {
653     refresh_leaves_num(spec);
654   }
655   return clone(spec, leaves);
656 }
657 
658 template <typename T, typename Aux = Empty>
unflatten(const StrTreeSpec & spec,T * leaves)659 ContainerHandle<T, Aux> unflatten(const StrTreeSpec& spec, T* leaves) {
660   return unflatten(from_str<Aux>(spec), leaves);
661 }
662 
663 template <typename T, typename Aux>
flatten_internal(const ContainerHandle<T,Aux> & tree,const T ** leaves)664 void flatten_internal(const ContainerHandle<T, Aux>& tree, const T** leaves) {
665   using tree_t = decltype(tree);
666   size_t leaves_idx = 0;
667   auto func = [&](tree_t node) {
668     if (node.isLeaf()) {
669       leaves[leaves_idx++] = node.leaf_ptr();
670     }
671   };
672   traverse(tree, FunctionRef<void(tree_t&)>{func});
673 }
674 
675 template <typename T, typename Aux>
flatten_internal(ContainerHandle<T,Aux> & tree,T ** leaves)676 void flatten_internal(ContainerHandle<T, Aux>& tree, T** leaves) {
677   using tree_t = decltype(tree);
678   size_t leaves_idx = 0;
679   auto func = [&](tree_t node) {
680     if (node.isLeaf()) {
681       leaves[leaves_idx++] = node.leaf_ptr();
682     }
683   };
684   traverse(tree, FunctionRef<void(tree_t&)>{func});
685 }
686 
687 template <typename T, typename Aux>
refresh_leaves_num(const ContainerHandle<T,Aux> & node)688 size_t refresh_leaves_num(const ContainerHandle<T, Aux>& node) {
689   if (node.isLeaf()) {
690     node.handle->leaves_num = 1;
691     return 1;
692   }
693 
694   size_t n = 0;
695   for (size_t i = 0; i < node.size(); ++i) {
696     n += refresh_leaves_num(node[i]);
697   }
698 
699   node.handle->leaves_num = n;
700   return n;
701 }
702 
703 template <typename T, typename Aux>
flatten(const ContainerHandle<T,Aux> & tree)704 std::pair<arr<const T*>, std::unique_ptr<TreeSpec<Aux>>> flatten(
705     const ContainerHandle<T, Aux>& tree) {
706   refresh_leaves_num(tree);
707   const size_t n = tree.leaves_num();
708   arr<T*> leaves(n);
709   flatten_internal(tree, leaves.data());
710   auto spec_leaves = std::make_unique<TreeSpecLeaf[]>(n);
711   return {
712       std::move(leaves),
713       std::make_unique<TreeSpec<Aux>>(clone(tree, spec_leaves.get()))};
714 }
715 
716 // Duplication of logic for non const ContainerHandle
717 template <typename T, typename Aux>
flatten(ContainerHandle<T,Aux> & tree)718 std::pair<arr<T*>, std::unique_ptr<TreeSpec<Aux>>> flatten(
719     ContainerHandle<T, Aux>& tree) {
720   refresh_leaves_num(tree);
721   const size_t n = tree.leaves_num();
722   arr<T*> leaves(n);
723   flatten_internal(tree, leaves.data());
724   auto spec_leaves = std::make_unique<TreeSpecLeaf[]>(n);
725   return {
726       std::move(leaves),
727       std::make_unique<TreeSpec<Aux>>(clone(tree, spec_leaves.get()))};
728 }
729 
730 } // namespace pytree
731 } // namespace extension
732 } // namespace executorch
733 
734 namespace torch {
735 namespace executor {
736 namespace pytree {
737 // TODO(T197294990): Remove these deprecated aliases once all users have moved
738 // to the new `::executorch` namespaces.
739 using ::executorch::extension::pytree::Empty;
740 using ::executorch::extension::pytree::from_str;
741 using ::executorch::extension::pytree::TreeSpec;
742 } // namespace pytree
743 } // namespace executor
744 } // namespace torch
745