xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/python/pytree.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_
18 
19 // See https://jax.readthedocs.io/en/latest/pytrees.html for the documentation
20 // about pytree.
21 
22 // Caution: this code uses exceptions. The exception use is local to the
23 // binding code and the idiomatic way to emit Python exceptions.
24 
25 #include <memory>
26 #include <stdexcept>
27 #include <string>
28 #include <utility>
29 #include <vector>
30 
31 #include "absl/container/flat_hash_map.h"
32 #include "absl/container/inlined_vector.h"
33 #include "absl/hash/hash.h"
34 #include "pybind11/pybind11.h"
35 #include "pybind11/pytypes.h"
36 #include "pybind11/stl.h"
37 
38 namespace xla {
39 
40 enum class PyTreeKind {
41   kLeaf,        // An opaque leaf node
42   kNone,        // None.
43   kTuple,       // A tuple
44   kNamedTuple,  // A collections.namedtuple
45   kList,        // A list
46   kDict,        // A dict
47   kCustom,      // A custom type.
48 };
49 
50 // Registry of custom node types.
51 class PyTreeTypeRegistry {
52  public:
53   struct Registration {
54     PyTreeKind kind;
55 
56     // The following values are populated for custom types.
57     // The Python type object, used to identify the type.
58     pybind11::object type;
59     // A function with signature: object -> (iterable, aux_data)
60     pybind11::function to_iterable;
61     // A function with signature: (aux_data, iterable) -> object
62     pybind11::function from_iterable;
63   };
64 
65   // Registers a new custom type. Objects of `type` will be treated as container
66   // node types in PyTrees.
67   static void Register(pybind11::object type, pybind11::function to_iterable,
68                        pybind11::function from_iterable);
69 
70   // Finds the custom type registration for `type`. Returns nullptr if none
71   // exists.
72   static const Registration* Lookup(pybind11::handle type);
73 
74  private:
75   static PyTreeTypeRegistry* Singleton();
76 
77   struct TypeHash {
78     using is_transparent = void;
operatorTypeHash79     size_t operator()(const pybind11::object& t) const {
80       return absl::HashOf(t.ptr());
81     }
operatorTypeHash82     size_t operator()(const pybind11::handle& t) const {
83       return absl::HashOf(t.ptr());
84     }
85   };
86   struct TypeEq {
87     using is_transparent = void;
operatorTypeEq88     bool operator()(const pybind11::object& a,
89                     const pybind11::object& b) const {
90       return a.ptr() == b.ptr();
91     }
operatorTypeEq92     bool operator()(const pybind11::object& a,
93                     const pybind11::handle& b) const {
94       return a.ptr() == b.ptr();
95     }
96   };
97   absl::flat_hash_map<pybind11::object, std::unique_ptr<Registration>, TypeHash,
98                       TypeEq>
99       registrations_;
100 };
101 
102 // A PyTreeDef describes the tree structure of a PyTree. A PyTree is a tree of
103 // Python values, where the interior nodes are tuples, lists, dictionaries, or
104 // user-defined containers, and the leaves are other objects.
105 class PyTreeDef {
106  public:
107   PyTreeDef() = default;
108 
109   // Flattens a Pytree into a list of leaves and a PyTreeDef.
110   // Returns references to the flattened objects, which might be temporary
111   // objects in the case of custom pytype handlers.
112   static std::pair<std::vector<pybind11::object>, std::unique_ptr<PyTreeDef>>
113   Flatten(pybind11::handle x,
114           std::optional<pybind11::function> leaf_predicate = std::nullopt);
115 
116   // Recursive helper used to implement Flatten().
117   void FlattenInto(
118       pybind11::handle handle, std::vector<pybind11::object>& leaves,
119       std::optional<pybind11::function> leaf_predicate = std::nullopt);
120   void FlattenInto(
121       pybind11::handle handle, absl::InlinedVector<pybind11::object, 2>& leaves,
122       std::optional<pybind11::function> leaf_predicate = std::nullopt);
123 
124   // Tests whether the given list is a flat list of leaves.
125   static bool AllLeaves(const pybind11::iterable& x);
126 
127   // Flattens a Pytree up to this PyTreeDef. 'this' must be a tree prefix of
128   // the tree-structure of 'x'. For example, if we flatten a value
129   // [(1, (2, 3)), {"foo": 4}] with a treedef [(*, *), *], the result is the
130   // list of leaves [1, (2, 3), {"foo": 4}].
131   pybind11::list FlattenUpTo(pybind11::handle x) const;
132 
133   // Returns an unflattened PyTree given an iterable of leaves and a PyTreeDef.
134   pybind11::object Unflatten(pybind11::iterable leaves) const;
135   pybind11::object Unflatten(absl::Span<const pybind11::object> leaves) const;
136 
137   // Composes two PyTreeDefs, replacing the leaves of this tree with copies of
138   // `inner`.
139   std::unique_ptr<PyTreeDef> Compose(const PyTreeDef& inner) const;
140 
141   // Makes a Tuple PyTreeDef out of a vector of PyTreeDefs.
142   static std::unique_ptr<PyTreeDef> Tuple(const std::vector<PyTreeDef>& defs);
143 
144   std::vector<std::unique_ptr<PyTreeDef>> Children() const;
145 
146   // Maps a function over a PyTree structure, applying f_leaf to each leaf, and
147   // f_node(node, node_data) to each container node.
148   pybind11::object Walk(const pybind11::function& f_node,
149                         pybind11::handle f_leaf,
150                         pybind11::iterable leaves) const;
151 
152   // Given a tree of iterables with the same node/leaf structure as this PyTree,
153   // build the corresponding PyTree.
154   // TODO(phawkins): use flattening everywhere instead and delete this method.
155   pybind11::object FromIterableTree(pybind11::handle xs) const;
156 
num_leaves()157   int num_leaves() const {
158     if (traversal_.empty()) {
159       return 0;
160     }
161     return traversal_.back().num_leaves;
162   }
163 
num_nodes()164   int num_nodes() const { return traversal_.size(); }
165 
166   size_t Hash() const;
167 
168   bool operator==(const PyTreeDef& other) const;
169   bool operator!=(const PyTreeDef& other) const { return !(*this == other); }
170 
171   std::string ToString() const;
172 
173   // Transforms the PyTreeDef into a pickleable object. Used to implement
174   // `PyTreeDef.__getstate__`.
175   pybind11::object ToPickleable() const;
176 
177   // Transforms the object returned by `ToPickleable()` back to PyTreeDef. Used
178   // to implement `PyTreeDef.__setstate__`.
179   static PyTreeDef FromPickleable(pybind11::object pickleable);
180 
181  private:
182   struct Node {
183     PyTreeKind kind = PyTreeKind::kLeaf;
184 
185     // Arity for non-kLeaf types.
186     int arity = 0;
187 
188     // Kind-specific auxiliary data. For a kNamedTuple, contains the tuple type
189     // object. For a kDict, contains a sorted list of keys. For a kCustom type,
190     // contains the auxiliary data returned by the `to_iterable` function.
191     pybind11::object node_data;
192 
193     // Custom type registration. Must be null for non-custom types.
194     const PyTreeTypeRegistry::Registration* custom = nullptr;
195 
196     // Number of leaf nodes in the subtree rooted at this node.
197     int num_leaves = 0;
198 
199     // Number of leaf and interior nodes in the subtree rooted at this node.
200     int num_nodes = 0;
201   };
202   template <typename H>
203   friend H AbslHashValue(H h, const Node& n);
204 
205   template <typename H>
206   friend H AbslHashValue(H h, const PyTreeDef& t);
207 
208   // Helper that manufactures an instance of a node given its children.
209   static pybind11::object MakeNode(const Node& node,
210                                    absl::Span<pybind11::object> children);
211 
212   // Recursive helper used to implement FromIterableTree()
213   pybind11::object FromIterableTreeHelper(
214       pybind11::handle xs,
215       absl::InlinedVector<PyTreeDef::Node, 1>::const_reverse_iterator* it)
216       const;
217 
218   // Computes the node kind of a given Python object.
219   static PyTreeKind GetKind(const pybind11::handle& obj,
220                             PyTreeTypeRegistry::Registration const** custom);
221 
222   template <typename T>
223   void FlattenIntoImpl(pybind11::handle handle, T& leaves,
224                        const std::optional<pybind11::function>& leaf_predicate);
225 
226   template <typename T>
227   pybind11::object UnflattenImpl(T leaves) const;
228 
229   // Nodes, in a post-order traversal. We use an ordered traversal to minimize
230   // allocations, and post-order corresponds to the order we need to rebuild the
231   // tree structure.
232   absl::InlinedVector<Node, 1> traversal_;
233 };
234 
235 template <typename H>
AbslHashValue(H h,const PyTreeDef::Node & n)236 H AbslHashValue(H h, const PyTreeDef::Node& n) {
237   h = H::combine(std::move(h), n.kind, n.arity, n.custom);
238   return h;
239 }
240 
241 template <typename H>
AbslHashValue(H h,const PyTreeDef & t)242 H AbslHashValue(H h, const PyTreeDef& t) {
243   h = H::combine(std::move(h), t.traversal_);
244   return h;
245 }
246 
247 void BuildPytreeSubmodule(pybind11::module& m);
248 
249 }  // namespace xla
250 
251 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_
252