xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/pytree.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Caution: this code uses exceptions. The exception use is local to the
17 // binding code and the idiomatic way to emit Python exceptions.
18 
19 #include "tensorflow/compiler/xla/python/pytree.h"
20 
21 #include <algorithm>
22 #include <memory>
23 #include <stdexcept>
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/hash/hash.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "pybind11/pybind11.h"
35 #include "pybind11/pytypes.h"
36 #include "pybind11/stl.h"
37 #include "pybind11_abseil/absl_casters.h"  // from @pybind11_abseil
38 #include "tensorflow/compiler/xla/python/exceptions.h"
39 #include "tensorflow/core/platform/logging.h"
40 
41 namespace xla {
42 
43 namespace py = pybind11;
44 
Singleton()45 /*static*/ PyTreeTypeRegistry* PyTreeTypeRegistry::Singleton() {
46   static auto* registry = []() -> PyTreeTypeRegistry* {
47     auto* registry = new PyTreeTypeRegistry;
48 
49     auto add_builtin_type = [&](PyTypeObject* type_obj, PyTreeKind kind) {
50       py::object type = py::reinterpret_borrow<py::object>(
51           reinterpret_cast<PyObject*>(type_obj));
52       auto registration = std::make_unique<Registration>();
53       registration->kind = kind;
54       registration->type = type;
55       CHECK(registry->registrations_.emplace(type, std::move(registration))
56                 .second);
57     };
58     add_builtin_type(Py_TYPE(Py_None), PyTreeKind::kNone);
59     add_builtin_type(&PyTuple_Type, PyTreeKind::kTuple);
60     add_builtin_type(&PyList_Type, PyTreeKind::kList);
61     add_builtin_type(&PyDict_Type, PyTreeKind::kDict);
62     return registry;
63   }();
64   return registry;
65 }
66 
Register(py::object type,py::function to_iterable,py::function from_iterable)67 /*static*/ void PyTreeTypeRegistry::Register(py::object type,
68                                              py::function to_iterable,
69                                              py::function from_iterable) {
70   PyTreeTypeRegistry* registry = Singleton();
71   auto registration = std::make_unique<Registration>();
72   registration->kind = PyTreeKind::kCustom;
73   registration->type = type;
74   registration->to_iterable = std::move(to_iterable);
75   registration->from_iterable = std::move(from_iterable);
76   auto it = registry->registrations_.emplace(type, std::move(registration));
77   if (!it.second) {
78     throw std::invalid_argument(
79         absl::StrFormat("Duplicate custom PyTreeDef type registration for %s.",
80                         py::repr(type)));
81   }
82 }
83 
Lookup(py::handle type)84 /*static*/ const PyTreeTypeRegistry::Registration* PyTreeTypeRegistry::Lookup(
85     py::handle type) {
86   PyTreeTypeRegistry* registry = Singleton();
87   auto it = registry->registrations_.find(type);
88   return it == registry->registrations_.end() ? nullptr : it->second.get();
89 }
90 
operator ==(const PyTreeDef & other) const91 bool PyTreeDef::operator==(const PyTreeDef& other) const {
92   if (traversal_.size() != other.traversal_.size()) {
93     return false;
94   }
95   for (size_t i = 0; i < traversal_.size(); ++i) {
96     const Node& a = traversal_[i];
97     const Node& b = other.traversal_[i];
98     if (a.kind != b.kind || a.arity != b.arity ||
99         (a.node_data.ptr() == nullptr) != (b.node_data.ptr() == nullptr) ||
100         a.custom != b.custom) {
101       return false;
102     }
103     if (a.node_data && a.node_data.not_equal(b.node_data)) {
104       return false;
105     }
106     // We don't need to test equality of num_leaves and num_nodes since they
107     // are derivable from the other node data.
108   }
109   return true;
110 }
111 
GetKind(const py::handle & obj,PyTreeTypeRegistry::Registration const ** custom)112 /*static*/ PyTreeKind PyTreeDef::GetKind(
113     const py::handle& obj, PyTreeTypeRegistry::Registration const** custom) {
114   const PyTreeTypeRegistry::Registration* registration =
115       PyTreeTypeRegistry::Lookup(obj.get_type());
116   if (registration) {
117     if (registration->kind == PyTreeKind::kCustom) {
118       *custom = registration;
119     } else {
120       *custom = nullptr;
121     }
122     return registration->kind;
123   } else if (py::isinstance<py::tuple>(obj) && py::hasattr(obj, "_fields")) {
124     // We can only identify namedtuples heuristically, here by the presence of
125     // a _fields attribute.
126     return PyTreeKind::kNamedTuple;
127   } else {
128     return PyTreeKind::kLeaf;
129   }
130 }
131 
132 template <typename T>
FlattenIntoImpl(py::handle handle,T & leaves,const std::optional<py::function> & leaf_predicate)133 void PyTreeDef::FlattenIntoImpl(
134     py::handle handle, T& leaves,
135     const std::optional<py::function>& leaf_predicate) {
136   Node node;
137   int start_num_nodes = traversal_.size();
138   int start_num_leaves = leaves.size();
139   if (leaf_predicate && (*leaf_predicate)(handle).cast<bool>()) {
140     leaves.push_back(py::reinterpret_borrow<py::object>(handle));
141   } else {
142     node.kind = GetKind(handle, &node.custom);
143     auto recurse = [this, &leaf_predicate, &leaves](py::handle child) {
144       FlattenInto(child, leaves, leaf_predicate);
145     };
146     switch (node.kind) {
147       case PyTreeKind::kNone:
148         // Nothing to do.
149         break;
150       case PyTreeKind::kTuple: {
151         node.arity = PyTuple_GET_SIZE(handle.ptr());
152         for (int i = 0; i < node.arity; ++i) {
153           recurse(PyTuple_GET_ITEM(handle.ptr(), i));
154         }
155         break;
156       }
157       case PyTreeKind::kList: {
158         node.arity = PyList_GET_SIZE(handle.ptr());
159         for (int i = 0; i < node.arity; ++i) {
160           recurse(PyList_GET_ITEM(handle.ptr(), i));
161         }
162         break;
163       }
164       case PyTreeKind::kDict: {
165         py::dict dict = py::reinterpret_borrow<py::dict>(handle);
166         py::list keys =
167             py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
168         if (PyList_Sort(keys.ptr())) {
169           throw py::error_already_set();
170         }
171         for (py::handle key : keys) {
172           recurse(dict[key]);
173         }
174         node.arity = dict.size();
175         node.node_data = std::move(keys);
176         break;
177       }
178       case PyTreeKind::kCustom: {
179         py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
180         if (out.size() != 2) {
181           throw xla::XlaRuntimeError(
182               "PyTree custom to_iterable function should return a pair");
183         }
184         node.node_data = out[1];
185         node.arity = 0;
186         for (py::handle entry : py::cast<py::iterable>(out[0])) {
187           ++node.arity;
188           recurse(entry);
189         }
190         break;
191       }
192       case PyTreeKind::kNamedTuple: {
193         py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
194         node.arity = tuple.size();
195         node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
196         for (py::handle entry : tuple) {
197           recurse(entry);
198         }
199         break;
200       }
201       default:
202         DCHECK(node.kind == PyTreeKind::kLeaf);
203         leaves.push_back(py::reinterpret_borrow<py::object>(handle));
204     }
205   }
206   node.num_nodes = traversal_.size() - start_num_nodes + 1;
207   node.num_leaves = leaves.size() - start_num_leaves;
208   traversal_.push_back(std::move(node));
209 }
210 
FlattenInto(py::handle handle,absl::InlinedVector<py::object,2> & leaves,std::optional<py::function> leaf_predicate)211 void PyTreeDef::FlattenInto(py::handle handle,
212                             absl::InlinedVector<py::object, 2>& leaves,
213                             std::optional<py::function> leaf_predicate) {
214   FlattenIntoImpl(handle, leaves, leaf_predicate);
215 }
216 
FlattenInto(py::handle handle,std::vector<py::object> & leaves,std::optional<py::function> leaf_predicate)217 void PyTreeDef::FlattenInto(py::handle handle, std::vector<py::object>& leaves,
218                             std::optional<py::function> leaf_predicate) {
219   FlattenIntoImpl(handle, leaves, leaf_predicate);
220 }
221 
222 /*static*/ std::pair<std::vector<py::object>, std::unique_ptr<PyTreeDef>>
Flatten(py::handle x,std::optional<py::function> leaf_predicate)223 PyTreeDef::Flatten(py::handle x, std::optional<py::function> leaf_predicate) {
224   std::vector<py::object> leaves;
225   auto tree = std::make_unique<PyTreeDef>();
226   tree->FlattenInto(x, leaves, leaf_predicate);
227   return std::make_pair(std::move(leaves), std::move(tree));
228 }
229 
AllLeaves(const py::iterable & x)230 /*static*/ bool PyTreeDef::AllLeaves(const py::iterable& x) {
231   const PyTreeTypeRegistry::Registration* custom;
232   for (const py::handle& h : x) {
233     if (GetKind(h, &custom) != PyTreeKind::kLeaf) return false;
234   }
235   return true;
236 }
237 
238 template <typename T>
UnflattenImpl(T leaves) const239 py::object PyTreeDef::UnflattenImpl(T leaves) const {
240   absl::InlinedVector<py::object, 4> agenda;
241   auto it = leaves.begin();
242   int leaf_count = 0;
243   for (const Node& node : traversal_) {
244     if (agenda.size() < node.arity) {
245       throw std::logic_error("Too few elements for TreeDef node.");
246     }
247     switch (node.kind) {
248       case PyTreeKind::kLeaf:
249         if (it == leaves.end()) {
250           throw std::invalid_argument(absl::StrFormat(
251               "Too few leaves for PyTreeDef; expected %d, got %d", num_leaves(),
252               leaf_count));
253         }
254         agenda.push_back(py::reinterpret_borrow<py::object>(*it));
255         ++it;
256         ++leaf_count;
257         break;
258 
259       case PyTreeKind::kNone:
260       case PyTreeKind::kTuple:
261       case PyTreeKind::kNamedTuple:
262       case PyTreeKind::kList:
263       case PyTreeKind::kDict:
264       case PyTreeKind::kCustom: {
265         const int size = agenda.size();
266         absl::Span<py::object> span;
267         if (node.arity > 0) {
268           span = absl::Span<py::object>(&agenda[size - node.arity], node.arity);
269         }
270         py::object o = MakeNode(node, span);
271         agenda.resize(size - node.arity);
272         agenda.push_back(o);
273         break;
274       }
275     }
276   }
277   if (it != leaves.end()) {
278     throw std::invalid_argument(absl::StrFormat(
279         "Too many leaves for PyTreeDef; expected %d.", num_leaves()));
280   }
281   if (agenda.size() != 1) {
282     throw std::logic_error("PyTreeDef traversal did not yield a singleton.");
283   }
284   return std::move(agenda.back());
285 }
286 
Unflatten(py::iterable leaves) const287 py::object PyTreeDef::Unflatten(py::iterable leaves) const {
288   return UnflattenImpl(leaves);
289 }
290 
Unflatten(absl::Span<const py::object> leaves) const291 py::object PyTreeDef::Unflatten(absl::Span<const py::object> leaves) const {
292   return UnflattenImpl(leaves);
293 }
294 
MakeNode(const PyTreeDef::Node & node,absl::Span<py::object> children)295 /*static*/ py::object PyTreeDef::MakeNode(const PyTreeDef::Node& node,
296                                           absl::Span<py::object> children) {
297   if (children.size() != node.arity) {
298     throw std::logic_error("Node arity mismatch.");
299   }
300   switch (node.kind) {
301     case PyTreeKind::kLeaf:
302       throw std::logic_error("MakeNode not implemented for leaves.");
303 
304     case PyTreeKind::kNone:
305       return py::none();
306 
307     case PyTreeKind::kTuple:
308     case PyTreeKind::kNamedTuple: {
309       py::tuple tuple(node.arity);
310       for (int i = 0; i < node.arity; ++i) {
311         tuple[i] = std::move(children[i]);
312       }
313       if (node.kind == PyTreeKind::kNamedTuple) {
314         return node.node_data(*tuple);
315       } else {
316         return std::move(tuple);
317       }
318     }
319 
320     case PyTreeKind::kList: {
321       py::list list(node.arity);
322       for (int i = 0; i < node.arity; ++i) {
323         list[i] = std::move(children[i]);
324       }
325       return std::move(list);
326     }
327 
328     case PyTreeKind::kDict: {
329       py::dict dict;
330       py::list keys = py::reinterpret_borrow<py::list>(node.node_data);
331       for (int i = 0; i < node.arity; ++i) {
332         dict[keys[i]] = std::move(children[i]);
333       }
334       return std::move(dict);
335       break;
336     }
337     case PyTreeKind::kCustom: {
338       py::tuple tuple(node.arity);
339       for (int i = 0; i < node.arity; ++i) {
340         tuple[i] = std::move(children[i]);
341       }
342       return node.custom->from_iterable(node.node_data, tuple);
343     }
344   }
345   throw std::logic_error("Unreachable code.");
346 }
347 
FlattenUpTo(py::handle xs) const348 py::list PyTreeDef::FlattenUpTo(py::handle xs) const {
349   py::list leaves(num_leaves());
350   std::vector<py::object> agenda;
351   agenda.push_back(py::reinterpret_borrow<py::object>(xs));
352   auto it = traversal_.rbegin();
353   int leaf = num_leaves() - 1;
354   while (!agenda.empty()) {
355     if (it == traversal_.rend()) {
356       throw std::invalid_argument(absl::StrFormat(
357           "Tree structures did not match: %s vs %s", py::repr(xs), ToString()));
358     }
359     const Node& node = *it;
360     py::object object = agenda.back();
361     agenda.pop_back();
362     ++it;
363 
364     switch (node.kind) {
365       case PyTreeKind::kLeaf:
366         if (leaf < 0) {
367           throw std::logic_error("Leaf count mismatch.");
368         }
369         leaves[leaf] = py::reinterpret_borrow<py::object>(object);
370         --leaf;
371         break;
372 
373       case PyTreeKind::kNone:
374         break;
375 
376       case PyTreeKind::kTuple: {
377         if (!PyTuple_CheckExact(object.ptr())) {
378           throw std::invalid_argument(
379               absl::StrFormat("Expected tuple, got %s.", py::repr(object)));
380         }
381         py::tuple tuple = py::reinterpret_borrow<py::tuple>(object);
382         if (tuple.size() != node.arity) {
383           throw std::invalid_argument(
384               absl::StrFormat("Tuple arity mismatch: %d != %d; tuple: %s.",
385                               tuple.size(), node.arity, py::repr(object)));
386         }
387         for (py::handle entry : tuple) {
388           agenda.push_back(py::reinterpret_borrow<py::object>(entry));
389         }
390         break;
391       }
392 
393       case PyTreeKind::kList: {
394         if (!PyList_CheckExact(object.ptr())) {
395           throw std::invalid_argument(
396               absl::StrFormat("Expected list, got %s.", py::repr(object)));
397         }
398         py::list list = py::reinterpret_borrow<py::list>(object);
399         if (list.size() != node.arity) {
400           throw std::invalid_argument(
401               absl::StrFormat("List arity mismatch: %d != %d; list: %s.",
402                               list.size(), node.arity, py::repr(object)));
403         }
404         for (py::handle entry : list) {
405           agenda.push_back(py::reinterpret_borrow<py::object>(entry));
406         }
407         break;
408       }
409 
410       case PyTreeKind::kDict: {
411         if (!PyDict_CheckExact(object.ptr())) {
412           throw std::invalid_argument(
413               absl::StrFormat("Expected dict, got %s.", py::repr(object)));
414         }
415         py::dict dict = py::reinterpret_borrow<py::dict>(object);
416         py::list keys =
417             py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
418         if (PyList_Sort(keys.ptr())) {
419           throw xla::XlaRuntimeError("Dictionary key sort failed.");
420         }
421         if (keys.not_equal(node.node_data)) {
422           throw std::invalid_argument(
423               absl::StrFormat("Dict key mismatch; expected keys: %s; dict: %s.",
424                               py::repr(node.node_data), py::repr(object)));
425         }
426         for (py::handle key : keys) {
427           agenda.push_back(dict[key]);
428         }
429         break;
430       }
431 
432       case PyTreeKind::kNamedTuple: {
433         if (!py::isinstance<py::tuple>(object) ||
434             !py::hasattr(object, "_fields")) {
435           throw std::invalid_argument(absl::StrFormat(
436               "Expected named tuple, got %s.", py::repr(object)));
437         }
438         py::tuple tuple = py::reinterpret_borrow<py::tuple>(object);
439         if (tuple.size() != node.arity) {
440           throw std::invalid_argument(absl::StrFormat(
441               "Named tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(),
442               node.arity, py::repr(object)));
443         }
444         if (tuple.get_type().not_equal(node.node_data)) {
445           throw std::invalid_argument(absl::StrFormat(
446               "Named tuple type mismatch: expected type: %s, tuple: %s.",
447               py::repr(node.node_data), py::repr(object)));
448         }
449         for (py::handle entry : tuple) {
450           agenda.push_back(py::reinterpret_borrow<py::object>(entry));
451         }
452         break;
453       }
454 
455       case PyTreeKind::kCustom: {
456         auto* registration = PyTreeTypeRegistry::Lookup(object.get_type());
457         if (registration != node.custom) {
458           throw std::invalid_argument(absl::StrFormat(
459               "Custom node type mismatch: expected type: %s, value: %s.",
460               py::repr(node.custom->type), py::repr(object)));
461         }
462         py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(object));
463         if (out.size() != 2) {
464           throw xla::XlaRuntimeError(
465               "PyTree custom to_iterable function should return a pair");
466         }
467         if (node.node_data.not_equal(out[1])) {
468           throw std::invalid_argument(absl::StrFormat(
469               "Mismatch custom node data: %s != %s; value: %s.",
470               py::repr(node.node_data), py::repr(out[1]), py::repr(object)));
471         }
472         int arity = 0;
473         for (py::handle entry : py::cast<py::iterable>(out[0])) {
474           ++arity;
475           agenda.push_back(py::reinterpret_borrow<py::object>(entry));
476         }
477         if (arity != node.arity) {
478           throw std::invalid_argument(absl::StrFormat(
479               "Custom type arity mismatch: %d != %d; value: %s.", arity,
480               node.arity, py::repr(object)));
481         }
482         break;
483       }
484     }
485   }
486   if (it != traversal_.rend() || leaf != -1) {
487     throw std::invalid_argument(absl::StrFormat(
488         "Tree structures did not match: %s vs %s", py::repr(xs), ToString()));
489   }
490   return leaves;
491 }
492 
Walk(const py::function & f_node,py::handle f_leaf,py::iterable leaves) const493 py::object PyTreeDef::Walk(const py::function& f_node, py::handle f_leaf,
494                            py::iterable leaves) const {
495   std::vector<py::object> agenda;
496   auto it = leaves.begin();
497   for (const Node& node : traversal_) {
498     switch (node.kind) {
499       case PyTreeKind::kLeaf: {
500         if (it == leaves.end()) {
501           throw std::invalid_argument("Too few leaves for PyTreeDef");
502         }
503 
504         py::object leaf = py::reinterpret_borrow<py::object>(*it);
505         agenda.push_back(f_leaf.is_none() ? std::move(leaf)
506                                           : f_leaf(std::move(leaf)));
507         ++it;
508         break;
509       }
510 
511       case PyTreeKind::kNone:
512       case PyTreeKind::kTuple:
513       case PyTreeKind::kNamedTuple:
514       case PyTreeKind::kList:
515       case PyTreeKind::kDict:
516       case PyTreeKind::kCustom: {
517         if (agenda.size() < node.arity) {
518           throw std::logic_error("Too few elements for custom type.");
519         }
520         py::tuple tuple(node.arity);
521         for (int i = node.arity - 1; i >= 0; --i) {
522           tuple[i] = agenda.back();
523           agenda.pop_back();
524         }
525         agenda.push_back(
526             f_node(tuple, node.node_data ? node.node_data : py::none()));
527       }
528     }
529   }
530   if (it != leaves.end()) {
531     throw std::invalid_argument("Too many leaves for PyTreeDef");
532   }
533   if (agenda.size() != 1) {
534     throw std::logic_error("PyTreeDef traversal did not yield a singleton.");
535   }
536   return std::move(agenda.back());
537 }
538 
FromIterableTreeHelper(py::handle xs,absl::InlinedVector<PyTreeDef::Node,1>::const_reverse_iterator * it) const539 py::object PyTreeDef::FromIterableTreeHelper(
540     py::handle xs,
541     absl::InlinedVector<PyTreeDef::Node, 1>::const_reverse_iterator* it) const {
542   if (*it == traversal_.rend()) {
543     throw std::invalid_argument("Tree structures did not match.");
544   }
545   const Node& node = **it;
546   ++*it;
547   if (node.kind == PyTreeKind::kLeaf) {
548     return py::reinterpret_borrow<py::object>(xs);
549   }
550   py::iterable iterable = py::reinterpret_borrow<py::iterable>(xs);
551   std::vector<py::object> ys;
552   ys.reserve(node.arity);
553   for (py::handle x : iterable) {
554     ys.push_back(py::reinterpret_borrow<py::object>(x));
555   }
556   if (ys.size() != node.arity) {
557     throw std::invalid_argument("Arity mismatch between trees");
558   }
559   for (int j = node.arity - 1; j >= 0; --j) {
560     ys[j] = FromIterableTreeHelper(ys[j], it);
561   }
562 
563   return MakeNode(node, absl::MakeSpan(ys));
564 }
565 
FromIterableTree(py::handle xs) const566 py::object PyTreeDef::FromIterableTree(py::handle xs) const {
567   auto it = traversal_.rbegin();
568   py::object out = FromIterableTreeHelper(xs, &it);
569   if (it != traversal_.rend()) {
570     throw std::invalid_argument("Tree structures did not match.");
571   }
572   return out;
573 }
574 
Compose(const PyTreeDef & inner) const575 std::unique_ptr<PyTreeDef> PyTreeDef::Compose(const PyTreeDef& inner) const {
576   auto out = std::make_unique<PyTreeDef>();
577   for (const Node& n : traversal_) {
578     if (n.kind == PyTreeKind::kLeaf) {
579       absl::c_copy(inner.traversal_, std::back_inserter(out->traversal_));
580     } else {
581       out->traversal_.push_back(n);
582     }
583   }
584   const auto& root = traversal_.back();
585   const auto& inner_root = inner.traversal_.back();
586   // TODO(tomhennigan): This should update all nodes in the traversal.
587   auto& out_root = out->traversal_.back();
588   out_root.num_nodes = (root.num_nodes - root.num_leaves) +
589                        (inner_root.num_nodes * root.num_leaves);
590   out_root.num_leaves *= inner_root.num_leaves;
591   return out;
592 }
593 
Tuple(const std::vector<PyTreeDef> & defs)594 /*static*/ std::unique_ptr<PyTreeDef> PyTreeDef::Tuple(
595     const std::vector<PyTreeDef>& defs) {
596   auto out = std::make_unique<PyTreeDef>();
597   int num_leaves = 0;
598   for (const PyTreeDef& def : defs) {
599     absl::c_copy(def.traversal_, std::back_inserter(out->traversal_));
600     num_leaves += def.num_leaves();
601   }
602   Node node;
603   node.kind = PyTreeKind::kTuple;
604   node.arity = defs.size();
605   node.num_leaves = num_leaves;
606   node.num_nodes = out->traversal_.size() + 1;
607   out->traversal_.push_back(node);
608   return out;
609 }
610 
Children() const611 std::vector<std::unique_ptr<PyTreeDef>> PyTreeDef::Children() const {
612   std::vector<std::unique_ptr<PyTreeDef>> children;
613   if (traversal_.empty()) {
614     return children;
615   }
616   Node const& root = traversal_.back();
617   children.resize(root.arity);
618   int pos = traversal_.size() - 1;
619   for (int i = root.arity - 1; i >= 0; --i) {
620     children[i] = std::make_unique<PyTreeDef>();
621     const Node& node = traversal_.at(pos - 1);
622     if (pos < node.num_nodes) {
623       throw std::logic_error("children() walked off start of array");
624     }
625     std::copy(traversal_.begin() + pos - node.num_nodes,
626               traversal_.begin() + pos,
627               std::back_inserter(children[i]->traversal_));
628     pos -= node.num_nodes;
629   }
630   if (pos != 0) {
631     throw std::logic_error("pos != 0 at end of PyTreeDef::Children");
632   }
633   return children;
634 }
635 
ToString() const636 std::string PyTreeDef::ToString() const {
637   std::vector<std::string> agenda;
638   for (const Node& node : traversal_) {
639     if (agenda.size() < node.arity) {
640       throw std::logic_error("Too few elements for container.");
641     }
642 
643     std::string children =
644         absl::StrJoin(agenda.end() - node.arity, agenda.end(), ", ");
645     std::string representation;
646     switch (node.kind) {
647       case PyTreeKind::kLeaf:
648         agenda.push_back("*");
649         continue;
650       case PyTreeKind::kNone:
651         representation = "None";
652         break;
653       case PyTreeKind::kTuple:
654         // Tuples with only one element must have a trailing comma.
655         if (node.arity == 1) children += ",";
656         representation = absl::StrCat("(", children, ")");
657         break;
658       case PyTreeKind::kList:
659         representation = absl::StrCat("[", children, "]");
660         break;
661       case PyTreeKind::kDict: {
662         if (py::len(node.node_data) != node.arity) {
663           throw std::logic_error("Number of keys and entries does not match.");
664         }
665         representation = "{";
666         std::string separator;
667         auto child_iter = agenda.end() - node.arity;
668         for (const py::handle& key : node.node_data) {
669           absl::StrAppendFormat(&representation, "%s%s: %s", separator,
670                                 py::repr(key), *child_iter);
671           child_iter++;
672           separator = ", ";
673         }
674         representation += "}";
675         break;
676       }
677 
678       case PyTreeKind::kNamedTuple:
679       case PyTreeKind::kCustom: {
680         std::string kind;
681         std::string data;
682         if (node.kind == PyTreeKind::kNamedTuple) {
683           kind = "namedtuple";
684           if (node.node_data) {
685             // Node data for named tuples is the type.
686             data = absl::StrFormat(
687                 "[%s]", py::str(py::getattr(node.node_data, "__name__")));
688           }
689         } else {
690           kind = static_cast<std::string>(
691               py::str(py::getattr(node.custom->type, "__name__")));
692           if (node.node_data) {
693             data = absl::StrFormat("[%s]", py::str(node.node_data));
694           }
695         }
696 
697         representation =
698             absl::StrFormat("CustomNode(%s%s, [%s])", kind, data, children);
699         break;
700       }
701     }
702     agenda.erase(agenda.end() - node.arity, agenda.end());
703     agenda.push_back(std::move(representation));
704   }
705   if (agenda.size() != 1) {
706     throw std::logic_error("PyTreeDef traversal did not yield a singleton.");
707   }
708   return absl::StrCat("PyTreeDef(", agenda.back(), ")");
709 }
710 
ToPickleable() const711 py::object PyTreeDef::ToPickleable() const {
712   py::list traversal;
713   for (const auto& node : traversal_) {
714     traversal.append(
715         py::make_tuple(static_cast<int>(node.kind), node.arity,
716                        node.node_data ? node.node_data : py::none(),
717                        node.custom != nullptr ? node.custom->type : py::none(),
718                        node.num_leaves, node.num_nodes));
719   }
720   return traversal;
721 }
722 
FromPickleable(py::object pickleable)723 PyTreeDef PyTreeDef::FromPickleable(py::object pickleable) {
724   PyTreeDef tree;
725   for (const auto& item : pickleable.cast<py::list>()) {
726     auto t = item.cast<py::tuple>();
727     if (t.size() != 6) {
728       throw xla::XlaRuntimeError("Malformed pickled PyTreeDef");
729     }
730     Node& node = tree.traversal_.emplace_back();
731     node.kind = static_cast<PyTreeKind>(t[0].cast<int>());
732     node.arity = t[1].cast<int>();
733     switch (node.kind) {
734       case PyTreeKind::kNamedTuple:
735         node.node_data = t[2].cast<py::type>();
736         break;
737       case PyTreeKind::kDict:
738         node.node_data = t[2].cast<py::list>();
739         break;
740       case PyTreeKind::kCustom:
741         node.node_data = t[2];
742         break;
743       default:
744         if (!t[2].is_none()) {
745           throw xla::XlaRuntimeError("Malformed pickled PyTreeDef");
746         }
747         break;
748     }
749     if (node.kind == PyTreeKind::kCustom) {
750       node.custom = t[3].is_none() ? nullptr : PyTreeTypeRegistry::Lookup(t[3]);
751       if (node.custom == nullptr) {
752         throw xla::XlaRuntimeError(
753             absl::StrCat("Unknown custom type in pickled PyTreeDef: ",
754                          static_cast<std::string>(py::repr(t[3]))));
755       }
756     } else {
757       if (!t[3].is_none()) {
758         throw xla::XlaRuntimeError("Malformed pickled PyTreeDef");
759       }
760     }
761     node.num_leaves = t[4].cast<int>();
762     node.num_nodes = t[5].cast<int>();
763   }
764   return tree;
765 }
766 
BuildPytreeSubmodule(py::module & m)767 void BuildPytreeSubmodule(py::module& m) {
768   py::module pytree = m.def_submodule("pytree", "Python tree library");
769   pytree.attr("version") = py::int_(3);
770   pytree.def("flatten", &PyTreeDef::Flatten, py::arg("tree"),
771              py::arg("leaf_predicate") = std::nullopt);
772   pytree.def("tuple", &PyTreeDef::Tuple);
773   pytree.def("all_leaves", &PyTreeDef::AllLeaves);
774 
775   py::class_<PyTreeDef>(pytree, "PyTreeDef")
776       .def("unflatten",
777            static_cast<pybind11::object (PyTreeDef::*)(
778                pybind11::iterable leaves) const>(&PyTreeDef::Unflatten))
779       .def("flatten_up_to", &PyTreeDef::FlattenUpTo)
780       .def("compose", &PyTreeDef::Compose)
781       .def("walk", &PyTreeDef::Walk,
782            "Walk pytree, calling f_node(node, node_data) at nodes, and f_leaf "
783            "at leaves",
784            py::arg("f_node"), py::arg("f_leaf"), py::arg("leaves"))
785       .def("from_iterable_tree", &PyTreeDef::FromIterableTree)
786       .def("children", &PyTreeDef::Children)
787       .def_property_readonly("num_leaves", &PyTreeDef::num_leaves)
788       .def_property_readonly("num_nodes", &PyTreeDef::num_nodes)
789       .def("__repr__", &PyTreeDef::ToString)
790       .def("__eq__",
791            [](const PyTreeDef& a, const PyTreeDef& b) { return a == b; })
792       .def("__ne__",
793            [](const PyTreeDef& a, const PyTreeDef& b) { return a != b; })
794       .def("__hash__", [](const PyTreeDef& t) { return absl::HashOf(t); })
795       .def(py::pickle(
796           [](const PyTreeDef& t) { return t.ToPickleable(); },
797           [](py::object o) { return PyTreeDef::FromPickleable(o); }));
798 
799   pytree.def("register_node", [](py::object type, py::function to_iterable,
800                                  py::function from_iterable) {
801     return PyTreeTypeRegistry::Register(type, to_iterable, from_iterable);
802   });
803 }
804 
805 }  // namespace xla
806