xref: /aosp_15_r20/external/executorch/extension/pytree/pybindings.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #include <pybind11/pybind11.h>
10*523fa7a6SAndroid Build Coastguard Worker #include <pybind11/stl.h>
11*523fa7a6SAndroid Build Coastguard Worker #include <memory>
12*523fa7a6SAndroid Build Coastguard Worker #include <stack>
13*523fa7a6SAndroid Build Coastguard Worker 
14*523fa7a6SAndroid Build Coastguard Worker #include "executorch/extension/pytree/pytree.h"
15*523fa7a6SAndroid Build Coastguard Worker 
16*523fa7a6SAndroid Build Coastguard Worker namespace py = pybind11;
17*523fa7a6SAndroid Build Coastguard Worker 
18*523fa7a6SAndroid Build Coastguard Worker namespace executorch {
19*523fa7a6SAndroid Build Coastguard Worker namespace extension {
20*523fa7a6SAndroid Build Coastguard Worker namespace pytree {
21*523fa7a6SAndroid Build Coastguard Worker 
22*523fa7a6SAndroid Build Coastguard Worker namespace {
23*523fa7a6SAndroid Build Coastguard Worker 
24*523fa7a6SAndroid Build Coastguard Worker struct PyAux {
25*523fa7a6SAndroid Build Coastguard Worker   py::object custom_type_context;
26*523fa7a6SAndroid Build Coastguard Worker };
27*523fa7a6SAndroid Build Coastguard Worker using PyTreeSpec = TreeSpec<PyAux>;
28*523fa7a6SAndroid Build Coastguard Worker 
29*523fa7a6SAndroid Build Coastguard Worker class PyTypeRegistry {
30*523fa7a6SAndroid Build Coastguard Worker  public:
31*523fa7a6SAndroid Build Coastguard Worker   struct PyTypeReg {
PyTypeRegexecutorch::extension::pytree::__anon59fa55ac0111::PyTypeRegistry::PyTypeReg32*523fa7a6SAndroid Build Coastguard Worker     explicit PyTypeReg(Kind k) : kind(k) {}
33*523fa7a6SAndroid Build Coastguard Worker 
34*523fa7a6SAndroid Build Coastguard Worker     Kind kind;
35*523fa7a6SAndroid Build Coastguard Worker 
36*523fa7a6SAndroid Build Coastguard Worker     // for custom types
37*523fa7a6SAndroid Build Coastguard Worker     py::object type;
38*523fa7a6SAndroid Build Coastguard Worker     // function type: object -> (children, spec_data)
39*523fa7a6SAndroid Build Coastguard Worker     py::function flatten;
40*523fa7a6SAndroid Build Coastguard Worker     // function type: (children, spec_data) -> object
41*523fa7a6SAndroid Build Coastguard Worker     py::function unflatten;
42*523fa7a6SAndroid Build Coastguard Worker   };
43*523fa7a6SAndroid Build Coastguard Worker 
get_by_str(const std::string & pytype)44*523fa7a6SAndroid Build Coastguard Worker   static const PyTypeReg* get_by_str(const std::string& pytype) {
45*523fa7a6SAndroid Build Coastguard Worker     auto* registry = instance();
46*523fa7a6SAndroid Build Coastguard Worker     auto it = registry->regs_.find(pytype);
47*523fa7a6SAndroid Build Coastguard Worker     return it == registry->regs_.end() ? nullptr : it->second.get();
48*523fa7a6SAndroid Build Coastguard Worker   }
49*523fa7a6SAndroid Build Coastguard Worker 
get_by_type(py::handle pytype)50*523fa7a6SAndroid Build Coastguard Worker   static const PyTypeReg* get_by_type(py::handle pytype) {
51*523fa7a6SAndroid Build Coastguard Worker     return get_by_str(py::str(pytype));
52*523fa7a6SAndroid Build Coastguard Worker   }
53*523fa7a6SAndroid Build Coastguard Worker 
register_custom_type(py::object type,py::function flatten,py::function unflatten)54*523fa7a6SAndroid Build Coastguard Worker   static void register_custom_type(
55*523fa7a6SAndroid Build Coastguard Worker       py::object type,
56*523fa7a6SAndroid Build Coastguard Worker       py::function flatten,
57*523fa7a6SAndroid Build Coastguard Worker       py::function unflatten) {
58*523fa7a6SAndroid Build Coastguard Worker     auto* registry = instance();
59*523fa7a6SAndroid Build Coastguard Worker     auto reg = std::make_unique<PyTypeReg>(Kind::Custom);
60*523fa7a6SAndroid Build Coastguard Worker     reg->type = type;
61*523fa7a6SAndroid Build Coastguard Worker     reg->flatten = std::move(flatten);
62*523fa7a6SAndroid Build Coastguard Worker     reg->unflatten = std::move(unflatten);
63*523fa7a6SAndroid Build Coastguard Worker     std::string pytype_str = py::str(type);
64*523fa7a6SAndroid Build Coastguard Worker     auto it = registry->regs_.emplace(pytype_str, std::move(reg));
65*523fa7a6SAndroid Build Coastguard Worker     if (!it.second) {
66*523fa7a6SAndroid Build Coastguard Worker       assert(false);
67*523fa7a6SAndroid Build Coastguard Worker     }
68*523fa7a6SAndroid Build Coastguard Worker   }
69*523fa7a6SAndroid Build Coastguard Worker 
70*523fa7a6SAndroid Build Coastguard Worker  private:
instance()71*523fa7a6SAndroid Build Coastguard Worker   static PyTypeRegistry* instance() {
72*523fa7a6SAndroid Build Coastguard Worker     static auto* registry_instance = []() -> PyTypeRegistry* {
73*523fa7a6SAndroid Build Coastguard Worker       auto* registry = new PyTypeRegistry;
74*523fa7a6SAndroid Build Coastguard Worker 
75*523fa7a6SAndroid Build Coastguard Worker       auto add_pytype_reg = [&](const std::string& pytype, Kind kind) {
76*523fa7a6SAndroid Build Coastguard Worker         registry->regs_.emplace(pytype, std::make_unique<PyTypeReg>(kind));
77*523fa7a6SAndroid Build Coastguard Worker       };
78*523fa7a6SAndroid Build Coastguard Worker 
79*523fa7a6SAndroid Build Coastguard Worker       add_pytype_reg("<class 'tuple'>", Kind::Tuple);
80*523fa7a6SAndroid Build Coastguard Worker       add_pytype_reg("<class 'list'>", Kind::List);
81*523fa7a6SAndroid Build Coastguard Worker       add_pytype_reg("<class 'dict'>", Kind::Dict);
82*523fa7a6SAndroid Build Coastguard Worker 
83*523fa7a6SAndroid Build Coastguard Worker       return registry;
84*523fa7a6SAndroid Build Coastguard Worker     }();
85*523fa7a6SAndroid Build Coastguard Worker 
86*523fa7a6SAndroid Build Coastguard Worker     return registry_instance;
87*523fa7a6SAndroid Build Coastguard Worker   }
88*523fa7a6SAndroid Build Coastguard Worker   std::unordered_map<std::string, std::unique_ptr<PyTypeReg>> regs_;
89*523fa7a6SAndroid Build Coastguard Worker };
90*523fa7a6SAndroid Build Coastguard Worker 
91*523fa7a6SAndroid Build Coastguard Worker class PyTree {
92*523fa7a6SAndroid Build Coastguard Worker   PyTreeSpec spec_;
93*523fa7a6SAndroid Build Coastguard Worker 
flatten_internal(py::handle x,std::vector<py::object> & leaves,PyTreeSpec & s)94*523fa7a6SAndroid Build Coastguard Worker   static void flatten_internal(
95*523fa7a6SAndroid Build Coastguard Worker       py::handle x,
96*523fa7a6SAndroid Build Coastguard Worker       std::vector<py::object>& leaves,
97*523fa7a6SAndroid Build Coastguard Worker       PyTreeSpec& s) {
98*523fa7a6SAndroid Build Coastguard Worker     const auto* reg = PyTypeRegistry::get_by_type(x.get_type());
99*523fa7a6SAndroid Build Coastguard Worker     const auto kind = [&reg, &x]() {
100*523fa7a6SAndroid Build Coastguard Worker       if (reg) {
101*523fa7a6SAndroid Build Coastguard Worker         return reg->kind;
102*523fa7a6SAndroid Build Coastguard Worker       }
103*523fa7a6SAndroid Build Coastguard Worker       if (py::isinstance<py::tuple>(x) && py::hasattr(x, "_fields")) {
104*523fa7a6SAndroid Build Coastguard Worker         return Kind::NamedTuple;
105*523fa7a6SAndroid Build Coastguard Worker       }
106*523fa7a6SAndroid Build Coastguard Worker       return Kind::Leaf;
107*523fa7a6SAndroid Build Coastguard Worker     }();
108*523fa7a6SAndroid Build Coastguard Worker     switch (kind) {
109*523fa7a6SAndroid Build Coastguard Worker       case Kind::List: {
110*523fa7a6SAndroid Build Coastguard Worker         const size_t n = PyList_GET_SIZE(x.ptr());
111*523fa7a6SAndroid Build Coastguard Worker         s = PyTreeSpec(Kind::List, n);
112*523fa7a6SAndroid Build Coastguard Worker         for (size_t i = 0; i < n; ++i) {
113*523fa7a6SAndroid Build Coastguard Worker           flatten_internal(PyList_GET_ITEM(x.ptr(), i), leaves, s[i]);
114*523fa7a6SAndroid Build Coastguard Worker         }
115*523fa7a6SAndroid Build Coastguard Worker         break;
116*523fa7a6SAndroid Build Coastguard Worker       }
117*523fa7a6SAndroid Build Coastguard Worker       case Kind::Tuple: {
118*523fa7a6SAndroid Build Coastguard Worker         const size_t n = PyTuple_GET_SIZE(x.ptr());
119*523fa7a6SAndroid Build Coastguard Worker         s = PyTreeSpec(Kind::Tuple, n);
120*523fa7a6SAndroid Build Coastguard Worker         for (size_t i = 0; i < n; ++i) {
121*523fa7a6SAndroid Build Coastguard Worker           flatten_internal(PyTuple_GET_ITEM(x.ptr(), i), leaves, s[i]);
122*523fa7a6SAndroid Build Coastguard Worker         }
123*523fa7a6SAndroid Build Coastguard Worker         break;
124*523fa7a6SAndroid Build Coastguard Worker       }
125*523fa7a6SAndroid Build Coastguard Worker       case Kind::NamedTuple: {
126*523fa7a6SAndroid Build Coastguard Worker         py::tuple tuple = py::reinterpret_borrow<py::tuple>(x);
127*523fa7a6SAndroid Build Coastguard Worker         const size_t n = tuple.size();
128*523fa7a6SAndroid Build Coastguard Worker         s = PyTreeSpec(Kind::NamedTuple, n);
129*523fa7a6SAndroid Build Coastguard Worker         size_t i = 0;
130*523fa7a6SAndroid Build Coastguard Worker         for (py::handle entry : tuple) {
131*523fa7a6SAndroid Build Coastguard Worker           flatten_internal(entry, leaves, s[i++]);
132*523fa7a6SAndroid Build Coastguard Worker         }
133*523fa7a6SAndroid Build Coastguard Worker         break;
134*523fa7a6SAndroid Build Coastguard Worker       }
135*523fa7a6SAndroid Build Coastguard Worker       case Kind::Dict: {
136*523fa7a6SAndroid Build Coastguard Worker         py::dict dict = py::reinterpret_borrow<py::dict>(x);
137*523fa7a6SAndroid Build Coastguard Worker         py::list keys =
138*523fa7a6SAndroid Build Coastguard Worker             py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
139*523fa7a6SAndroid Build Coastguard Worker         const auto n = PyList_GET_SIZE(keys.ptr());
140*523fa7a6SAndroid Build Coastguard Worker         s = PyTreeSpec(Kind::Dict, n);
141*523fa7a6SAndroid Build Coastguard Worker         size_t i = 0;
142*523fa7a6SAndroid Build Coastguard Worker         for (py::handle key : keys) {
143*523fa7a6SAndroid Build Coastguard Worker           if (py::isinstance<py::str>(key)) {
144*523fa7a6SAndroid Build Coastguard Worker             s.key(i) = py::cast<std::string>(key);
145*523fa7a6SAndroid Build Coastguard Worker           } else if (py::isinstance<py::int_>(key)) {
146*523fa7a6SAndroid Build Coastguard Worker             s.key(i) = py::cast<int32_t>(key);
147*523fa7a6SAndroid Build Coastguard Worker           } else {
148*523fa7a6SAndroid Build Coastguard Worker             pytree_assert(false);
149*523fa7a6SAndroid Build Coastguard Worker           }
150*523fa7a6SAndroid Build Coastguard Worker 
151*523fa7a6SAndroid Build Coastguard Worker           flatten_internal(dict[key], leaves, s[i]);
152*523fa7a6SAndroid Build Coastguard Worker           i++;
153*523fa7a6SAndroid Build Coastguard Worker         }
154*523fa7a6SAndroid Build Coastguard Worker         break;
155*523fa7a6SAndroid Build Coastguard Worker       }
156*523fa7a6SAndroid Build Coastguard Worker       case Kind::Custom: {
157*523fa7a6SAndroid Build Coastguard Worker         py::tuple out = py::cast<py::tuple>(reg->flatten(x));
158*523fa7a6SAndroid Build Coastguard Worker         if (out.size() != 2) {
159*523fa7a6SAndroid Build Coastguard Worker           assert(false);
160*523fa7a6SAndroid Build Coastguard Worker         }
161*523fa7a6SAndroid Build Coastguard Worker         py::list children = py::cast<py::list>(out[0]);
162*523fa7a6SAndroid Build Coastguard Worker         const size_t n = children.size();
163*523fa7a6SAndroid Build Coastguard Worker         s = PyTreeSpec(Kind::Custom, n);
164*523fa7a6SAndroid Build Coastguard Worker         s.handle->custom_type = py::str(x.get_type());
165*523fa7a6SAndroid Build Coastguard Worker         s.handle->custom_type_context = out[1];
166*523fa7a6SAndroid Build Coastguard Worker         size_t i = 0;
167*523fa7a6SAndroid Build Coastguard Worker         for (py::handle pychild : children) {
168*523fa7a6SAndroid Build Coastguard Worker           flatten_internal(pychild, leaves, s[i++]);
169*523fa7a6SAndroid Build Coastguard Worker         }
170*523fa7a6SAndroid Build Coastguard Worker         break;
171*523fa7a6SAndroid Build Coastguard Worker       }
172*523fa7a6SAndroid Build Coastguard Worker       case Kind::Leaf: {
173*523fa7a6SAndroid Build Coastguard Worker         s = PyTreeSpec(Kind::Leaf);
174*523fa7a6SAndroid Build Coastguard Worker         leaves.push_back(py::reinterpret_borrow<py::object>(x));
175*523fa7a6SAndroid Build Coastguard Worker         break;
176*523fa7a6SAndroid Build Coastguard Worker       }
177*523fa7a6SAndroid Build Coastguard Worker       case Kind::None:
178*523fa7a6SAndroid Build Coastguard Worker         pytree_assert(false);
179*523fa7a6SAndroid Build Coastguard Worker     }
180*523fa7a6SAndroid Build Coastguard Worker   }
181*523fa7a6SAndroid Build Coastguard Worker 
182*523fa7a6SAndroid Build Coastguard Worker   template <typename T>
unflatten_internal(const PyTreeSpec & spec,T && leaves_it) const183*523fa7a6SAndroid Build Coastguard Worker   py::object unflatten_internal(const PyTreeSpec& spec, T&& leaves_it) const {
184*523fa7a6SAndroid Build Coastguard Worker     switch (spec.kind()) {
185*523fa7a6SAndroid Build Coastguard Worker       case Kind::NamedTuple:
186*523fa7a6SAndroid Build Coastguard Worker       case Kind::Tuple: {
187*523fa7a6SAndroid Build Coastguard Worker         const size_t size = spec.size();
188*523fa7a6SAndroid Build Coastguard Worker         py::tuple tuple(size);
189*523fa7a6SAndroid Build Coastguard Worker         for (size_t i = 0; i < size; ++i) {
190*523fa7a6SAndroid Build Coastguard Worker           tuple[i] = unflatten_internal(spec[i], leaves_it);
191*523fa7a6SAndroid Build Coastguard Worker         }
192*523fa7a6SAndroid Build Coastguard Worker         return std::move(tuple);
193*523fa7a6SAndroid Build Coastguard Worker       }
194*523fa7a6SAndroid Build Coastguard Worker       case Kind::List: {
195*523fa7a6SAndroid Build Coastguard Worker         const size_t size = spec.size();
196*523fa7a6SAndroid Build Coastguard Worker         py::list list(size);
197*523fa7a6SAndroid Build Coastguard Worker         for (size_t i = 0; i < size; ++i) {
198*523fa7a6SAndroid Build Coastguard Worker           list[i] = unflatten_internal(spec[i], leaves_it);
199*523fa7a6SAndroid Build Coastguard Worker         }
200*523fa7a6SAndroid Build Coastguard Worker         return std::move(list);
201*523fa7a6SAndroid Build Coastguard Worker       }
202*523fa7a6SAndroid Build Coastguard Worker       case Kind::Custom: {
203*523fa7a6SAndroid Build Coastguard Worker         const auto& pytype_str = spec.handle->custom_type;
204*523fa7a6SAndroid Build Coastguard Worker         const auto* reg = PyTypeRegistry::get_by_str(pytype_str);
205*523fa7a6SAndroid Build Coastguard Worker         const size_t size = spec.size();
206*523fa7a6SAndroid Build Coastguard Worker         py::list list(size);
207*523fa7a6SAndroid Build Coastguard Worker         for (size_t i = 0; i < size; ++i) {
208*523fa7a6SAndroid Build Coastguard Worker           list[i] = unflatten_internal(spec[i], leaves_it);
209*523fa7a6SAndroid Build Coastguard Worker         }
210*523fa7a6SAndroid Build Coastguard Worker         py::object o = reg->unflatten(list, spec.handle->custom_type_context);
211*523fa7a6SAndroid Build Coastguard Worker         return o;
212*523fa7a6SAndroid Build Coastguard Worker       }
213*523fa7a6SAndroid Build Coastguard Worker       case Kind::Dict: {
214*523fa7a6SAndroid Build Coastguard Worker         const size_t size = spec.size();
215*523fa7a6SAndroid Build Coastguard Worker         py::dict dict;
216*523fa7a6SAndroid Build Coastguard Worker         for (size_t i = 0; i < size; ++i) {
217*523fa7a6SAndroid Build Coastguard Worker           auto& key = spec.key(i);
218*523fa7a6SAndroid Build Coastguard Worker           auto py_key = [&key]() -> py::handle {
219*523fa7a6SAndroid Build Coastguard Worker             switch (key.kind()) {
220*523fa7a6SAndroid Build Coastguard Worker               case Key::Kind::Int:
221*523fa7a6SAndroid Build Coastguard Worker                 return py::cast(key.as_int()).release();
222*523fa7a6SAndroid Build Coastguard Worker               case Key::Kind::Str:
223*523fa7a6SAndroid Build Coastguard Worker                 return py::cast(key.as_str()).release();
224*523fa7a6SAndroid Build Coastguard Worker               case Key::Kind::None:
225*523fa7a6SAndroid Build Coastguard Worker                 pytree_assert(false);
226*523fa7a6SAndroid Build Coastguard Worker             }
227*523fa7a6SAndroid Build Coastguard Worker             pytree_assert(false);
228*523fa7a6SAndroid Build Coastguard Worker             return py::none();
229*523fa7a6SAndroid Build Coastguard Worker           }();
230*523fa7a6SAndroid Build Coastguard Worker           dict[py_key] = unflatten_internal(spec[i], leaves_it);
231*523fa7a6SAndroid Build Coastguard Worker         }
232*523fa7a6SAndroid Build Coastguard Worker         return std::move(dict);
233*523fa7a6SAndroid Build Coastguard Worker       }
234*523fa7a6SAndroid Build Coastguard Worker       case Kind::Leaf: {
235*523fa7a6SAndroid Build Coastguard Worker         py::object o =
236*523fa7a6SAndroid Build Coastguard Worker             py::reinterpret_borrow<py::object>(*std::forward<T>(leaves_it));
237*523fa7a6SAndroid Build Coastguard Worker         leaves_it++;
238*523fa7a6SAndroid Build Coastguard Worker         return o;
239*523fa7a6SAndroid Build Coastguard Worker       }
240*523fa7a6SAndroid Build Coastguard Worker       case Kind::None: {
241*523fa7a6SAndroid Build Coastguard Worker         return py::none();
242*523fa7a6SAndroid Build Coastguard Worker       }
243*523fa7a6SAndroid Build Coastguard Worker     }
244*523fa7a6SAndroid Build Coastguard Worker     pytree_assert(false);
245*523fa7a6SAndroid Build Coastguard Worker   }
246*523fa7a6SAndroid Build Coastguard Worker 
247*523fa7a6SAndroid Build Coastguard Worker  public:
PyTree(PyTreeSpec spec)248*523fa7a6SAndroid Build Coastguard Worker   explicit PyTree(PyTreeSpec spec) : spec_(std::move(spec)) {}
249*523fa7a6SAndroid Build Coastguard Worker 
spec() const250*523fa7a6SAndroid Build Coastguard Worker   const PyTreeSpec& spec() const {
251*523fa7a6SAndroid Build Coastguard Worker     return spec_;
252*523fa7a6SAndroid Build Coastguard Worker   }
253*523fa7a6SAndroid Build Coastguard Worker 
py_from_str(std::string spec)254*523fa7a6SAndroid Build Coastguard Worker   static PyTree py_from_str(std::string spec) {
255*523fa7a6SAndroid Build Coastguard Worker     return PyTree(from_str<PyAux>(spec));
256*523fa7a6SAndroid Build Coastguard Worker   }
257*523fa7a6SAndroid Build Coastguard Worker 
py_to_str() const258*523fa7a6SAndroid Build Coastguard Worker   StrTreeSpec py_to_str() const {
259*523fa7a6SAndroid Build Coastguard Worker     return to_str(spec_);
260*523fa7a6SAndroid Build Coastguard Worker   }
261*523fa7a6SAndroid Build Coastguard Worker 
262*523fa7a6SAndroid Build Coastguard Worker   static std::pair<std::vector<py::object>, std::unique_ptr<PyTree>>
tree_flatten(py::handle x)263*523fa7a6SAndroid Build Coastguard Worker   tree_flatten(py::handle x) {
264*523fa7a6SAndroid Build Coastguard Worker     std::vector<py::object> leaves{};
265*523fa7a6SAndroid Build Coastguard Worker     PyTreeSpec spec{};
266*523fa7a6SAndroid Build Coastguard Worker     flatten_internal(x, leaves, spec);
267*523fa7a6SAndroid Build Coastguard Worker     refresh_leaves_num(spec);
268*523fa7a6SAndroid Build Coastguard Worker     return {std::move(leaves), std::make_unique<PyTree>(std::move(spec))};
269*523fa7a6SAndroid Build Coastguard Worker   }
270*523fa7a6SAndroid Build Coastguard Worker 
tree_unflatten(py::iterable leaves,py::object o)271*523fa7a6SAndroid Build Coastguard Worker   static py::object tree_unflatten(py::iterable leaves, py::object o) {
272*523fa7a6SAndroid Build Coastguard Worker     return o.cast<PyTree*>()->tree_unflatten(leaves);
273*523fa7a6SAndroid Build Coastguard Worker   }
274*523fa7a6SAndroid Build Coastguard Worker 
275*523fa7a6SAndroid Build Coastguard Worker   template <typename T>
tree_unflatten(T leaves) const276*523fa7a6SAndroid Build Coastguard Worker   py::object tree_unflatten(T leaves) const {
277*523fa7a6SAndroid Build Coastguard Worker     return unflatten_internal(spec_, leaves.begin());
278*523fa7a6SAndroid Build Coastguard Worker   }
279*523fa7a6SAndroid Build Coastguard Worker 
operator ==(const PyTree & rhs)280*523fa7a6SAndroid Build Coastguard Worker   bool operator==(const PyTree& rhs) {
281*523fa7a6SAndroid Build Coastguard Worker     return spec_ == rhs.spec_;
282*523fa7a6SAndroid Build Coastguard Worker   }
283*523fa7a6SAndroid Build Coastguard Worker 
leaves_num() const284*523fa7a6SAndroid Build Coastguard Worker   size_t leaves_num() const {
285*523fa7a6SAndroid Build Coastguard Worker     return refresh_leaves_num(spec_);
286*523fa7a6SAndroid Build Coastguard Worker   }
287*523fa7a6SAndroid Build Coastguard Worker };
288*523fa7a6SAndroid Build Coastguard Worker 
tree_flatten(py::handle x)289*523fa7a6SAndroid Build Coastguard Worker inline std::pair<std::vector<py::object>, std::unique_ptr<PyTree>> tree_flatten(
290*523fa7a6SAndroid Build Coastguard Worker     py::handle x) {
291*523fa7a6SAndroid Build Coastguard Worker   return PyTree::tree_flatten(x);
292*523fa7a6SAndroid Build Coastguard Worker }
293*523fa7a6SAndroid Build Coastguard Worker 
tree_unflatten(py::iterable leaves,py::object o)294*523fa7a6SAndroid Build Coastguard Worker inline py::object tree_unflatten(py::iterable leaves, py::object o) {
295*523fa7a6SAndroid Build Coastguard Worker   return PyTree::tree_unflatten(leaves, o);
296*523fa7a6SAndroid Build Coastguard Worker }
297*523fa7a6SAndroid Build Coastguard Worker 
tree_map(py::function & fn,py::handle x)298*523fa7a6SAndroid Build Coastguard Worker static py::object tree_map(py::function& fn, py::handle x) {
299*523fa7a6SAndroid Build Coastguard Worker   auto p = tree_flatten(x);
300*523fa7a6SAndroid Build Coastguard Worker   const auto& leaves = p.first;
301*523fa7a6SAndroid Build Coastguard Worker   const auto& pytree = p.second;
302*523fa7a6SAndroid Build Coastguard Worker   std::vector<py::handle> vec;
303*523fa7a6SAndroid Build Coastguard Worker   for (const py::handle& h : leaves) {
304*523fa7a6SAndroid Build Coastguard Worker     vec.push_back(fn(h));
305*523fa7a6SAndroid Build Coastguard Worker   }
306*523fa7a6SAndroid Build Coastguard Worker   return pytree->tree_unflatten(vec);
307*523fa7a6SAndroid Build Coastguard Worker }
308*523fa7a6SAndroid Build Coastguard Worker 
py_from_str(std::string spec)309*523fa7a6SAndroid Build Coastguard Worker static std::unique_ptr<PyTree> py_from_str(std::string spec) {
310*523fa7a6SAndroid Build Coastguard Worker   return std::make_unique<PyTree>(from_str<PyAux>(spec));
311*523fa7a6SAndroid Build Coastguard Worker }
312*523fa7a6SAndroid Build Coastguard Worker 
broadcast_to_and_flatten(py::object x,py::object py_tree_spec)313*523fa7a6SAndroid Build Coastguard Worker static py::object broadcast_to_and_flatten(
314*523fa7a6SAndroid Build Coastguard Worker     py::object x,
315*523fa7a6SAndroid Build Coastguard Worker     py::object py_tree_spec) {
316*523fa7a6SAndroid Build Coastguard Worker   auto p = tree_flatten(x);
317*523fa7a6SAndroid Build Coastguard Worker   const auto& x_leaves = p.first;
318*523fa7a6SAndroid Build Coastguard Worker   const auto& x_spec = p.second->spec();
319*523fa7a6SAndroid Build Coastguard Worker 
320*523fa7a6SAndroid Build Coastguard Worker   PyTree* tree_spec = py_tree_spec.cast<PyTree*>();
321*523fa7a6SAndroid Build Coastguard Worker 
322*523fa7a6SAndroid Build Coastguard Worker   py::list ret;
323*523fa7a6SAndroid Build Coastguard Worker   struct StackItem {
324*523fa7a6SAndroid Build Coastguard Worker     const PyTreeSpec* tree_spec_node;
325*523fa7a6SAndroid Build Coastguard Worker     const PyTreeSpec* x_spec_node;
326*523fa7a6SAndroid Build Coastguard Worker     const size_t x_leaves_offset;
327*523fa7a6SAndroid Build Coastguard Worker   };
328*523fa7a6SAndroid Build Coastguard Worker   std::stack<StackItem> stack;
329*523fa7a6SAndroid Build Coastguard Worker   stack.push({&tree_spec->spec(), &x_spec, 0u});
330*523fa7a6SAndroid Build Coastguard Worker   while (!stack.empty()) {
331*523fa7a6SAndroid Build Coastguard Worker     const auto top = stack.top();
332*523fa7a6SAndroid Build Coastguard Worker     stack.pop();
333*523fa7a6SAndroid Build Coastguard Worker     if (top.x_spec_node->isLeaf()) {
334*523fa7a6SAndroid Build Coastguard Worker       for (size_t i = 0; i < top.tree_spec_node->leaves_num(); ++i) {
335*523fa7a6SAndroid Build Coastguard Worker         ret.append(x_leaves[top.x_leaves_offset]);
336*523fa7a6SAndroid Build Coastguard Worker       }
337*523fa7a6SAndroid Build Coastguard Worker     } else {
338*523fa7a6SAndroid Build Coastguard Worker       const auto kind = top.tree_spec_node->kind();
339*523fa7a6SAndroid Build Coastguard Worker       if (kind != top.x_spec_node->kind()) {
340*523fa7a6SAndroid Build Coastguard Worker         return py::none();
341*523fa7a6SAndroid Build Coastguard Worker       }
342*523fa7a6SAndroid Build Coastguard Worker       pytree_assert(top.tree_spec_node->kind() == top.x_spec_node->kind());
343*523fa7a6SAndroid Build Coastguard Worker       const size_t child_num = top.tree_spec_node->size();
344*523fa7a6SAndroid Build Coastguard Worker       if (child_num != top.x_spec_node->size()) {
345*523fa7a6SAndroid Build Coastguard Worker         return py::none();
346*523fa7a6SAndroid Build Coastguard Worker       }
347*523fa7a6SAndroid Build Coastguard Worker       pytree_assert(child_num == top.x_spec_node->size());
348*523fa7a6SAndroid Build Coastguard Worker 
349*523fa7a6SAndroid Build Coastguard Worker       size_t x_leaves_offset =
350*523fa7a6SAndroid Build Coastguard Worker           top.x_leaves_offset + top.x_spec_node->leaves_num();
351*523fa7a6SAndroid Build Coastguard Worker       auto fn_i = [&](size_t i) {
352*523fa7a6SAndroid Build Coastguard Worker         x_leaves_offset -= (*top.x_spec_node)[i].leaves_num();
353*523fa7a6SAndroid Build Coastguard Worker         stack.push(
354*523fa7a6SAndroid Build Coastguard Worker             {&(*top.tree_spec_node)[i],
355*523fa7a6SAndroid Build Coastguard Worker              &(*top.x_spec_node)[i],
356*523fa7a6SAndroid Build Coastguard Worker              x_leaves_offset});
357*523fa7a6SAndroid Build Coastguard Worker       };
358*523fa7a6SAndroid Build Coastguard Worker       if (Kind::Dict == kind) {
359*523fa7a6SAndroid Build Coastguard Worker         for (size_t i = child_num - 1; i < child_num; --i) {
360*523fa7a6SAndroid Build Coastguard Worker           if (top.tree_spec_node->key(i) != top.x_spec_node->key(i)) {
361*523fa7a6SAndroid Build Coastguard Worker             return py::none();
362*523fa7a6SAndroid Build Coastguard Worker           }
363*523fa7a6SAndroid Build Coastguard Worker           fn_i(i);
364*523fa7a6SAndroid Build Coastguard Worker         }
365*523fa7a6SAndroid Build Coastguard Worker       } else {
366*523fa7a6SAndroid Build Coastguard Worker         for (size_t i = child_num - 1; i < child_num; --i) {
367*523fa7a6SAndroid Build Coastguard Worker           fn_i(i);
368*523fa7a6SAndroid Build Coastguard Worker         }
369*523fa7a6SAndroid Build Coastguard Worker       }
370*523fa7a6SAndroid Build Coastguard Worker     }
371*523fa7a6SAndroid Build Coastguard Worker   }
372*523fa7a6SAndroid Build Coastguard Worker   return std::move(ret);
373*523fa7a6SAndroid Build Coastguard Worker }
374*523fa7a6SAndroid Build Coastguard Worker 
375*523fa7a6SAndroid Build Coastguard Worker } // namespace
376*523fa7a6SAndroid Build Coastguard Worker 
PYBIND11_MODULE(pybindings,m)377*523fa7a6SAndroid Build Coastguard Worker PYBIND11_MODULE(pybindings, m) {
378*523fa7a6SAndroid Build Coastguard Worker   m.def("tree_flatten", &tree_flatten, py::arg("tree"));
379*523fa7a6SAndroid Build Coastguard Worker   m.def("tree_unflatten", &tree_unflatten, py::arg("leaves"), py::arg("tree"));
380*523fa7a6SAndroid Build Coastguard Worker   m.def("tree_map", &tree_map);
381*523fa7a6SAndroid Build Coastguard Worker   m.def("from_str", &py_from_str);
382*523fa7a6SAndroid Build Coastguard Worker   m.def("broadcast_to_and_flatten", &broadcast_to_and_flatten);
383*523fa7a6SAndroid Build Coastguard Worker   m.def("register_custom", &PyTypeRegistry::register_custom_type);
384*523fa7a6SAndroid Build Coastguard Worker 
385*523fa7a6SAndroid Build Coastguard Worker   py::class_<PyTree>(m, "TreeSpec")
386*523fa7a6SAndroid Build Coastguard Worker       .def("from_str", &PyTree::py_from_str)
387*523fa7a6SAndroid Build Coastguard Worker       .def(
388*523fa7a6SAndroid Build Coastguard Worker           "tree_unflatten",
389*523fa7a6SAndroid Build Coastguard Worker           static_cast<py::object (PyTree::*)(py::iterable leaves) const>(
390*523fa7a6SAndroid Build Coastguard Worker               &PyTree::tree_unflatten))
391*523fa7a6SAndroid Build Coastguard Worker       .def("__repr__", &PyTree::py_to_str)
392*523fa7a6SAndroid Build Coastguard Worker       .def("__eq__", &PyTree::operator==)
393*523fa7a6SAndroid Build Coastguard Worker       .def("to_str", &PyTree::py_to_str)
394*523fa7a6SAndroid Build Coastguard Worker       .def("num_leaves", &PyTree::leaves_num);
395*523fa7a6SAndroid Build Coastguard Worker }
396*523fa7a6SAndroid Build Coastguard Worker 
397*523fa7a6SAndroid Build Coastguard Worker } // namespace pytree
398*523fa7a6SAndroid Build Coastguard Worker } // namespace extension
399*523fa7a6SAndroid Build Coastguard Worker } // namespace executorch
400