xref: /aosp_15_r20/external/executorch/extension/pytree/test/test.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport unittest
8*523fa7a6SAndroid Build Coastguard Workerfrom collections import namedtuple
9*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Dict
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Workerimport torch
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Worker# @manual=//executorch/extension/pytree:pybindings
14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.extension.pytree import (
15*523fa7a6SAndroid Build Coastguard Worker    broadcast_to_and_flatten,
16*523fa7a6SAndroid Build Coastguard Worker    register_custom,
17*523fa7a6SAndroid Build Coastguard Worker    tree_flatten,
18*523fa7a6SAndroid Build Coastguard Worker    tree_map,
19*523fa7a6SAndroid Build Coastguard Worker    tree_unflatten,
20*523fa7a6SAndroid Build Coastguard Worker    TreeSpec,
21*523fa7a6SAndroid Build Coastguard Worker)
22*523fa7a6SAndroid Build Coastguard Worker
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Worker# pyre-fixme[11]: Annotation `TreeSpec` is not defined as a type.
25*523fa7a6SAndroid Build Coastguard Workerdef _spec(o: Any) -> TreeSpec:
26*523fa7a6SAndroid Build Coastguard Worker    # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
27*523fa7a6SAndroid Build Coastguard Worker    _, spec = tree_flatten(o)
28*523fa7a6SAndroid Build Coastguard Worker    return spec
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Worker
31*523fa7a6SAndroid Build Coastguard Worker# Constructs string representation of pytree spec of type specified by type_char (can be 'T' for tuple, 'L' for List) argument, that contains n children, each with single leaf.
32*523fa7a6SAndroid Build Coastguard Worker# e.g. ('T', 3) -> 'T3#1#1#1($,$,$)'
33*523fa7a6SAndroid Build Coastguard Workerdef _spec_str(type_char, n: int) -> str:
34*523fa7a6SAndroid Build Coastguard Worker    spec = type_char + str(n)
35*523fa7a6SAndroid Build Coastguard Worker    for _ in range(n):
36*523fa7a6SAndroid Build Coastguard Worker        spec += "#1"
37*523fa7a6SAndroid Build Coastguard Worker    spec += "("
38*523fa7a6SAndroid Build Coastguard Worker    for i in range(n):
39*523fa7a6SAndroid Build Coastguard Worker        if i > 0:
40*523fa7a6SAndroid Build Coastguard Worker            spec += ","
41*523fa7a6SAndroid Build Coastguard Worker        spec += "$"
42*523fa7a6SAndroid Build Coastguard Worker    spec += ")"
43*523fa7a6SAndroid Build Coastguard Worker    return spec
44*523fa7a6SAndroid Build Coastguard Worker
45*523fa7a6SAndroid Build Coastguard Worker
46*523fa7a6SAndroid Build Coastguard Worker# Constructs string representation of pytree spec of Dict, keys can be str or int, every value is leaf.
47*523fa7a6SAndroid Build Coastguard Worker# e.g.: {'a': 1, 2: 2} -> D2#1#1('a':$,2:$)
48*523fa7a6SAndroid Build Coastguard Workerdef _spec_str_dict(d: Dict[Any, Any]) -> str:
49*523fa7a6SAndroid Build Coastguard Worker    n = len(d)
50*523fa7a6SAndroid Build Coastguard Worker    spec = "D" + str(n)
51*523fa7a6SAndroid Build Coastguard Worker    for _ in range(n):
52*523fa7a6SAndroid Build Coastguard Worker        spec += "#1"
53*523fa7a6SAndroid Build Coastguard Worker    spec += "("
54*523fa7a6SAndroid Build Coastguard Worker    i = 0
55*523fa7a6SAndroid Build Coastguard Worker    for key in d.keys():
56*523fa7a6SAndroid Build Coastguard Worker        if i > 0:
57*523fa7a6SAndroid Build Coastguard Worker            spec += ","
58*523fa7a6SAndroid Build Coastguard Worker        if isinstance(key, str):
59*523fa7a6SAndroid Build Coastguard Worker            spec += "'" + key + "'"
60*523fa7a6SAndroid Build Coastguard Worker        else:
61*523fa7a6SAndroid Build Coastguard Worker            spec += str(key)
62*523fa7a6SAndroid Build Coastguard Worker        spec += ":$"
63*523fa7a6SAndroid Build Coastguard Worker        i += 1
64*523fa7a6SAndroid Build Coastguard Worker    spec += ")"
65*523fa7a6SAndroid Build Coastguard Worker    return spec
66*523fa7a6SAndroid Build Coastguard Worker
67*523fa7a6SAndroid Build Coastguard Worker
68*523fa7a6SAndroid Build Coastguard Workerclass TestPytree(unittest.TestCase):
69*523fa7a6SAndroid Build Coastguard Worker    def test(self):
70*523fa7a6SAndroid Build Coastguard Worker        SPEC = "D4#2#1#2#2('a':L2#1#1($,$),1:$,2:T2#1#1($,$),'str':D2#1#1('str':$,'str2':$))"
71*523fa7a6SAndroid Build Coastguard Worker        d = {}
72*523fa7a6SAndroid Build Coastguard Worker        d["a"] = [777, 1]
73*523fa7a6SAndroid Build Coastguard Worker        d[1] = 4
74*523fa7a6SAndroid Build Coastguard Worker        d[2] = ("ta", 2)
75*523fa7a6SAndroid Build Coastguard Worker        d["str"] = {"str": 23, "str2": "47str"}
76*523fa7a6SAndroid Build Coastguard Worker        (leaves, pytree) = tree_flatten(d)
77*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(leaves, [777, 1, 4, "ta", 2, 23, "47str"])
78*523fa7a6SAndroid Build Coastguard Worker        pytree_str = pytree.to_str()
79*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(pytree_str, SPEC)
80*523fa7a6SAndroid Build Coastguard Worker
81*523fa7a6SAndroid Build Coastguard Worker        leaves_test = []
82*523fa7a6SAndroid Build Coastguard Worker        for i in range(len(leaves)):
83*523fa7a6SAndroid Build Coastguard Worker            if i % 2 == 0:
84*523fa7a6SAndroid Build Coastguard Worker                leaves_test.append(i + 13)
85*523fa7a6SAndroid Build Coastguard Worker            else:
86*523fa7a6SAndroid Build Coastguard Worker                leaves_test.append(str(i + 13))
87*523fa7a6SAndroid Build Coastguard Worker
88*523fa7a6SAndroid Build Coastguard Worker        tree_test = pytree.tree_unflatten(leaves_test)
89*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(
90*523fa7a6SAndroid Build Coastguard Worker            tree_test,
91*523fa7a6SAndroid Build Coastguard Worker            {"a": [13, "14"], 1: 15, 2: ("16", 17), "str": {"str": "18", "str2": 19}},
92*523fa7a6SAndroid Build Coastguard Worker        )
93*523fa7a6SAndroid Build Coastguard Worker
94*523fa7a6SAndroid Build Coastguard Worker        pytree_from = TreeSpec.from_str(SPEC)
95*523fa7a6SAndroid Build Coastguard Worker        spec_str_to = pytree_from.to_str()
96*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(SPEC, spec_str_to)
97*523fa7a6SAndroid Build Coastguard Worker
98*523fa7a6SAndroid Build Coastguard Worker    def test_extract_nested_list(self):
99*523fa7a6SAndroid Build Coastguard Worker        nested_struct = (1, 2, [3, 4])
100*523fa7a6SAndroid Build Coastguard Worker        (_, pytree) = tree_flatten(nested_struct)
101*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(pytree.to_str(), "T3#1#1#2($,$,L2#1#1($,$))")
102*523fa7a6SAndroid Build Coastguard Worker
103*523fa7a6SAndroid Build Coastguard Worker    def test_extract_nested_dict(self):
104*523fa7a6SAndroid Build Coastguard Worker        nested_struct = (1, 2, {3: 4, "str": 6})
105*523fa7a6SAndroid Build Coastguard Worker        (_, pytree) = tree_flatten(nested_struct)
106*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(pytree.to_str(), "T3#1#1#2($,$,D2#1#1(3:$,'str':$))")
107*523fa7a6SAndroid Build Coastguard Worker
108*523fa7a6SAndroid Build Coastguard Worker    def test_extracted_scalar(self):
109*523fa7a6SAndroid Build Coastguard Worker        struct = 4
110*523fa7a6SAndroid Build Coastguard Worker        (_, pytree) = tree_flatten(struct)
111*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(pytree.to_str(), "$")
112*523fa7a6SAndroid Build Coastguard Worker
113*523fa7a6SAndroid Build Coastguard Worker    def test_map(self):
114*523fa7a6SAndroid Build Coastguard Worker        struct = (1, 2, [3, 4])
115*523fa7a6SAndroid Build Coastguard Worker        struct_map = tree_map(lambda x: 2 * x, struct)
116*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(struct_map, (2, 4, [6, 8]))
117*523fa7a6SAndroid Build Coastguard Worker
118*523fa7a6SAndroid Build Coastguard Worker    def test_treespec_equality(self):
119*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(TreeSpec.from_str("$") == TreeSpec.from_str("$"))
120*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(_spec([1]) == TreeSpec.from_str("L1#1($)"))
121*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(_spec((1)) != _spec([1]))
122*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(_spec((1)) == _spec((2)))
123*523fa7a6SAndroid Build Coastguard Worker
124*523fa7a6SAndroid Build Coastguard Worker    def test_flatten_unflatten_leaf(self):
125*523fa7a6SAndroid Build Coastguard Worker        def run_test_with_leaf(leaf):
126*523fa7a6SAndroid Build Coastguard Worker            values, treespec = tree_flatten(leaf)
127*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(values, [leaf])
128*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(treespec, TreeSpec.from_str("$"))
129*523fa7a6SAndroid Build Coastguard Worker
130*523fa7a6SAndroid Build Coastguard Worker            unflattened = tree_unflatten(values, treespec)
131*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(unflattened, leaf)
132*523fa7a6SAndroid Build Coastguard Worker
133*523fa7a6SAndroid Build Coastguard Worker        run_test_with_leaf(1)
134*523fa7a6SAndroid Build Coastguard Worker        run_test_with_leaf(1.0)
135*523fa7a6SAndroid Build Coastguard Worker        run_test_with_leaf(None)
136*523fa7a6SAndroid Build Coastguard Worker        run_test_with_leaf(bool)
137*523fa7a6SAndroid Build Coastguard Worker
138*523fa7a6SAndroid Build Coastguard Worker    def test_flatten_unflatten_list(self):
139*523fa7a6SAndroid Build Coastguard Worker        def run_test(lst):
140*523fa7a6SAndroid Build Coastguard Worker            spec = _spec_str("L", len(lst))
141*523fa7a6SAndroid Build Coastguard Worker
142*523fa7a6SAndroid Build Coastguard Worker            expected_spec = TreeSpec.from_str(spec)
143*523fa7a6SAndroid Build Coastguard Worker            values, treespec = tree_flatten(lst)
144*523fa7a6SAndroid Build Coastguard Worker            self.assertTrue(isinstance(values, list))
145*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(values, lst)
146*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(treespec, expected_spec)
147*523fa7a6SAndroid Build Coastguard Worker
148*523fa7a6SAndroid Build Coastguard Worker            unflattened = tree_unflatten(values, treespec)
149*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(unflattened, lst)
150*523fa7a6SAndroid Build Coastguard Worker            self.assertTrue(isinstance(unflattened, list))
151*523fa7a6SAndroid Build Coastguard Worker
152*523fa7a6SAndroid Build Coastguard Worker        run_test([])
153*523fa7a6SAndroid Build Coastguard Worker        run_test([1.0, 2])
154*523fa7a6SAndroid Build Coastguard Worker        run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11])
155*523fa7a6SAndroid Build Coastguard Worker
156*523fa7a6SAndroid Build Coastguard Worker    def test_flatten_unflatten_tuple(self):
157*523fa7a6SAndroid Build Coastguard Worker        def run_test(tup):
158*523fa7a6SAndroid Build Coastguard Worker            spec = _spec_str("T", len(tup))
159*523fa7a6SAndroid Build Coastguard Worker
160*523fa7a6SAndroid Build Coastguard Worker            expected_spec = TreeSpec.from_str(spec)
161*523fa7a6SAndroid Build Coastguard Worker            values, treespec = tree_flatten(tup)
162*523fa7a6SAndroid Build Coastguard Worker            self.assertTrue(isinstance(values, list))
163*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(values, list(tup))
164*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(treespec, expected_spec)
165*523fa7a6SAndroid Build Coastguard Worker
166*523fa7a6SAndroid Build Coastguard Worker            unflattened = tree_unflatten(values, treespec)
167*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(unflattened, tup)
168*523fa7a6SAndroid Build Coastguard Worker            self.assertTrue(isinstance(unflattened, tuple))
169*523fa7a6SAndroid Build Coastguard Worker
170*523fa7a6SAndroid Build Coastguard Worker        run_test(())
171*523fa7a6SAndroid Build Coastguard Worker        run_test((1.0,))
172*523fa7a6SAndroid Build Coastguard Worker        run_test((1.0, 2))
173*523fa7a6SAndroid Build Coastguard Worker        run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11))
174*523fa7a6SAndroid Build Coastguard Worker
175*523fa7a6SAndroid Build Coastguard Worker    def test_flatten_unflatten_namedtuple(self):
176*523fa7a6SAndroid Build Coastguard Worker        Point = namedtuple("Point", ["x", "y"])
177*523fa7a6SAndroid Build Coastguard Worker
178*523fa7a6SAndroid Build Coastguard Worker        def run_test(tup):
179*523fa7a6SAndroid Build Coastguard Worker            spec = _spec_str("N", len(tup))
180*523fa7a6SAndroid Build Coastguard Worker            expected_spec = TreeSpec.from_str(spec)
181*523fa7a6SAndroid Build Coastguard Worker
182*523fa7a6SAndroid Build Coastguard Worker            values, treespec = tree_flatten(tup)
183*523fa7a6SAndroid Build Coastguard Worker            self.assertTrue(isinstance(values, list))
184*523fa7a6SAndroid Build Coastguard Worker
185*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(values, list(tup))
186*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(treespec, expected_spec)
187*523fa7a6SAndroid Build Coastguard Worker
188*523fa7a6SAndroid Build Coastguard Worker            unflattened = tree_unflatten(values, treespec)
189*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(unflattened, tup)
190*523fa7a6SAndroid Build Coastguard Worker
191*523fa7a6SAndroid Build Coastguard Worker        run_test(Point(1.0, 2))
192*523fa7a6SAndroid Build Coastguard Worker        run_test(Point(torch.tensor(1.0), 2))
193*523fa7a6SAndroid Build Coastguard Worker
194*523fa7a6SAndroid Build Coastguard Worker    def test_flatten_unflatten_torch_namedtuple_return_type(self):
195*523fa7a6SAndroid Build Coastguard Worker        x = torch.randn(3, 3)
196*523fa7a6SAndroid Build Coastguard Worker        expected = torch.max(x, dim=0)
197*523fa7a6SAndroid Build Coastguard Worker
198*523fa7a6SAndroid Build Coastguard Worker        values, spec = tree_flatten(expected)
199*523fa7a6SAndroid Build Coastguard Worker        result = tree_unflatten(values, spec)
200*523fa7a6SAndroid Build Coastguard Worker
201*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(type(result), type(expected))
202*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
203*523fa7a6SAndroid Build Coastguard Worker
204*523fa7a6SAndroid Build Coastguard Worker    def test_flatten_unflatten_dict(self):
205*523fa7a6SAndroid Build Coastguard Worker        def run_test(d):
206*523fa7a6SAndroid Build Coastguard Worker            spec = _spec_str_dict(d)
207*523fa7a6SAndroid Build Coastguard Worker
208*523fa7a6SAndroid Build Coastguard Worker            values, treespec = tree_flatten(d)
209*523fa7a6SAndroid Build Coastguard Worker            self.assertTrue(isinstance(values, list))
210*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(values, list(d.values()))
211*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(treespec, TreeSpec.from_str(spec))
212*523fa7a6SAndroid Build Coastguard Worker
213*523fa7a6SAndroid Build Coastguard Worker            unflattened = tree_unflatten(values, treespec)
214*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(unflattened, d)
215*523fa7a6SAndroid Build Coastguard Worker            self.assertTrue(isinstance(unflattened, dict))
216*523fa7a6SAndroid Build Coastguard Worker
217*523fa7a6SAndroid Build Coastguard Worker        run_test({})
218*523fa7a6SAndroid Build Coastguard Worker        run_test({"a": 1})
219*523fa7a6SAndroid Build Coastguard Worker        run_test({"abcdefg": torch.randn(2, 3)})
220*523fa7a6SAndroid Build Coastguard Worker        run_test({1: torch.randn(2, 3)})
221*523fa7a6SAndroid Build Coastguard Worker        run_test({"a": 1, "b": 2, "c": torch.randn(2, 3)})
222*523fa7a6SAndroid Build Coastguard Worker
223*523fa7a6SAndroid Build Coastguard Worker    def test_flatten_unflatten_nested(self):
224*523fa7a6SAndroid Build Coastguard Worker        def run_test(pytree):
225*523fa7a6SAndroid Build Coastguard Worker            values, treespec = tree_flatten(pytree)
226*523fa7a6SAndroid Build Coastguard Worker            self.assertTrue(isinstance(values, list))
227*523fa7a6SAndroid Build Coastguard Worker
228*523fa7a6SAndroid Build Coastguard Worker            unflattened = tree_unflatten(values, treespec)
229*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(unflattened, pytree)
230*523fa7a6SAndroid Build Coastguard Worker
231*523fa7a6SAndroid Build Coastguard Worker        cases = [
232*523fa7a6SAndroid Build Coastguard Worker            [()],
233*523fa7a6SAndroid Build Coastguard Worker            ([],),
234*523fa7a6SAndroid Build Coastguard Worker            {"a": ()},
235*523fa7a6SAndroid Build Coastguard Worker            {"a": 0, "b": [{"c": 1}]},
236*523fa7a6SAndroid Build Coastguard Worker            {"a": 0, "b": [1, {"c": 2}, torch.randn(3)], "c": (torch.randn(2, 3), 1)},
237*523fa7a6SAndroid Build Coastguard Worker        ]
238*523fa7a6SAndroid Build Coastguard Worker        for case in cases:
239*523fa7a6SAndroid Build Coastguard Worker            run_test(case)
240*523fa7a6SAndroid Build Coastguard Worker
241*523fa7a6SAndroid Build Coastguard Worker    def test_treemap(self):
242*523fa7a6SAndroid Build Coastguard Worker        def run_test(pytree):
243*523fa7a6SAndroid Build Coastguard Worker            def f(x):
244*523fa7a6SAndroid Build Coastguard Worker                return x * 3
245*523fa7a6SAndroid Build Coastguard Worker
246*523fa7a6SAndroid Build Coastguard Worker            sm1 = sum(map(tree_flatten(pytree)[0], f))
247*523fa7a6SAndroid Build Coastguard Worker            sm2 = tree_flatten(tree_map(f, pytree))[0]
248*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(sm1, sm2)
249*523fa7a6SAndroid Build Coastguard Worker
250*523fa7a6SAndroid Build Coastguard Worker            def invf(x):
251*523fa7a6SAndroid Build Coastguard Worker                return x // 3
252*523fa7a6SAndroid Build Coastguard Worker
253*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(tree_flatten(tree_flatten(pytree, f), invf), pytree)
254*523fa7a6SAndroid Build Coastguard Worker
255*523fa7a6SAndroid Build Coastguard Worker            cases = [
256*523fa7a6SAndroid Build Coastguard Worker                [()],
257*523fa7a6SAndroid Build Coastguard Worker                ([],),
258*523fa7a6SAndroid Build Coastguard Worker                {"a": ()},
259*523fa7a6SAndroid Build Coastguard Worker                {"a": 1, "b": [{"c": 2}]},
260*523fa7a6SAndroid Build Coastguard Worker                {"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)},
261*523fa7a6SAndroid Build Coastguard Worker            ]
262*523fa7a6SAndroid Build Coastguard Worker            for case in cases:
263*523fa7a6SAndroid Build Coastguard Worker                run_test(case)
264*523fa7a6SAndroid Build Coastguard Worker
265*523fa7a6SAndroid Build Coastguard Worker    def test_treespec_repr(self):
266*523fa7a6SAndroid Build Coastguard Worker        pytree = (0, [0, 0, 0])
267*523fa7a6SAndroid Build Coastguard Worker        _, spec = tree_flatten(pytree)
268*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(repr(spec), "T2#1#3($,L3#1#1#1($,$,$))")
269*523fa7a6SAndroid Build Coastguard Worker
270*523fa7a6SAndroid Build Coastguard Worker    def test_custom_tree_node(self):
271*523fa7a6SAndroid Build Coastguard Worker        class Point(object):
272*523fa7a6SAndroid Build Coastguard Worker            def __init__(self, x, y, name):
273*523fa7a6SAndroid Build Coastguard Worker                self.x = x
274*523fa7a6SAndroid Build Coastguard Worker                self.y = y
275*523fa7a6SAndroid Build Coastguard Worker                self.name = name
276*523fa7a6SAndroid Build Coastguard Worker
277*523fa7a6SAndroid Build Coastguard Worker            def __repr__(self):
278*523fa7a6SAndroid Build Coastguard Worker                return "Point(x:{}, y:{}, name: {})".format(self.x, self.y, self.name)
279*523fa7a6SAndroid Build Coastguard Worker
280*523fa7a6SAndroid Build Coastguard Worker        def custom_flatten(p):
281*523fa7a6SAndroid Build Coastguard Worker            children = [p.x, p.y]
282*523fa7a6SAndroid Build Coastguard Worker            extra_data = p.name
283*523fa7a6SAndroid Build Coastguard Worker            return (children, extra_data)
284*523fa7a6SAndroid Build Coastguard Worker
285*523fa7a6SAndroid Build Coastguard Worker        def custom_unflatten(children, extra_data):
286*523fa7a6SAndroid Build Coastguard Worker            return Point(*children, extra_data)
287*523fa7a6SAndroid Build Coastguard Worker
288*523fa7a6SAndroid Build Coastguard Worker        register_custom(Point, custom_flatten, custom_unflatten)
289*523fa7a6SAndroid Build Coastguard Worker
290*523fa7a6SAndroid Build Coastguard Worker        point = Point((1.0, 1.0, 1), 2.0, "point_name")
291*523fa7a6SAndroid Build Coastguard Worker        children, spec = tree_flatten(point)
292*523fa7a6SAndroid Build Coastguard Worker        point2 = tree_unflatten(children, spec)
293*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(str(point), str(point2))
294*523fa7a6SAndroid Build Coastguard Worker
295*523fa7a6SAndroid Build Coastguard Worker    def test_broadcast_to_and_flatten(self):
296*523fa7a6SAndroid Build Coastguard Worker        cases = [
297*523fa7a6SAndroid Build Coastguard Worker            (1, (), []),
298*523fa7a6SAndroid Build Coastguard Worker            # Same (flat) structures
299*523fa7a6SAndroid Build Coastguard Worker            ((1,), (0,), [1]),
300*523fa7a6SAndroid Build Coastguard Worker            ([1], [0], [1]),
301*523fa7a6SAndroid Build Coastguard Worker            ((1, 2, 3), (0, 0, 0), [1, 2, 3]),
302*523fa7a6SAndroid Build Coastguard Worker            ({"a": 1, "b": 2}, {"a": 0, "b": 0}, [1, 2]),
303*523fa7a6SAndroid Build Coastguard Worker            # Mismatched (flat) structures
304*523fa7a6SAndroid Build Coastguard Worker            ([1], (0,), None),
305*523fa7a6SAndroid Build Coastguard Worker            ([1], (0,), None),
306*523fa7a6SAndroid Build Coastguard Worker            ((1,), [0], None),
307*523fa7a6SAndroid Build Coastguard Worker            ((1, 2, 3), (0, 0), None),
308*523fa7a6SAndroid Build Coastguard Worker            ({"a": 1, "b": 2}, {"a": 0}, None),
309*523fa7a6SAndroid Build Coastguard Worker            ({"a": 1, "b": 2}, {"a": 0, "c": 0}, None),
310*523fa7a6SAndroid Build Coastguard Worker            ({"a": 1, "b": 2}, {"a": 0, "b": 0, "c": 0}, None),
311*523fa7a6SAndroid Build Coastguard Worker            # Same (nested) structures
312*523fa7a6SAndroid Build Coastguard Worker            ((1, [2, 3]), (0, [0, 0]), [1, 2, 3]),
313*523fa7a6SAndroid Build Coastguard Worker            ((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]),
314*523fa7a6SAndroid Build Coastguard Worker            # Mismatched (nested) structures
315*523fa7a6SAndroid Build Coastguard Worker            ((1, [2, 3]), (0, (0, 0)), None),
316*523fa7a6SAndroid Build Coastguard Worker            ((1, [2, 3]), (0, [0, 0, 0]), None),
317*523fa7a6SAndroid Build Coastguard Worker            # Broadcasting single value
318*523fa7a6SAndroid Build Coastguard Worker            (1, (0, 0, 0), [1, 1, 1]),
319*523fa7a6SAndroid Build Coastguard Worker            (1, [0, 0, 0], [1, 1, 1]),
320*523fa7a6SAndroid Build Coastguard Worker            (1, {"a": 0, "b": 0}, [1, 1]),
321*523fa7a6SAndroid Build Coastguard Worker            (1, (0, [0, [0]], 0), [1, 1, 1, 1]),
322*523fa7a6SAndroid Build Coastguard Worker            (1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]),
323*523fa7a6SAndroid Build Coastguard Worker            # Broadcast multiple things
324*523fa7a6SAndroid Build Coastguard Worker            ((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]),
325*523fa7a6SAndroid Build Coastguard Worker            ((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]),
326*523fa7a6SAndroid Build Coastguard Worker            (([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]),
327*523fa7a6SAndroid Build Coastguard Worker        ]
328*523fa7a6SAndroid Build Coastguard Worker        for pytree, to_pytree, expected in cases:
329*523fa7a6SAndroid Build Coastguard Worker            _, to_spec = tree_flatten(to_pytree)
330*523fa7a6SAndroid Build Coastguard Worker            result = broadcast_to_and_flatten(pytree, to_spec)
331*523fa7a6SAndroid Build Coastguard Worker            self.assertEqual(result, expected, msg=str([pytree, to_spec, expected]))
332