1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_
18
19 // See https://jax.readthedocs.io/en/latest/pytrees.html for the documentation
20 // about pytree.
21
22 // Caution: this code uses exceptions. The exception use is local to the
23 // binding code and the idiomatic way to emit Python exceptions.
24
25 #include <memory>
26 #include <stdexcept>
27 #include <string>
28 #include <utility>
29 #include <vector>
30
31 #include "absl/container/flat_hash_map.h"
32 #include "absl/container/inlined_vector.h"
33 #include "absl/hash/hash.h"
34 #include "pybind11/pybind11.h"
35 #include "pybind11/pytypes.h"
36 #include "pybind11/stl.h"
37
38 namespace xla {
39
40 enum class PyTreeKind {
41 kLeaf, // An opaque leaf node
42 kNone, // None.
43 kTuple, // A tuple
44 kNamedTuple, // A collections.namedtuple
45 kList, // A list
46 kDict, // A dict
47 kCustom, // A custom type.
48 };
49
50 // Registry of custom node types.
51 class PyTreeTypeRegistry {
52 public:
53 struct Registration {
54 PyTreeKind kind;
55
56 // The following values are populated for custom types.
57 // The Python type object, used to identify the type.
58 pybind11::object type;
59 // A function with signature: object -> (iterable, aux_data)
60 pybind11::function to_iterable;
61 // A function with signature: (aux_data, iterable) -> object
62 pybind11::function from_iterable;
63 };
64
65 // Registers a new custom type. Objects of `type` will be treated as container
66 // node types in PyTrees.
67 static void Register(pybind11::object type, pybind11::function to_iterable,
68 pybind11::function from_iterable);
69
70 // Finds the custom type registration for `type`. Returns nullptr if none
71 // exists.
72 static const Registration* Lookup(pybind11::handle type);
73
74 private:
75 static PyTreeTypeRegistry* Singleton();
76
77 struct TypeHash {
78 using is_transparent = void;
operatorTypeHash79 size_t operator()(const pybind11::object& t) const {
80 return absl::HashOf(t.ptr());
81 }
operatorTypeHash82 size_t operator()(const pybind11::handle& t) const {
83 return absl::HashOf(t.ptr());
84 }
85 };
86 struct TypeEq {
87 using is_transparent = void;
operatorTypeEq88 bool operator()(const pybind11::object& a,
89 const pybind11::object& b) const {
90 return a.ptr() == b.ptr();
91 }
operatorTypeEq92 bool operator()(const pybind11::object& a,
93 const pybind11::handle& b) const {
94 return a.ptr() == b.ptr();
95 }
96 };
97 absl::flat_hash_map<pybind11::object, std::unique_ptr<Registration>, TypeHash,
98 TypeEq>
99 registrations_;
100 };
101
102 // A PyTreeDef describes the tree structure of a PyTree. A PyTree is a tree of
103 // Python values, where the interior nodes are tuples, lists, dictionaries, or
104 // user-defined containers, and the leaves are other objects.
105 class PyTreeDef {
106 public:
107 PyTreeDef() = default;
108
109 // Flattens a Pytree into a list of leaves and a PyTreeDef.
110 // Returns references to the flattened objects, which might be temporary
111 // objects in the case of custom pytype handlers.
112 static std::pair<std::vector<pybind11::object>, std::unique_ptr<PyTreeDef>>
113 Flatten(pybind11::handle x,
114 std::optional<pybind11::function> leaf_predicate = std::nullopt);
115
116 // Recursive helper used to implement Flatten().
117 void FlattenInto(
118 pybind11::handle handle, std::vector<pybind11::object>& leaves,
119 std::optional<pybind11::function> leaf_predicate = std::nullopt);
120 void FlattenInto(
121 pybind11::handle handle, absl::InlinedVector<pybind11::object, 2>& leaves,
122 std::optional<pybind11::function> leaf_predicate = std::nullopt);
123
124 // Tests whether the given list is a flat list of leaves.
125 static bool AllLeaves(const pybind11::iterable& x);
126
127 // Flattens a Pytree up to this PyTreeDef. 'this' must be a tree prefix of
128 // the tree-structure of 'x'. For example, if we flatten a value
129 // [(1, (2, 3)), {"foo": 4}] with a treedef [(*, *), *], the result is the
130 // list of leaves [1, (2, 3), {"foo": 4}].
131 pybind11::list FlattenUpTo(pybind11::handle x) const;
132
133 // Returns an unflattened PyTree given an iterable of leaves and a PyTreeDef.
134 pybind11::object Unflatten(pybind11::iterable leaves) const;
135 pybind11::object Unflatten(absl::Span<const pybind11::object> leaves) const;
136
137 // Composes two PyTreeDefs, replacing the leaves of this tree with copies of
138 // `inner`.
139 std::unique_ptr<PyTreeDef> Compose(const PyTreeDef& inner) const;
140
141 // Makes a Tuple PyTreeDef out of a vector of PyTreeDefs.
142 static std::unique_ptr<PyTreeDef> Tuple(const std::vector<PyTreeDef>& defs);
143
144 std::vector<std::unique_ptr<PyTreeDef>> Children() const;
145
146 // Maps a function over a PyTree structure, applying f_leaf to each leaf, and
147 // f_node(node, node_data) to each container node.
148 pybind11::object Walk(const pybind11::function& f_node,
149 pybind11::handle f_leaf,
150 pybind11::iterable leaves) const;
151
152 // Given a tree of iterables with the same node/leaf structure as this PyTree,
153 // build the corresponding PyTree.
154 // TODO(phawkins): use flattening everywhere instead and delete this method.
155 pybind11::object FromIterableTree(pybind11::handle xs) const;
156
num_leaves()157 int num_leaves() const {
158 if (traversal_.empty()) {
159 return 0;
160 }
161 return traversal_.back().num_leaves;
162 }
163
num_nodes()164 int num_nodes() const { return traversal_.size(); }
165
166 size_t Hash() const;
167
168 bool operator==(const PyTreeDef& other) const;
169 bool operator!=(const PyTreeDef& other) const { return !(*this == other); }
170
171 std::string ToString() const;
172
173 // Transforms the PyTreeDef into a pickleable object. Used to implement
174 // `PyTreeDef.__getstate__`.
175 pybind11::object ToPickleable() const;
176
177 // Transforms the object returned by `ToPickleable()` back to PyTreeDef. Used
178 // to implement `PyTreeDef.__setstate__`.
179 static PyTreeDef FromPickleable(pybind11::object pickleable);
180
181 private:
182 struct Node {
183 PyTreeKind kind = PyTreeKind::kLeaf;
184
185 // Arity for non-kLeaf types.
186 int arity = 0;
187
188 // Kind-specific auxiliary data. For a kNamedTuple, contains the tuple type
189 // object. For a kDict, contains a sorted list of keys. For a kCustom type,
190 // contains the auxiliary data returned by the `to_iterable` function.
191 pybind11::object node_data;
192
193 // Custom type registration. Must be null for non-custom types.
194 const PyTreeTypeRegistry::Registration* custom = nullptr;
195
196 // Number of leaf nodes in the subtree rooted at this node.
197 int num_leaves = 0;
198
199 // Number of leaf and interior nodes in the subtree rooted at this node.
200 int num_nodes = 0;
201 };
202 template <typename H>
203 friend H AbslHashValue(H h, const Node& n);
204
205 template <typename H>
206 friend H AbslHashValue(H h, const PyTreeDef& t);
207
208 // Helper that manufactures an instance of a node given its children.
209 static pybind11::object MakeNode(const Node& node,
210 absl::Span<pybind11::object> children);
211
212 // Recursive helper used to implement FromIterableTree()
213 pybind11::object FromIterableTreeHelper(
214 pybind11::handle xs,
215 absl::InlinedVector<PyTreeDef::Node, 1>::const_reverse_iterator* it)
216 const;
217
218 // Computes the node kind of a given Python object.
219 static PyTreeKind GetKind(const pybind11::handle& obj,
220 PyTreeTypeRegistry::Registration const** custom);
221
222 template <typename T>
223 void FlattenIntoImpl(pybind11::handle handle, T& leaves,
224 const std::optional<pybind11::function>& leaf_predicate);
225
226 template <typename T>
227 pybind11::object UnflattenImpl(T leaves) const;
228
229 // Nodes, in a post-order traversal. We use an ordered traversal to minimize
230 // allocations, and post-order corresponds to the order we need to rebuild the
231 // tree structure.
232 absl::InlinedVector<Node, 1> traversal_;
233 };
234
235 template <typename H>
AbslHashValue(H h,const PyTreeDef::Node & n)236 H AbslHashValue(H h, const PyTreeDef::Node& n) {
237 h = H::combine(std::move(h), n.kind, n.arity, n.custom);
238 return h;
239 }
240
241 template <typename H>
AbslHashValue(H h,const PyTreeDef & t)242 H AbslHashValue(H h, const PyTreeDef& t) {
243 h = H::combine(std::move(h), t.traversal_);
244 return h;
245 }
246
247 void BuildPytreeSubmodule(pybind11::module& m);
248
249 } // namespace xla
250
251 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_
252