/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include #include #include #include #include #include #include #include // NB: This is a local, pytree FunctionRef and not from the ExecuTorch runtime. #include namespace executorch { namespace extension { namespace pytree { inline void pytree_assert(bool must_be_true) { assert(must_be_true); } #ifdef _MSC_VER #define EXECUTORCH_ALWAYS_INLINE __forceinline #elif defined(__GNUC__) #define EXECUTORCH_ALWAYS_INLINE inline __attribute__((__always_inline__)) #else #define EXECUTORCH_ALWAYS_INLINE inline #endif [[noreturn]] EXECUTORCH_ALWAYS_INLINE void pytree_unreachable() { assert(false); #if defined(__GNUC__) __builtin_unreachable(); #elif defined(_MSC_VER) __assume(0); #else while (!0) ; #endif } enum class Kind : uint8_t { List, Tuple, NamedTuple, Dict, Leaf, Custom, None }; using KeyStr = std::string; using KeyInt = int32_t; struct Key { enum class Kind : uint8_t { None, Int, Str } kind_; private: std::variant repr_; public: Key() {} /*implicit*/ Key(KeyInt key) : repr_(key) {} /*implicit*/ Key(KeyStr key) : repr_(std::move(key)) {} Kind kind() const { return static_cast(repr_.index()); } KeyInt as_int() const { return std::get(repr_); } operator KeyInt() const { return as_int(); } const KeyStr& as_str() const { return std::get(repr_); } operator const KeyStr&() const { return as_str(); } bool operator==(const Key& rhs) const { return repr_ == rhs.repr_; } bool operator!=(const Key& rhs) const { return !operator==(rhs); } }; struct Empty {}; template struct ContainerHandle; template struct Container final : public Aux { using handle_type = ContainerHandle; using leaf_type = T; Kind kind = Kind::None; size_t size = 0; leaf_type* leaf = nullptr; std::unique_ptr items; std::unique_ptr keys; std::string custom_type; // internal only field to keep associated to every node meta info mutable size_t leaves_num = 0u; /*implicit*/ Container(Kind kind, size_t size = 0u) : kind(kind), size(size), items(std::unique_ptr(new handle_type[size])) { if (kind == Kind::Dict) { keys = std::unique_ptr(new Key[size]); } } /*implicit*/ Container(leaf_type* leaf) : kind(Kind::Leaf), size(0u), leaf(leaf), leaves_num(1u) {} Container(const Container&) = delete; Container& operator=(const Container&) = delete; }; template struct ContainerHandle { using container_type = Container; using leaf_type = T; std::unique_ptr handle; ContainerHandle() {} template ContainerHandle(Args... args) : handle(std::make_unique(std::forward(args)...)) {} /*implicit*/ ContainerHandle(container_type* c) : handle(c) {} /*implicit*/ ContainerHandle(std::unique_ptr c) : handle(std::move(c)) {} void set_leaf(leaf_type* leaf) { pytree_assert(handle->kind == Kind::Leaf); handle->leaf = leaf; } operator leaf_type() const { pytree_assert(handle->kind == Kind::Leaf); return *handle->leaf; } const leaf_type& leaf() const { pytree_assert(handle->kind == Kind::Leaf); return *handle->leaf; } leaf_type& leaf() { pytree_assert(handle->kind == Kind::Leaf); return *handle->leaf; } const leaf_type* leaf_ptr() const { pytree_assert(handle->kind == Kind::Leaf); return handle->leaf; } leaf_type* leaf_ptr() { pytree_assert(handle->kind == Kind::Leaf); return handle->leaf; } const ContainerHandle& operator[](size_t idx) const { pytree_assert(idx < handle->size); return handle->items[idx]; } ContainerHandle& operator[](size_t idx) { pytree_assert(idx < handle->size); return handle->items[idx]; } bool contains(const KeyStr& lookup_key) const { pytree_assert(isDict()); for (size_t i = 0; i < handle->size; ++i) { if (handle->keys[i] == lookup_key) { return true; } } return false; } const ContainerHandle& at(const Key& lookup_key) const { pytree_assert(isDict()); for (size_t i = 0; i < handle->size; ++i) { if (handle->keys[i] == lookup_key) { return handle->items[i]; } } pytree_unreachable(); } const ContainerHandle& at(const KeyInt& lookup_key) const { return at(Key(lookup_key)); } const ContainerHandle& at(const KeyStr& lookup_key) const { return at(Key(lookup_key)); } const Key& key(size_t idx) const { pytree_assert(isDict()); return handle->keys[idx]; } Key& key(size_t idx) { pytree_assert(isDict()); return handle->keys[idx]; } size_t size() const { return handle->size; } size_t leaves_num() const { return handle->leaves_num; } bool isDict() const { return handle->kind == Kind::Dict; } bool isList() const { return handle->kind == Kind::List; } bool isNamedTuple() const { return handle->kind == Kind::NamedTuple; } bool isTuple() const { return handle->kind == Kind::Tuple; } bool isLeaf() const { return handle->kind == Kind::Leaf; } Kind kind() const { return handle->kind; } // Checks only structure, no leaves comparison bool operator==(const ContainerHandle& rhs) { const Kind knd = kind(); if (knd != rhs.kind()) { return false; } if (knd == Kind::Leaf) { return true; } const size_t _size = size(); if (_size != rhs.size()) { return false; } for (size_t i = 0; i < _size; ++i) { if (knd == Kind::Dict && (key(i) != rhs.key(i))) { return false; } if (operator[](i) != rhs[i]) { return false; } } return true; } bool operator!=(const ContainerHandle& rhs) { return !operator==(rhs); } }; struct TreeSpecLeaf {}; template using TreeSpec = ContainerHandle; template using TreeSpecContainer = Container; using StrTreeSpec = std::string; // Expects refresh_leaves_num() was called after the last modification template ContainerHandle clone(const ContainerHandle& node, U* leaves) { if (node.isLeaf()) { return ContainerHandle(leaves); } ContainerHandle ret(node.kind(), node.size()); size_t leaves_offset = 0; size_t size = node.size(); for (size_t i = 0; i < size; ++i) { ret[i] = clone(node[i], leaves + leaves_offset); leaves_offset += node[i].leaves_num(); } if (node.isDict()) { ret.handle->keys = std::unique_ptr(new Key[size]); for (size_t i = 0; i < size; ++i) { ret.handle->keys[i] = node.handle->keys[i]; } } return ret; } template void traverse( ContainerHandle& node, FunctionRef&)> func) { for (size_t i = 0; i < node.size(); ++i) { traverse(node[i], func); } func(node); } template void traverse( const ContainerHandle& node, FunctionRef&)> func) { for (size_t i = 0; i < node.size(); ++i) { traverse(node[i], func); } func(node); } struct Config final { static constexpr char kTuple = 'T'; static constexpr char kNamedTuple = 'N'; static constexpr char kList = 'L'; static constexpr char kDict = 'D'; static constexpr char kCustom = 'C'; static constexpr char kLeaf = '$'; static constexpr char kNodeDataBegin = '('; static constexpr char kNodeDataEnd = ')'; static constexpr char kDictStrKeyQuote = '\''; static constexpr char kDictKeyValueSep = ':'; static constexpr char kChildrenSep = ','; static constexpr char kChildrenDataSep = '#'; }; template StrTreeSpec to_str_internal(const TreeSpec& spec) { std::string s; switch (spec.kind()) { case Kind::List: s.push_back(Config::kList); break; case Kind::NamedTuple: s.push_back(Config::kNamedTuple); break; case Kind::Tuple: s.push_back(Config::kTuple); break; case Kind::Dict: s.push_back(Config::kDict); break; case Kind::Leaf: s.push_back(Config::kLeaf); return s; case Kind::Custom: s.push_back(Config::kCustom); s.push_back('('); s.append(spec.handle->custom_type); s.push_back(')'); break; case Kind::None: return s; } const size_t size = spec.size(); s.append(std::to_string(size)); for (size_t i = 0; i < size; ++i) { s.push_back(Config::kChildrenDataSep); s.append(std::to_string(spec[i].leaves_num())); } s.push_back(Config::kNodeDataBegin); if (spec.kind() == Kind::Dict) { for (size_t i = 0; i < size; ++i) { if (i) { s.push_back(Config::kChildrenSep); } const auto& key = spec.key(i); if (key.kind() == Key::Kind::Int) { s.append(std::to_string(key.as_int())); } else if (key.kind() == Key::Kind::Str) { s.push_back(Config::kDictStrKeyQuote); s.append(key.as_str()); s.push_back(Config::kDictStrKeyQuote); } else { pytree_unreachable(); } s.push_back(Config::kDictKeyValueSep); s.append(to_str_internal(spec[i])); } } else { for (size_t i = 0; i < size; ++i) { if (i) { s.push_back(Config::kChildrenSep); } s.append(to_str_internal(spec[i])); } } s.push_back(Config::kNodeDataEnd); return s; } template struct arr { explicit arr(const size_t n) : data_(std::unique_ptr(new T[n])), n_(n) {} T& operator[](size_t idx) { return data_[idx]; } const T& operator[](size_t idx) const { return data_[idx]; } inline T* data() { return data_.get(); } inline size_t size() const { return n_; } private: std::unique_ptr data_; size_t n_; }; inline size_t read_number(const StrTreeSpec& spec, size_t& read_idx) { size_t num = 0; while (isdigit(spec[read_idx])) { num = 10 * num + (spec[read_idx] - '0'); read_idx++; } return num; } inline arr read_node_layout(const StrTreeSpec& spec, size_t& read_idx) { const size_t child_num = read_number(spec, read_idx); arr ret(child_num); size_t child_idx = 0; while (spec[read_idx] == Config::kChildrenDataSep) { ++read_idx; ret[child_idx++] = read_number(spec, read_idx); } return ret; } template TreeSpec from_str_internal( const StrTreeSpec& spec, size_t read_idx, const arr& spec_data) { const auto kind_char = spec[read_idx]; switch (kind_char) { case Config::kTuple: case Config::kNamedTuple: case Config::kList: { Kind kind = Kind::List; std::string custom_type; if (Config::kNamedTuple == kind_char) { kind = Kind::NamedTuple; } else if (Config::kTuple == kind_char) { kind = Kind::Tuple; } else if (Config::kCustom == kind_char) { kind = Kind::Custom; read_idx++; assert(spec[read_idx] == '('); auto type_str_end = spec_data[read_idx]; read_idx++; custom_type = spec.substr(read_idx, type_str_end - read_idx); assert(false); } read_idx++; auto layout = read_node_layout(spec, read_idx); const auto size = layout.size(); auto c = std::make_unique>(kind, size); if (Kind::Custom == kind) { c->custom_type = std::move(custom_type); } size_t child_idx = 0; size_t leaves_offset = 0; if (size > 0) { while (spec[read_idx] != Config::kNodeDataEnd) { // NOLINTNEXTLINE auto next_delim_idx = spec_data[read_idx]; read_idx++; c->items[child_idx] = from_str_internal(spec, read_idx, spec_data); read_idx = next_delim_idx; leaves_offset += layout[child_idx++]; } } else { read_idx++; } c->leaves_num = leaves_offset; return TreeSpec(std::move(c)); } case Config::kDict: { read_idx++; auto layout = read_node_layout(spec, read_idx); const auto size = layout.size(); auto c = std::make_unique>(Kind::Dict, size); size_t child_idx = 0; size_t leaves_offset = 0; if (size > 0) { while (spec[read_idx] != Config::kNodeDataEnd) { // NOLINTNEXTLINE auto next_delim_idx = spec_data[read_idx]; read_idx++; if (spec[read_idx] == Config::kDictStrKeyQuote) { auto key_delim_idx = spec_data[read_idx]; read_idx++; const size_t key_len = key_delim_idx - read_idx; // NOLINTNEXTLINE c->keys[child_idx] = spec.substr(read_idx, key_len); read_idx = key_delim_idx + 2; } else { pytree_assert(isdigit(spec[read_idx])); size_t key = read_number(spec, read_idx); c->keys[child_idx] = KeyInt(key); read_idx += 1; } c->items[child_idx] = from_str_internal(spec, read_idx, spec_data); read_idx = next_delim_idx; leaves_offset += layout[child_idx++]; } } else { read_idx++; } c->leaves_num = leaves_offset; return TreeSpec(std::move(c)); } case Config::kLeaf: return new TreeSpecContainer(nullptr); } pytree_unreachable(); return new TreeSpecContainer(Kind::None); } template struct stack final { constexpr static const size_t SIZE = 8; size_t size_ = 0; T data[SIZE]; void push(T&& item) { pytree_assert(size_ < SIZE); data[size_++] = std::move(item); } T pop() { pytree_assert(size_ > 0); return data[--size_]; } T& top() { pytree_assert(size_ > 0); return data[size_ - 1]; } size_t size() { return size_; } }; inline arr pre_parse(const StrTreeSpec& spec) { stack> stack; size_t i = 0; const size_t size = spec.size(); arr ret(size); while (i < size) { const auto c = spec[i]; switch (c) { case Config::kNodeDataBegin: { stack.push({i, i}); break; } case Config::kNodeDataEnd: { auto& item = stack.top(); size_t last_sep_idx = item.second; ret[last_sep_idx] = i; stack.pop(); break; } case Config::kDictStrKeyQuote: { size_t idx = i; i++; while (spec[i] != Config::kDictStrKeyQuote) { i++; } ret[idx] = i; ret[i] = idx; break; } case Config::kChildrenSep: { auto& item = stack.top(); size_t last_sep_idx = item.second; ret[last_sep_idx] = i; item.second = i; break; } } i++; } return ret; } template TreeSpec from_str(const StrTreeSpec& spec) { return from_str_internal(spec, 0u, pre_parse(spec)); } template StrTreeSpec to_str(const TreeSpec& spec) { if (spec.leaves_num() == 0) { refresh_leaves_num(spec); } return to_str_internal(spec); } template StrTreeSpec to_str(const TreeSpec& spec); template ContainerHandle unflatten(const TreeSpec& spec, T* leaves) { if (spec.leaves_num() == 0) { refresh_leaves_num(spec); } return clone(spec, leaves); } template ContainerHandle unflatten(const StrTreeSpec& spec, T* leaves) { return unflatten(from_str(spec), leaves); } template void flatten_internal(const ContainerHandle& tree, const T** leaves) { using tree_t = decltype(tree); size_t leaves_idx = 0; auto func = [&](tree_t node) { if (node.isLeaf()) { leaves[leaves_idx++] = node.leaf_ptr(); } }; traverse(tree, FunctionRef{func}); } template void flatten_internal(ContainerHandle& tree, T** leaves) { using tree_t = decltype(tree); size_t leaves_idx = 0; auto func = [&](tree_t node) { if (node.isLeaf()) { leaves[leaves_idx++] = node.leaf_ptr(); } }; traverse(tree, FunctionRef{func}); } template size_t refresh_leaves_num(const ContainerHandle& node) { if (node.isLeaf()) { node.handle->leaves_num = 1; return 1; } size_t n = 0; for (size_t i = 0; i < node.size(); ++i) { n += refresh_leaves_num(node[i]); } node.handle->leaves_num = n; return n; } template std::pair, std::unique_ptr>> flatten( const ContainerHandle& tree) { refresh_leaves_num(tree); const size_t n = tree.leaves_num(); arr leaves(n); flatten_internal(tree, leaves.data()); auto spec_leaves = std::make_unique(n); return { std::move(leaves), std::make_unique>(clone(tree, spec_leaves.get()))}; } // Duplication of logic for non const ContainerHandle template std::pair, std::unique_ptr>> flatten( ContainerHandle& tree) { refresh_leaves_num(tree); const size_t n = tree.leaves_num(); arr leaves(n); flatten_internal(tree, leaves.data()); auto spec_leaves = std::make_unique(n); return { std::move(leaves), std::make_unique>(clone(tree, spec_leaves.get()))}; } } // namespace pytree } // namespace extension } // namespace executorch namespace torch { namespace executor { namespace pytree { // TODO(T197294990): Remove these deprecated aliases once all users have moved // to the new `::executorch` namespaces. using ::executorch::extension::pytree::Empty; using ::executorch::extension::pytree::from_str; using ::executorch::extension::pytree::TreeSpec; } // namespace pytree } // namespace executor } // namespace torch