1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Caution: this code uses exceptions. The exception use is local to the
17 // binding code and the idiomatic way to emit Python exceptions.
18
19 #include "tensorflow/compiler/xla/python/pytree.h"
20
21 #include <algorithm>
22 #include <memory>
23 #include <stdexcept>
24 #include <string>
25 #include <utility>
26 #include <vector>
27
28 #include "absl/algorithm/container.h"
29 #include "absl/container/flat_hash_map.h"
30 #include "absl/hash/hash.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_format.h"
33 #include "absl/strings/str_join.h"
34 #include "pybind11/pybind11.h"
35 #include "pybind11/pytypes.h"
36 #include "pybind11/stl.h"
37 #include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil
38 #include "tensorflow/compiler/xla/python/exceptions.h"
39 #include "tensorflow/core/platform/logging.h"
40
41 namespace xla {
42
43 namespace py = pybind11;
44
Singleton()45 /*static*/ PyTreeTypeRegistry* PyTreeTypeRegistry::Singleton() {
46 static auto* registry = []() -> PyTreeTypeRegistry* {
47 auto* registry = new PyTreeTypeRegistry;
48
49 auto add_builtin_type = [&](PyTypeObject* type_obj, PyTreeKind kind) {
50 py::object type = py::reinterpret_borrow<py::object>(
51 reinterpret_cast<PyObject*>(type_obj));
52 auto registration = std::make_unique<Registration>();
53 registration->kind = kind;
54 registration->type = type;
55 CHECK(registry->registrations_.emplace(type, std::move(registration))
56 .second);
57 };
58 add_builtin_type(Py_TYPE(Py_None), PyTreeKind::kNone);
59 add_builtin_type(&PyTuple_Type, PyTreeKind::kTuple);
60 add_builtin_type(&PyList_Type, PyTreeKind::kList);
61 add_builtin_type(&PyDict_Type, PyTreeKind::kDict);
62 return registry;
63 }();
64 return registry;
65 }
66
Register(py::object type,py::function to_iterable,py::function from_iterable)67 /*static*/ void PyTreeTypeRegistry::Register(py::object type,
68 py::function to_iterable,
69 py::function from_iterable) {
70 PyTreeTypeRegistry* registry = Singleton();
71 auto registration = std::make_unique<Registration>();
72 registration->kind = PyTreeKind::kCustom;
73 registration->type = type;
74 registration->to_iterable = std::move(to_iterable);
75 registration->from_iterable = std::move(from_iterable);
76 auto it = registry->registrations_.emplace(type, std::move(registration));
77 if (!it.second) {
78 throw std::invalid_argument(
79 absl::StrFormat("Duplicate custom PyTreeDef type registration for %s.",
80 py::repr(type)));
81 }
82 }
83
Lookup(py::handle type)84 /*static*/ const PyTreeTypeRegistry::Registration* PyTreeTypeRegistry::Lookup(
85 py::handle type) {
86 PyTreeTypeRegistry* registry = Singleton();
87 auto it = registry->registrations_.find(type);
88 return it == registry->registrations_.end() ? nullptr : it->second.get();
89 }
90
operator ==(const PyTreeDef & other) const91 bool PyTreeDef::operator==(const PyTreeDef& other) const {
92 if (traversal_.size() != other.traversal_.size()) {
93 return false;
94 }
95 for (size_t i = 0; i < traversal_.size(); ++i) {
96 const Node& a = traversal_[i];
97 const Node& b = other.traversal_[i];
98 if (a.kind != b.kind || a.arity != b.arity ||
99 (a.node_data.ptr() == nullptr) != (b.node_data.ptr() == nullptr) ||
100 a.custom != b.custom) {
101 return false;
102 }
103 if (a.node_data && a.node_data.not_equal(b.node_data)) {
104 return false;
105 }
106 // We don't need to test equality of num_leaves and num_nodes since they
107 // are derivable from the other node data.
108 }
109 return true;
110 }
111
GetKind(const py::handle & obj,PyTreeTypeRegistry::Registration const ** custom)112 /*static*/ PyTreeKind PyTreeDef::GetKind(
113 const py::handle& obj, PyTreeTypeRegistry::Registration const** custom) {
114 const PyTreeTypeRegistry::Registration* registration =
115 PyTreeTypeRegistry::Lookup(obj.get_type());
116 if (registration) {
117 if (registration->kind == PyTreeKind::kCustom) {
118 *custom = registration;
119 } else {
120 *custom = nullptr;
121 }
122 return registration->kind;
123 } else if (py::isinstance<py::tuple>(obj) && py::hasattr(obj, "_fields")) {
124 // We can only identify namedtuples heuristically, here by the presence of
125 // a _fields attribute.
126 return PyTreeKind::kNamedTuple;
127 } else {
128 return PyTreeKind::kLeaf;
129 }
130 }
131
132 template <typename T>
FlattenIntoImpl(py::handle handle,T & leaves,const std::optional<py::function> & leaf_predicate)133 void PyTreeDef::FlattenIntoImpl(
134 py::handle handle, T& leaves,
135 const std::optional<py::function>& leaf_predicate) {
136 Node node;
137 int start_num_nodes = traversal_.size();
138 int start_num_leaves = leaves.size();
139 if (leaf_predicate && (*leaf_predicate)(handle).cast<bool>()) {
140 leaves.push_back(py::reinterpret_borrow<py::object>(handle));
141 } else {
142 node.kind = GetKind(handle, &node.custom);
143 auto recurse = [this, &leaf_predicate, &leaves](py::handle child) {
144 FlattenInto(child, leaves, leaf_predicate);
145 };
146 switch (node.kind) {
147 case PyTreeKind::kNone:
148 // Nothing to do.
149 break;
150 case PyTreeKind::kTuple: {
151 node.arity = PyTuple_GET_SIZE(handle.ptr());
152 for (int i = 0; i < node.arity; ++i) {
153 recurse(PyTuple_GET_ITEM(handle.ptr(), i));
154 }
155 break;
156 }
157 case PyTreeKind::kList: {
158 node.arity = PyList_GET_SIZE(handle.ptr());
159 for (int i = 0; i < node.arity; ++i) {
160 recurse(PyList_GET_ITEM(handle.ptr(), i));
161 }
162 break;
163 }
164 case PyTreeKind::kDict: {
165 py::dict dict = py::reinterpret_borrow<py::dict>(handle);
166 py::list keys =
167 py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
168 if (PyList_Sort(keys.ptr())) {
169 throw py::error_already_set();
170 }
171 for (py::handle key : keys) {
172 recurse(dict[key]);
173 }
174 node.arity = dict.size();
175 node.node_data = std::move(keys);
176 break;
177 }
178 case PyTreeKind::kCustom: {
179 py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
180 if (out.size() != 2) {
181 throw xla::XlaRuntimeError(
182 "PyTree custom to_iterable function should return a pair");
183 }
184 node.node_data = out[1];
185 node.arity = 0;
186 for (py::handle entry : py::cast<py::iterable>(out[0])) {
187 ++node.arity;
188 recurse(entry);
189 }
190 break;
191 }
192 case PyTreeKind::kNamedTuple: {
193 py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
194 node.arity = tuple.size();
195 node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
196 for (py::handle entry : tuple) {
197 recurse(entry);
198 }
199 break;
200 }
201 default:
202 DCHECK(node.kind == PyTreeKind::kLeaf);
203 leaves.push_back(py::reinterpret_borrow<py::object>(handle));
204 }
205 }
206 node.num_nodes = traversal_.size() - start_num_nodes + 1;
207 node.num_leaves = leaves.size() - start_num_leaves;
208 traversal_.push_back(std::move(node));
209 }
210
FlattenInto(py::handle handle,absl::InlinedVector<py::object,2> & leaves,std::optional<py::function> leaf_predicate)211 void PyTreeDef::FlattenInto(py::handle handle,
212 absl::InlinedVector<py::object, 2>& leaves,
213 std::optional<py::function> leaf_predicate) {
214 FlattenIntoImpl(handle, leaves, leaf_predicate);
215 }
216
FlattenInto(py::handle handle,std::vector<py::object> & leaves,std::optional<py::function> leaf_predicate)217 void PyTreeDef::FlattenInto(py::handle handle, std::vector<py::object>& leaves,
218 std::optional<py::function> leaf_predicate) {
219 FlattenIntoImpl(handle, leaves, leaf_predicate);
220 }
221
222 /*static*/ std::pair<std::vector<py::object>, std::unique_ptr<PyTreeDef>>
Flatten(py::handle x,std::optional<py::function> leaf_predicate)223 PyTreeDef::Flatten(py::handle x, std::optional<py::function> leaf_predicate) {
224 std::vector<py::object> leaves;
225 auto tree = std::make_unique<PyTreeDef>();
226 tree->FlattenInto(x, leaves, leaf_predicate);
227 return std::make_pair(std::move(leaves), std::move(tree));
228 }
229
AllLeaves(const py::iterable & x)230 /*static*/ bool PyTreeDef::AllLeaves(const py::iterable& x) {
231 const PyTreeTypeRegistry::Registration* custom;
232 for (const py::handle& h : x) {
233 if (GetKind(h, &custom) != PyTreeKind::kLeaf) return false;
234 }
235 return true;
236 }
237
238 template <typename T>
UnflattenImpl(T leaves) const239 py::object PyTreeDef::UnflattenImpl(T leaves) const {
240 absl::InlinedVector<py::object, 4> agenda;
241 auto it = leaves.begin();
242 int leaf_count = 0;
243 for (const Node& node : traversal_) {
244 if (agenda.size() < node.arity) {
245 throw std::logic_error("Too few elements for TreeDef node.");
246 }
247 switch (node.kind) {
248 case PyTreeKind::kLeaf:
249 if (it == leaves.end()) {
250 throw std::invalid_argument(absl::StrFormat(
251 "Too few leaves for PyTreeDef; expected %d, got %d", num_leaves(),
252 leaf_count));
253 }
254 agenda.push_back(py::reinterpret_borrow<py::object>(*it));
255 ++it;
256 ++leaf_count;
257 break;
258
259 case PyTreeKind::kNone:
260 case PyTreeKind::kTuple:
261 case PyTreeKind::kNamedTuple:
262 case PyTreeKind::kList:
263 case PyTreeKind::kDict:
264 case PyTreeKind::kCustom: {
265 const int size = agenda.size();
266 absl::Span<py::object> span;
267 if (node.arity > 0) {
268 span = absl::Span<py::object>(&agenda[size - node.arity], node.arity);
269 }
270 py::object o = MakeNode(node, span);
271 agenda.resize(size - node.arity);
272 agenda.push_back(o);
273 break;
274 }
275 }
276 }
277 if (it != leaves.end()) {
278 throw std::invalid_argument(absl::StrFormat(
279 "Too many leaves for PyTreeDef; expected %d.", num_leaves()));
280 }
281 if (agenda.size() != 1) {
282 throw std::logic_error("PyTreeDef traversal did not yield a singleton.");
283 }
284 return std::move(agenda.back());
285 }
286
Unflatten(py::iterable leaves) const287 py::object PyTreeDef::Unflatten(py::iterable leaves) const {
288 return UnflattenImpl(leaves);
289 }
290
Unflatten(absl::Span<const py::object> leaves) const291 py::object PyTreeDef::Unflatten(absl::Span<const py::object> leaves) const {
292 return UnflattenImpl(leaves);
293 }
294
MakeNode(const PyTreeDef::Node & node,absl::Span<py::object> children)295 /*static*/ py::object PyTreeDef::MakeNode(const PyTreeDef::Node& node,
296 absl::Span<py::object> children) {
297 if (children.size() != node.arity) {
298 throw std::logic_error("Node arity mismatch.");
299 }
300 switch (node.kind) {
301 case PyTreeKind::kLeaf:
302 throw std::logic_error("MakeNode not implemented for leaves.");
303
304 case PyTreeKind::kNone:
305 return py::none();
306
307 case PyTreeKind::kTuple:
308 case PyTreeKind::kNamedTuple: {
309 py::tuple tuple(node.arity);
310 for (int i = 0; i < node.arity; ++i) {
311 tuple[i] = std::move(children[i]);
312 }
313 if (node.kind == PyTreeKind::kNamedTuple) {
314 return node.node_data(*tuple);
315 } else {
316 return std::move(tuple);
317 }
318 }
319
320 case PyTreeKind::kList: {
321 py::list list(node.arity);
322 for (int i = 0; i < node.arity; ++i) {
323 list[i] = std::move(children[i]);
324 }
325 return std::move(list);
326 }
327
328 case PyTreeKind::kDict: {
329 py::dict dict;
330 py::list keys = py::reinterpret_borrow<py::list>(node.node_data);
331 for (int i = 0; i < node.arity; ++i) {
332 dict[keys[i]] = std::move(children[i]);
333 }
334 return std::move(dict);
335 break;
336 }
337 case PyTreeKind::kCustom: {
338 py::tuple tuple(node.arity);
339 for (int i = 0; i < node.arity; ++i) {
340 tuple[i] = std::move(children[i]);
341 }
342 return node.custom->from_iterable(node.node_data, tuple);
343 }
344 }
345 throw std::logic_error("Unreachable code.");
346 }
347
FlattenUpTo(py::handle xs) const348 py::list PyTreeDef::FlattenUpTo(py::handle xs) const {
349 py::list leaves(num_leaves());
350 std::vector<py::object> agenda;
351 agenda.push_back(py::reinterpret_borrow<py::object>(xs));
352 auto it = traversal_.rbegin();
353 int leaf = num_leaves() - 1;
354 while (!agenda.empty()) {
355 if (it == traversal_.rend()) {
356 throw std::invalid_argument(absl::StrFormat(
357 "Tree structures did not match: %s vs %s", py::repr(xs), ToString()));
358 }
359 const Node& node = *it;
360 py::object object = agenda.back();
361 agenda.pop_back();
362 ++it;
363
364 switch (node.kind) {
365 case PyTreeKind::kLeaf:
366 if (leaf < 0) {
367 throw std::logic_error("Leaf count mismatch.");
368 }
369 leaves[leaf] = py::reinterpret_borrow<py::object>(object);
370 --leaf;
371 break;
372
373 case PyTreeKind::kNone:
374 break;
375
376 case PyTreeKind::kTuple: {
377 if (!PyTuple_CheckExact(object.ptr())) {
378 throw std::invalid_argument(
379 absl::StrFormat("Expected tuple, got %s.", py::repr(object)));
380 }
381 py::tuple tuple = py::reinterpret_borrow<py::tuple>(object);
382 if (tuple.size() != node.arity) {
383 throw std::invalid_argument(
384 absl::StrFormat("Tuple arity mismatch: %d != %d; tuple: %s.",
385 tuple.size(), node.arity, py::repr(object)));
386 }
387 for (py::handle entry : tuple) {
388 agenda.push_back(py::reinterpret_borrow<py::object>(entry));
389 }
390 break;
391 }
392
393 case PyTreeKind::kList: {
394 if (!PyList_CheckExact(object.ptr())) {
395 throw std::invalid_argument(
396 absl::StrFormat("Expected list, got %s.", py::repr(object)));
397 }
398 py::list list = py::reinterpret_borrow<py::list>(object);
399 if (list.size() != node.arity) {
400 throw std::invalid_argument(
401 absl::StrFormat("List arity mismatch: %d != %d; list: %s.",
402 list.size(), node.arity, py::repr(object)));
403 }
404 for (py::handle entry : list) {
405 agenda.push_back(py::reinterpret_borrow<py::object>(entry));
406 }
407 break;
408 }
409
410 case PyTreeKind::kDict: {
411 if (!PyDict_CheckExact(object.ptr())) {
412 throw std::invalid_argument(
413 absl::StrFormat("Expected dict, got %s.", py::repr(object)));
414 }
415 py::dict dict = py::reinterpret_borrow<py::dict>(object);
416 py::list keys =
417 py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
418 if (PyList_Sort(keys.ptr())) {
419 throw xla::XlaRuntimeError("Dictionary key sort failed.");
420 }
421 if (keys.not_equal(node.node_data)) {
422 throw std::invalid_argument(
423 absl::StrFormat("Dict key mismatch; expected keys: %s; dict: %s.",
424 py::repr(node.node_data), py::repr(object)));
425 }
426 for (py::handle key : keys) {
427 agenda.push_back(dict[key]);
428 }
429 break;
430 }
431
432 case PyTreeKind::kNamedTuple: {
433 if (!py::isinstance<py::tuple>(object) ||
434 !py::hasattr(object, "_fields")) {
435 throw std::invalid_argument(absl::StrFormat(
436 "Expected named tuple, got %s.", py::repr(object)));
437 }
438 py::tuple tuple = py::reinterpret_borrow<py::tuple>(object);
439 if (tuple.size() != node.arity) {
440 throw std::invalid_argument(absl::StrFormat(
441 "Named tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(),
442 node.arity, py::repr(object)));
443 }
444 if (tuple.get_type().not_equal(node.node_data)) {
445 throw std::invalid_argument(absl::StrFormat(
446 "Named tuple type mismatch: expected type: %s, tuple: %s.",
447 py::repr(node.node_data), py::repr(object)));
448 }
449 for (py::handle entry : tuple) {
450 agenda.push_back(py::reinterpret_borrow<py::object>(entry));
451 }
452 break;
453 }
454
455 case PyTreeKind::kCustom: {
456 auto* registration = PyTreeTypeRegistry::Lookup(object.get_type());
457 if (registration != node.custom) {
458 throw std::invalid_argument(absl::StrFormat(
459 "Custom node type mismatch: expected type: %s, value: %s.",
460 py::repr(node.custom->type), py::repr(object)));
461 }
462 py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(object));
463 if (out.size() != 2) {
464 throw xla::XlaRuntimeError(
465 "PyTree custom to_iterable function should return a pair");
466 }
467 if (node.node_data.not_equal(out[1])) {
468 throw std::invalid_argument(absl::StrFormat(
469 "Mismatch custom node data: %s != %s; value: %s.",
470 py::repr(node.node_data), py::repr(out[1]), py::repr(object)));
471 }
472 int arity = 0;
473 for (py::handle entry : py::cast<py::iterable>(out[0])) {
474 ++arity;
475 agenda.push_back(py::reinterpret_borrow<py::object>(entry));
476 }
477 if (arity != node.arity) {
478 throw std::invalid_argument(absl::StrFormat(
479 "Custom type arity mismatch: %d != %d; value: %s.", arity,
480 node.arity, py::repr(object)));
481 }
482 break;
483 }
484 }
485 }
486 if (it != traversal_.rend() || leaf != -1) {
487 throw std::invalid_argument(absl::StrFormat(
488 "Tree structures did not match: %s vs %s", py::repr(xs), ToString()));
489 }
490 return leaves;
491 }
492
Walk(const py::function & f_node,py::handle f_leaf,py::iterable leaves) const493 py::object PyTreeDef::Walk(const py::function& f_node, py::handle f_leaf,
494 py::iterable leaves) const {
495 std::vector<py::object> agenda;
496 auto it = leaves.begin();
497 for (const Node& node : traversal_) {
498 switch (node.kind) {
499 case PyTreeKind::kLeaf: {
500 if (it == leaves.end()) {
501 throw std::invalid_argument("Too few leaves for PyTreeDef");
502 }
503
504 py::object leaf = py::reinterpret_borrow<py::object>(*it);
505 agenda.push_back(f_leaf.is_none() ? std::move(leaf)
506 : f_leaf(std::move(leaf)));
507 ++it;
508 break;
509 }
510
511 case PyTreeKind::kNone:
512 case PyTreeKind::kTuple:
513 case PyTreeKind::kNamedTuple:
514 case PyTreeKind::kList:
515 case PyTreeKind::kDict:
516 case PyTreeKind::kCustom: {
517 if (agenda.size() < node.arity) {
518 throw std::logic_error("Too few elements for custom type.");
519 }
520 py::tuple tuple(node.arity);
521 for (int i = node.arity - 1; i >= 0; --i) {
522 tuple[i] = agenda.back();
523 agenda.pop_back();
524 }
525 agenda.push_back(
526 f_node(tuple, node.node_data ? node.node_data : py::none()));
527 }
528 }
529 }
530 if (it != leaves.end()) {
531 throw std::invalid_argument("Too many leaves for PyTreeDef");
532 }
533 if (agenda.size() != 1) {
534 throw std::logic_error("PyTreeDef traversal did not yield a singleton.");
535 }
536 return std::move(agenda.back());
537 }
538
FromIterableTreeHelper(py::handle xs,absl::InlinedVector<PyTreeDef::Node,1>::const_reverse_iterator * it) const539 py::object PyTreeDef::FromIterableTreeHelper(
540 py::handle xs,
541 absl::InlinedVector<PyTreeDef::Node, 1>::const_reverse_iterator* it) const {
542 if (*it == traversal_.rend()) {
543 throw std::invalid_argument("Tree structures did not match.");
544 }
545 const Node& node = **it;
546 ++*it;
547 if (node.kind == PyTreeKind::kLeaf) {
548 return py::reinterpret_borrow<py::object>(xs);
549 }
550 py::iterable iterable = py::reinterpret_borrow<py::iterable>(xs);
551 std::vector<py::object> ys;
552 ys.reserve(node.arity);
553 for (py::handle x : iterable) {
554 ys.push_back(py::reinterpret_borrow<py::object>(x));
555 }
556 if (ys.size() != node.arity) {
557 throw std::invalid_argument("Arity mismatch between trees");
558 }
559 for (int j = node.arity - 1; j >= 0; --j) {
560 ys[j] = FromIterableTreeHelper(ys[j], it);
561 }
562
563 return MakeNode(node, absl::MakeSpan(ys));
564 }
565
FromIterableTree(py::handle xs) const566 py::object PyTreeDef::FromIterableTree(py::handle xs) const {
567 auto it = traversal_.rbegin();
568 py::object out = FromIterableTreeHelper(xs, &it);
569 if (it != traversal_.rend()) {
570 throw std::invalid_argument("Tree structures did not match.");
571 }
572 return out;
573 }
574
Compose(const PyTreeDef & inner) const575 std::unique_ptr<PyTreeDef> PyTreeDef::Compose(const PyTreeDef& inner) const {
576 auto out = std::make_unique<PyTreeDef>();
577 for (const Node& n : traversal_) {
578 if (n.kind == PyTreeKind::kLeaf) {
579 absl::c_copy(inner.traversal_, std::back_inserter(out->traversal_));
580 } else {
581 out->traversal_.push_back(n);
582 }
583 }
584 const auto& root = traversal_.back();
585 const auto& inner_root = inner.traversal_.back();
586 // TODO(tomhennigan): This should update all nodes in the traversal.
587 auto& out_root = out->traversal_.back();
588 out_root.num_nodes = (root.num_nodes - root.num_leaves) +
589 (inner_root.num_nodes * root.num_leaves);
590 out_root.num_leaves *= inner_root.num_leaves;
591 return out;
592 }
593
Tuple(const std::vector<PyTreeDef> & defs)594 /*static*/ std::unique_ptr<PyTreeDef> PyTreeDef::Tuple(
595 const std::vector<PyTreeDef>& defs) {
596 auto out = std::make_unique<PyTreeDef>();
597 int num_leaves = 0;
598 for (const PyTreeDef& def : defs) {
599 absl::c_copy(def.traversal_, std::back_inserter(out->traversal_));
600 num_leaves += def.num_leaves();
601 }
602 Node node;
603 node.kind = PyTreeKind::kTuple;
604 node.arity = defs.size();
605 node.num_leaves = num_leaves;
606 node.num_nodes = out->traversal_.size() + 1;
607 out->traversal_.push_back(node);
608 return out;
609 }
610
Children() const611 std::vector<std::unique_ptr<PyTreeDef>> PyTreeDef::Children() const {
612 std::vector<std::unique_ptr<PyTreeDef>> children;
613 if (traversal_.empty()) {
614 return children;
615 }
616 Node const& root = traversal_.back();
617 children.resize(root.arity);
618 int pos = traversal_.size() - 1;
619 for (int i = root.arity - 1; i >= 0; --i) {
620 children[i] = std::make_unique<PyTreeDef>();
621 const Node& node = traversal_.at(pos - 1);
622 if (pos < node.num_nodes) {
623 throw std::logic_error("children() walked off start of array");
624 }
625 std::copy(traversal_.begin() + pos - node.num_nodes,
626 traversal_.begin() + pos,
627 std::back_inserter(children[i]->traversal_));
628 pos -= node.num_nodes;
629 }
630 if (pos != 0) {
631 throw std::logic_error("pos != 0 at end of PyTreeDef::Children");
632 }
633 return children;
634 }
635
ToString() const636 std::string PyTreeDef::ToString() const {
637 std::vector<std::string> agenda;
638 for (const Node& node : traversal_) {
639 if (agenda.size() < node.arity) {
640 throw std::logic_error("Too few elements for container.");
641 }
642
643 std::string children =
644 absl::StrJoin(agenda.end() - node.arity, agenda.end(), ", ");
645 std::string representation;
646 switch (node.kind) {
647 case PyTreeKind::kLeaf:
648 agenda.push_back("*");
649 continue;
650 case PyTreeKind::kNone:
651 representation = "None";
652 break;
653 case PyTreeKind::kTuple:
654 // Tuples with only one element must have a trailing comma.
655 if (node.arity == 1) children += ",";
656 representation = absl::StrCat("(", children, ")");
657 break;
658 case PyTreeKind::kList:
659 representation = absl::StrCat("[", children, "]");
660 break;
661 case PyTreeKind::kDict: {
662 if (py::len(node.node_data) != node.arity) {
663 throw std::logic_error("Number of keys and entries does not match.");
664 }
665 representation = "{";
666 std::string separator;
667 auto child_iter = agenda.end() - node.arity;
668 for (const py::handle& key : node.node_data) {
669 absl::StrAppendFormat(&representation, "%s%s: %s", separator,
670 py::repr(key), *child_iter);
671 child_iter++;
672 separator = ", ";
673 }
674 representation += "}";
675 break;
676 }
677
678 case PyTreeKind::kNamedTuple:
679 case PyTreeKind::kCustom: {
680 std::string kind;
681 std::string data;
682 if (node.kind == PyTreeKind::kNamedTuple) {
683 kind = "namedtuple";
684 if (node.node_data) {
685 // Node data for named tuples is the type.
686 data = absl::StrFormat(
687 "[%s]", py::str(py::getattr(node.node_data, "__name__")));
688 }
689 } else {
690 kind = static_cast<std::string>(
691 py::str(py::getattr(node.custom->type, "__name__")));
692 if (node.node_data) {
693 data = absl::StrFormat("[%s]", py::str(node.node_data));
694 }
695 }
696
697 representation =
698 absl::StrFormat("CustomNode(%s%s, [%s])", kind, data, children);
699 break;
700 }
701 }
702 agenda.erase(agenda.end() - node.arity, agenda.end());
703 agenda.push_back(std::move(representation));
704 }
705 if (agenda.size() != 1) {
706 throw std::logic_error("PyTreeDef traversal did not yield a singleton.");
707 }
708 return absl::StrCat("PyTreeDef(", agenda.back(), ")");
709 }
710
ToPickleable() const711 py::object PyTreeDef::ToPickleable() const {
712 py::list traversal;
713 for (const auto& node : traversal_) {
714 traversal.append(
715 py::make_tuple(static_cast<int>(node.kind), node.arity,
716 node.node_data ? node.node_data : py::none(),
717 node.custom != nullptr ? node.custom->type : py::none(),
718 node.num_leaves, node.num_nodes));
719 }
720 return traversal;
721 }
722
FromPickleable(py::object pickleable)723 PyTreeDef PyTreeDef::FromPickleable(py::object pickleable) {
724 PyTreeDef tree;
725 for (const auto& item : pickleable.cast<py::list>()) {
726 auto t = item.cast<py::tuple>();
727 if (t.size() != 6) {
728 throw xla::XlaRuntimeError("Malformed pickled PyTreeDef");
729 }
730 Node& node = tree.traversal_.emplace_back();
731 node.kind = static_cast<PyTreeKind>(t[0].cast<int>());
732 node.arity = t[1].cast<int>();
733 switch (node.kind) {
734 case PyTreeKind::kNamedTuple:
735 node.node_data = t[2].cast<py::type>();
736 break;
737 case PyTreeKind::kDict:
738 node.node_data = t[2].cast<py::list>();
739 break;
740 case PyTreeKind::kCustom:
741 node.node_data = t[2];
742 break;
743 default:
744 if (!t[2].is_none()) {
745 throw xla::XlaRuntimeError("Malformed pickled PyTreeDef");
746 }
747 break;
748 }
749 if (node.kind == PyTreeKind::kCustom) {
750 node.custom = t[3].is_none() ? nullptr : PyTreeTypeRegistry::Lookup(t[3]);
751 if (node.custom == nullptr) {
752 throw xla::XlaRuntimeError(
753 absl::StrCat("Unknown custom type in pickled PyTreeDef: ",
754 static_cast<std::string>(py::repr(t[3]))));
755 }
756 } else {
757 if (!t[3].is_none()) {
758 throw xla::XlaRuntimeError("Malformed pickled PyTreeDef");
759 }
760 }
761 node.num_leaves = t[4].cast<int>();
762 node.num_nodes = t[5].cast<int>();
763 }
764 return tree;
765 }
766
BuildPytreeSubmodule(py::module & m)767 void BuildPytreeSubmodule(py::module& m) {
768 py::module pytree = m.def_submodule("pytree", "Python tree library");
769 pytree.attr("version") = py::int_(3);
770 pytree.def("flatten", &PyTreeDef::Flatten, py::arg("tree"),
771 py::arg("leaf_predicate") = std::nullopt);
772 pytree.def("tuple", &PyTreeDef::Tuple);
773 pytree.def("all_leaves", &PyTreeDef::AllLeaves);
774
775 py::class_<PyTreeDef>(pytree, "PyTreeDef")
776 .def("unflatten",
777 static_cast<pybind11::object (PyTreeDef::*)(
778 pybind11::iterable leaves) const>(&PyTreeDef::Unflatten))
779 .def("flatten_up_to", &PyTreeDef::FlattenUpTo)
780 .def("compose", &PyTreeDef::Compose)
781 .def("walk", &PyTreeDef::Walk,
782 "Walk pytree, calling f_node(node, node_data) at nodes, and f_leaf "
783 "at leaves",
784 py::arg("f_node"), py::arg("f_leaf"), py::arg("leaves"))
785 .def("from_iterable_tree", &PyTreeDef::FromIterableTree)
786 .def("children", &PyTreeDef::Children)
787 .def_property_readonly("num_leaves", &PyTreeDef::num_leaves)
788 .def_property_readonly("num_nodes", &PyTreeDef::num_nodes)
789 .def("__repr__", &PyTreeDef::ToString)
790 .def("__eq__",
791 [](const PyTreeDef& a, const PyTreeDef& b) { return a == b; })
792 .def("__ne__",
793 [](const PyTreeDef& a, const PyTreeDef& b) { return a != b; })
794 .def("__hash__", [](const PyTreeDef& t) { return absl::HashOf(t); })
795 .def(py::pickle(
796 [](const PyTreeDef& t) { return t.ToPickleable(); },
797 [](py::object o) { return PyTreeDef::FromPickleable(o); }));
798
799 pytree.def("register_node", [](py::object type, py::function to_iterable,
800 py::function from_iterable) {
801 return PyTreeTypeRegistry::Register(type, to_iterable, from_iterable);
802 });
803 }
804
805 } // namespace xla
806