xref: /aosp_15_r20/external/pytorch/test/test_pytree.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: pytree"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport collections
4*da0073e9SAndroid Build Coastguard Workerimport inspect
5*da0073e9SAndroid Build Coastguard Workerimport os
6*da0073e9SAndroid Build Coastguard Workerimport re
7*da0073e9SAndroid Build Coastguard Workerimport subprocess
8*da0073e9SAndroid Build Coastguard Workerimport sys
9*da0073e9SAndroid Build Coastguard Workerimport unittest
10*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
11*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass
12*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, NamedTuple
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerimport torch
15*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as py_pytree
16*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.immutable_collections import immutable_dict, immutable_list
17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
18*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
19*da0073e9SAndroid Build Coastguard Worker    IS_FBCODE,
20*da0073e9SAndroid Build Coastguard Worker    parametrize,
21*da0073e9SAndroid Build Coastguard Worker    run_tests,
22*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
23*da0073e9SAndroid Build Coastguard Worker    subtest,
24*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_TORCHDYNAMO,
25*da0073e9SAndroid Build Coastguard Worker    TestCase,
26*da0073e9SAndroid Build Coastguard Worker)
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workerif IS_FBCODE:
30*da0073e9SAndroid Build Coastguard Worker    # optree is not yet enabled in fbcode, so just re-test the python implementation
31*da0073e9SAndroid Build Coastguard Worker    cxx_pytree = py_pytree
32*da0073e9SAndroid Build Coastguard Workerelse:
33*da0073e9SAndroid Build Coastguard Worker    import torch.utils._cxx_pytree as cxx_pytree
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard WorkerGlobalPoint = namedtuple("GlobalPoint", ["x", "y"])
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Workerclass GlobalDummyType:
39*da0073e9SAndroid Build Coastguard Worker    def __init__(self, x, y):
40*da0073e9SAndroid Build Coastguard Worker        self.x = x
41*da0073e9SAndroid Build Coastguard Worker        self.y = y
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Workerclass TestGenericPytree(TestCase):
45*da0073e9SAndroid Build Coastguard Worker    def test_aligned_public_apis(self):
46*da0073e9SAndroid Build Coastguard Worker        public_apis = py_pytree.__all__
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(public_apis, cxx_pytree.__all__)
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker        for name in public_apis:
51*da0073e9SAndroid Build Coastguard Worker            cxx_api = getattr(cxx_pytree, name)
52*da0073e9SAndroid Build Coastguard Worker            py_api = getattr(py_pytree, name)
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(inspect.isclass(cxx_api), inspect.isclass(py_api))
55*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(inspect.isfunction(cxx_api), inspect.isfunction(py_api))
56*da0073e9SAndroid Build Coastguard Worker            if inspect.isfunction(cxx_api):
57*da0073e9SAndroid Build Coastguard Worker                cxx_signature = inspect.signature(cxx_api)
58*da0073e9SAndroid Build Coastguard Worker                py_signature = inspect.signature(py_api)
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker                # Check the parameter names are the same.
61*da0073e9SAndroid Build Coastguard Worker                cxx_param_names = list(cxx_signature.parameters)
62*da0073e9SAndroid Build Coastguard Worker                py_param_names = list(py_signature.parameters)
63*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cxx_param_names, py_param_names)
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker                # Check the positional parameters are the same.
66*da0073e9SAndroid Build Coastguard Worker                cxx_positional_param_names = [
67*da0073e9SAndroid Build Coastguard Worker                    n
68*da0073e9SAndroid Build Coastguard Worker                    for n, p in cxx_signature.parameters.items()
69*da0073e9SAndroid Build Coastguard Worker                    if (
70*da0073e9SAndroid Build Coastguard Worker                        p.kind
71*da0073e9SAndroid Build Coastguard Worker                        in {
72*da0073e9SAndroid Build Coastguard Worker                            inspect.Parameter.POSITIONAL_ONLY,
73*da0073e9SAndroid Build Coastguard Worker                            inspect.Parameter.POSITIONAL_OR_KEYWORD,
74*da0073e9SAndroid Build Coastguard Worker                        }
75*da0073e9SAndroid Build Coastguard Worker                    )
76*da0073e9SAndroid Build Coastguard Worker                ]
77*da0073e9SAndroid Build Coastguard Worker                py_positional_param_names = [
78*da0073e9SAndroid Build Coastguard Worker                    n
79*da0073e9SAndroid Build Coastguard Worker                    for n, p in py_signature.parameters.items()
80*da0073e9SAndroid Build Coastguard Worker                    if (
81*da0073e9SAndroid Build Coastguard Worker                        p.kind
82*da0073e9SAndroid Build Coastguard Worker                        in {
83*da0073e9SAndroid Build Coastguard Worker                            inspect.Parameter.POSITIONAL_ONLY,
84*da0073e9SAndroid Build Coastguard Worker                            inspect.Parameter.POSITIONAL_OR_KEYWORD,
85*da0073e9SAndroid Build Coastguard Worker                        }
86*da0073e9SAndroid Build Coastguard Worker                    )
87*da0073e9SAndroid Build Coastguard Worker                ]
88*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cxx_positional_param_names, py_positional_param_names)
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker                for py_name, py_param in py_signature.parameters.items():
91*da0073e9SAndroid Build Coastguard Worker                    self.assertIn(py_name, cxx_signature.parameters)
92*da0073e9SAndroid Build Coastguard Worker                    cxx_param = cxx_signature.parameters[py_name]
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker                    # Check parameter kinds and default values are the same.
95*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(cxx_param.kind, py_param.kind)
96*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(cxx_param.default, py_param.default)
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker                    # Check parameter annotations are the same.
99*da0073e9SAndroid Build Coastguard Worker                    if "TreeSpec" in str(cxx_param.annotation):
100*da0073e9SAndroid Build Coastguard Worker                        self.assertIn("TreeSpec", str(py_param.annotation))
101*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(
102*da0073e9SAndroid Build Coastguard Worker                            re.sub(
103*da0073e9SAndroid Build Coastguard Worker                                r"(?:\b)([\w\.]*)TreeSpec(?:\b)",
104*da0073e9SAndroid Build Coastguard Worker                                "TreeSpec",
105*da0073e9SAndroid Build Coastguard Worker                                str(cxx_param.annotation),
106*da0073e9SAndroid Build Coastguard Worker                            ),
107*da0073e9SAndroid Build Coastguard Worker                            re.sub(
108*da0073e9SAndroid Build Coastguard Worker                                r"(?:\b)([\w\.]*)TreeSpec(?:\b)",
109*da0073e9SAndroid Build Coastguard Worker                                "TreeSpec",
110*da0073e9SAndroid Build Coastguard Worker                                str(py_param.annotation),
111*da0073e9SAndroid Build Coastguard Worker                            ),
112*da0073e9SAndroid Build Coastguard Worker                            msg=(
113*da0073e9SAndroid Build Coastguard Worker                                f"C++ parameter {cxx_param} "
114*da0073e9SAndroid Build Coastguard Worker                                f"does not match Python parameter {py_param} "
115*da0073e9SAndroid Build Coastguard Worker                                f"for API `{name}`"
116*da0073e9SAndroid Build Coastguard Worker                            ),
117*da0073e9SAndroid Build Coastguard Worker                        )
118*da0073e9SAndroid Build Coastguard Worker                    else:
119*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(
120*da0073e9SAndroid Build Coastguard Worker                            cxx_param.annotation,
121*da0073e9SAndroid Build Coastguard Worker                            py_param.annotation,
122*da0073e9SAndroid Build Coastguard Worker                            msg=(
123*da0073e9SAndroid Build Coastguard Worker                                f"C++ parameter {cxx_param} "
124*da0073e9SAndroid Build Coastguard Worker                                f"does not match Python parameter {py_param} "
125*da0073e9SAndroid Build Coastguard Worker                                f"for API `{name}`"
126*da0073e9SAndroid Build Coastguard Worker                            ),
127*da0073e9SAndroid Build Coastguard Worker                        )
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker    @parametrize(
130*da0073e9SAndroid Build Coastguard Worker        "pytree_impl",
131*da0073e9SAndroid Build Coastguard Worker        [
132*da0073e9SAndroid Build Coastguard Worker            subtest(py_pytree, name="py"),
133*da0073e9SAndroid Build Coastguard Worker            subtest(cxx_pytree, name="cxx"),
134*da0073e9SAndroid Build Coastguard Worker        ],
135*da0073e9SAndroid Build Coastguard Worker    )
136*da0073e9SAndroid Build Coastguard Worker    def test_register_pytree_node(self, pytree_impl):
137*da0073e9SAndroid Build Coastguard Worker        class MyDict(UserDict):
138*da0073e9SAndroid Build Coastguard Worker            pass
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker        d = MyDict(a=1, b=2, c=3)
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker        # Custom types are leaf nodes by default
143*da0073e9SAndroid Build Coastguard Worker        values, spec = pytree_impl.tree_flatten(d)
144*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(values, [d])
145*da0073e9SAndroid Build Coastguard Worker        self.assertIs(values[0], d)
146*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d, pytree_impl.tree_unflatten(values, spec))
147*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(spec.is_leaf())
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker        # Register MyDict as a pytree node
150*da0073e9SAndroid Build Coastguard Worker        pytree_impl.register_pytree_node(
151*da0073e9SAndroid Build Coastguard Worker            MyDict,
152*da0073e9SAndroid Build Coastguard Worker            lambda d: (list(d.values()), list(d.keys())),
153*da0073e9SAndroid Build Coastguard Worker            lambda values, keys: MyDict(zip(keys, values)),
154*da0073e9SAndroid Build Coastguard Worker        )
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker        values, spec = pytree_impl.tree_flatten(d)
157*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(values, [1, 2, 3])
158*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d, pytree_impl.tree_unflatten(values, spec))
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker        # Do not allow registering the same type twice
161*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "already registered"):
162*da0073e9SAndroid Build Coastguard Worker            pytree_impl.register_pytree_node(
163*da0073e9SAndroid Build Coastguard Worker                MyDict,
164*da0073e9SAndroid Build Coastguard Worker                lambda d: (list(d.values()), list(d.keys())),
165*da0073e9SAndroid Build Coastguard Worker                lambda values, keys: MyDict(zip(keys, values)),
166*da0073e9SAndroid Build Coastguard Worker            )
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker    @parametrize(
169*da0073e9SAndroid Build Coastguard Worker        "pytree_impl",
170*da0073e9SAndroid Build Coastguard Worker        [
171*da0073e9SAndroid Build Coastguard Worker            subtest(py_pytree, name="py"),
172*da0073e9SAndroid Build Coastguard Worker            subtest(cxx_pytree, name="cxx"),
173*da0073e9SAndroid Build Coastguard Worker        ],
174*da0073e9SAndroid Build Coastguard Worker    )
175*da0073e9SAndroid Build Coastguard Worker    def test_flatten_unflatten_leaf(self, pytree_impl):
176*da0073e9SAndroid Build Coastguard Worker        def run_test_with_leaf(leaf):
177*da0073e9SAndroid Build Coastguard Worker            values, treespec = pytree_impl.tree_flatten(leaf)
178*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(values, [leaf])
179*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(treespec, pytree_impl.LeafSpec())
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker            unflattened = pytree_impl.tree_unflatten(values, treespec)
182*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(unflattened, leaf)
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker        run_test_with_leaf(1)
185*da0073e9SAndroid Build Coastguard Worker        run_test_with_leaf(1.0)
186*da0073e9SAndroid Build Coastguard Worker        run_test_with_leaf(None)
187*da0073e9SAndroid Build Coastguard Worker        run_test_with_leaf(bool)
188*da0073e9SAndroid Build Coastguard Worker        run_test_with_leaf(torch.randn(3, 3))
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker    @parametrize(
191*da0073e9SAndroid Build Coastguard Worker        "pytree_impl,gen_expected_fn",
192*da0073e9SAndroid Build Coastguard Worker        [
193*da0073e9SAndroid Build Coastguard Worker            subtest(
194*da0073e9SAndroid Build Coastguard Worker                (
195*da0073e9SAndroid Build Coastguard Worker                    py_pytree,
196*da0073e9SAndroid Build Coastguard Worker                    lambda tup: py_pytree.TreeSpec(
197*da0073e9SAndroid Build Coastguard Worker                        tuple, None, [py_pytree.LeafSpec() for _ in tup]
198*da0073e9SAndroid Build Coastguard Worker                    ),
199*da0073e9SAndroid Build Coastguard Worker                ),
200*da0073e9SAndroid Build Coastguard Worker                name="py",
201*da0073e9SAndroid Build Coastguard Worker            ),
202*da0073e9SAndroid Build Coastguard Worker            subtest(
203*da0073e9SAndroid Build Coastguard Worker                (cxx_pytree, lambda tup: cxx_pytree.tree_structure((0,) * len(tup))),
204*da0073e9SAndroid Build Coastguard Worker                name="cxx",
205*da0073e9SAndroid Build Coastguard Worker            ),
206*da0073e9SAndroid Build Coastguard Worker        ],
207*da0073e9SAndroid Build Coastguard Worker    )
208*da0073e9SAndroid Build Coastguard Worker    def test_flatten_unflatten_tuple(self, pytree_impl, gen_expected_fn):
209*da0073e9SAndroid Build Coastguard Worker        def run_test(tup):
210*da0073e9SAndroid Build Coastguard Worker            expected_spec = gen_expected_fn(tup)
211*da0073e9SAndroid Build Coastguard Worker            values, treespec = pytree_impl.tree_flatten(tup)
212*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(values, list)
213*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(values, list(tup))
214*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(treespec, expected_spec)
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker            unflattened = pytree_impl.tree_unflatten(values, treespec)
217*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(unflattened, tup)
218*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(unflattened, tuple)
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker        run_test(())
221*da0073e9SAndroid Build Coastguard Worker        run_test((1.0,))
222*da0073e9SAndroid Build Coastguard Worker        run_test((1.0, 2))
223*da0073e9SAndroid Build Coastguard Worker        run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11))
224*da0073e9SAndroid Build Coastguard Worker
225*da0073e9SAndroid Build Coastguard Worker    @parametrize(
226*da0073e9SAndroid Build Coastguard Worker        "pytree_impl,gen_expected_fn",
227*da0073e9SAndroid Build Coastguard Worker        [
228*da0073e9SAndroid Build Coastguard Worker            subtest(
229*da0073e9SAndroid Build Coastguard Worker                (
230*da0073e9SAndroid Build Coastguard Worker                    py_pytree,
231*da0073e9SAndroid Build Coastguard Worker                    lambda lst: py_pytree.TreeSpec(
232*da0073e9SAndroid Build Coastguard Worker                        list, None, [py_pytree.LeafSpec() for _ in lst]
233*da0073e9SAndroid Build Coastguard Worker                    ),
234*da0073e9SAndroid Build Coastguard Worker                ),
235*da0073e9SAndroid Build Coastguard Worker                name="py",
236*da0073e9SAndroid Build Coastguard Worker            ),
237*da0073e9SAndroid Build Coastguard Worker            subtest(
238*da0073e9SAndroid Build Coastguard Worker                (cxx_pytree, lambda lst: cxx_pytree.tree_structure([0] * len(lst))),
239*da0073e9SAndroid Build Coastguard Worker                name="cxx",
240*da0073e9SAndroid Build Coastguard Worker            ),
241*da0073e9SAndroid Build Coastguard Worker        ],
242*da0073e9SAndroid Build Coastguard Worker    )
243*da0073e9SAndroid Build Coastguard Worker    def test_flatten_unflatten_list(self, pytree_impl, gen_expected_fn):
244*da0073e9SAndroid Build Coastguard Worker        def run_test(lst):
245*da0073e9SAndroid Build Coastguard Worker            expected_spec = gen_expected_fn(lst)
246*da0073e9SAndroid Build Coastguard Worker            values, treespec = pytree_impl.tree_flatten(lst)
247*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(values, list)
248*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(values, lst)
249*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(treespec, expected_spec)
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker            unflattened = pytree_impl.tree_unflatten(values, treespec)
252*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(unflattened, lst)
253*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(unflattened, list)
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker        run_test([])
256*da0073e9SAndroid Build Coastguard Worker        run_test([1.0, 2])
257*da0073e9SAndroid Build Coastguard Worker        run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11])
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker    @parametrize(
260*da0073e9SAndroid Build Coastguard Worker        "pytree_impl,gen_expected_fn",
261*da0073e9SAndroid Build Coastguard Worker        [
262*da0073e9SAndroid Build Coastguard Worker            subtest(
263*da0073e9SAndroid Build Coastguard Worker                (
264*da0073e9SAndroid Build Coastguard Worker                    py_pytree,
265*da0073e9SAndroid Build Coastguard Worker                    lambda dct: py_pytree.TreeSpec(
266*da0073e9SAndroid Build Coastguard Worker                        dict,
267*da0073e9SAndroid Build Coastguard Worker                        list(dct.keys()),
268*da0073e9SAndroid Build Coastguard Worker                        [py_pytree.LeafSpec() for _ in dct.values()],
269*da0073e9SAndroid Build Coastguard Worker                    ),
270*da0073e9SAndroid Build Coastguard Worker                ),
271*da0073e9SAndroid Build Coastguard Worker                name="py",
272*da0073e9SAndroid Build Coastguard Worker            ),
273*da0073e9SAndroid Build Coastguard Worker            subtest(
274*da0073e9SAndroid Build Coastguard Worker                (
275*da0073e9SAndroid Build Coastguard Worker                    cxx_pytree,
276*da0073e9SAndroid Build Coastguard Worker                    lambda dct: cxx_pytree.tree_structure(dict.fromkeys(dct, 0)),
277*da0073e9SAndroid Build Coastguard Worker                ),
278*da0073e9SAndroid Build Coastguard Worker                name="cxx",
279*da0073e9SAndroid Build Coastguard Worker            ),
280*da0073e9SAndroid Build Coastguard Worker        ],
281*da0073e9SAndroid Build Coastguard Worker    )
282*da0073e9SAndroid Build Coastguard Worker    def test_flatten_unflatten_dict(self, pytree_impl, gen_expected_fn):
283*da0073e9SAndroid Build Coastguard Worker        def run_test(dct):
284*da0073e9SAndroid Build Coastguard Worker            expected_spec = gen_expected_fn(dct)
285*da0073e9SAndroid Build Coastguard Worker            values, treespec = pytree_impl.tree_flatten(dct)
286*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(values, list)
287*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(values, list(dct.values()))
288*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(treespec, expected_spec)
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker            unflattened = pytree_impl.tree_unflatten(values, treespec)
291*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(unflattened, dct)
292*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(unflattened, dict)
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker        run_test({})
295*da0073e9SAndroid Build Coastguard Worker        run_test({"a": 1})
296*da0073e9SAndroid Build Coastguard Worker        run_test({"abcdefg": torch.randn(2, 3)})
297*da0073e9SAndroid Build Coastguard Worker        run_test({1: torch.randn(2, 3)})
298*da0073e9SAndroid Build Coastguard Worker        run_test({"a": 1, "b": 2, "c": torch.randn(2, 3)})
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker    @parametrize(
301*da0073e9SAndroid Build Coastguard Worker        "pytree_impl,gen_expected_fn",
302*da0073e9SAndroid Build Coastguard Worker        [
303*da0073e9SAndroid Build Coastguard Worker            subtest(
304*da0073e9SAndroid Build Coastguard Worker                (
305*da0073e9SAndroid Build Coastguard Worker                    py_pytree,
306*da0073e9SAndroid Build Coastguard Worker                    lambda odict: py_pytree.TreeSpec(
307*da0073e9SAndroid Build Coastguard Worker                        OrderedDict,
308*da0073e9SAndroid Build Coastguard Worker                        list(odict.keys()),
309*da0073e9SAndroid Build Coastguard Worker                        [py_pytree.LeafSpec() for _ in odict.values()],
310*da0073e9SAndroid Build Coastguard Worker                    ),
311*da0073e9SAndroid Build Coastguard Worker                ),
312*da0073e9SAndroid Build Coastguard Worker                name="py",
313*da0073e9SAndroid Build Coastguard Worker            ),
314*da0073e9SAndroid Build Coastguard Worker            subtest(
315*da0073e9SAndroid Build Coastguard Worker                (
316*da0073e9SAndroid Build Coastguard Worker                    cxx_pytree,
317*da0073e9SAndroid Build Coastguard Worker                    lambda odict: cxx_pytree.tree_structure(
318*da0073e9SAndroid Build Coastguard Worker                        OrderedDict.fromkeys(odict, 0)
319*da0073e9SAndroid Build Coastguard Worker                    ),
320*da0073e9SAndroid Build Coastguard Worker                ),
321*da0073e9SAndroid Build Coastguard Worker                name="cxx",
322*da0073e9SAndroid Build Coastguard Worker            ),
323*da0073e9SAndroid Build Coastguard Worker        ],
324*da0073e9SAndroid Build Coastguard Worker    )
325*da0073e9SAndroid Build Coastguard Worker    def test_flatten_unflatten_ordereddict(self, pytree_impl, gen_expected_fn):
326*da0073e9SAndroid Build Coastguard Worker        def run_test(odict):
327*da0073e9SAndroid Build Coastguard Worker            expected_spec = gen_expected_fn(odict)
328*da0073e9SAndroid Build Coastguard Worker            values, treespec = pytree_impl.tree_flatten(odict)
329*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(values, list)
330*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(values, list(odict.values()))
331*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(treespec, expected_spec)
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker            unflattened = pytree_impl.tree_unflatten(values, treespec)
334*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(unflattened, odict)
335*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(unflattened, OrderedDict)
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker        od = OrderedDict()
338*da0073e9SAndroid Build Coastguard Worker        run_test(od)
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker        od["b"] = 1
341*da0073e9SAndroid Build Coastguard Worker        od["a"] = torch.tensor(3.14)
342*da0073e9SAndroid Build Coastguard Worker        run_test(od)
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker    @parametrize(
345*da0073e9SAndroid Build Coastguard Worker        "pytree_impl,gen_expected_fn",
346*da0073e9SAndroid Build Coastguard Worker        [
347*da0073e9SAndroid Build Coastguard Worker            subtest(
348*da0073e9SAndroid Build Coastguard Worker                (
349*da0073e9SAndroid Build Coastguard Worker                    py_pytree,
350*da0073e9SAndroid Build Coastguard Worker                    lambda ddct: py_pytree.TreeSpec(
351*da0073e9SAndroid Build Coastguard Worker                        defaultdict,
352*da0073e9SAndroid Build Coastguard Worker                        [ddct.default_factory, list(ddct.keys())],
353*da0073e9SAndroid Build Coastguard Worker                        [py_pytree.LeafSpec() for _ in ddct.values()],
354*da0073e9SAndroid Build Coastguard Worker                    ),
355*da0073e9SAndroid Build Coastguard Worker                ),
356*da0073e9SAndroid Build Coastguard Worker                name="py",
357*da0073e9SAndroid Build Coastguard Worker            ),
358*da0073e9SAndroid Build Coastguard Worker            subtest(
359*da0073e9SAndroid Build Coastguard Worker                (
360*da0073e9SAndroid Build Coastguard Worker                    cxx_pytree,
361*da0073e9SAndroid Build Coastguard Worker                    lambda ddct: cxx_pytree.tree_structure(
362*da0073e9SAndroid Build Coastguard Worker                        defaultdict(ddct.default_factory, dict.fromkeys(ddct, 0))
363*da0073e9SAndroid Build Coastguard Worker                    ),
364*da0073e9SAndroid Build Coastguard Worker                ),
365*da0073e9SAndroid Build Coastguard Worker                name="cxx",
366*da0073e9SAndroid Build Coastguard Worker            ),
367*da0073e9SAndroid Build Coastguard Worker        ],
368*da0073e9SAndroid Build Coastguard Worker    )
369*da0073e9SAndroid Build Coastguard Worker    def test_flatten_unflatten_defaultdict(self, pytree_impl, gen_expected_fn):
370*da0073e9SAndroid Build Coastguard Worker        def run_test(ddct):
371*da0073e9SAndroid Build Coastguard Worker            expected_spec = gen_expected_fn(ddct)
372*da0073e9SAndroid Build Coastguard Worker            values, treespec = pytree_impl.tree_flatten(ddct)
373*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(values, list)
374*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(values, list(ddct.values()))
375*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(treespec, expected_spec)
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Worker            unflattened = pytree_impl.tree_unflatten(values, treespec)
378*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(unflattened, ddct)
379*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(unflattened.default_factory, ddct.default_factory)
380*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(unflattened, defaultdict)
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker        run_test(defaultdict(list, {}))
383*da0073e9SAndroid Build Coastguard Worker        run_test(defaultdict(int, {"a": 1}))
384*da0073e9SAndroid Build Coastguard Worker        run_test(defaultdict(int, {"abcdefg": torch.randn(2, 3)}))
385*da0073e9SAndroid Build Coastguard Worker        run_test(defaultdict(int, {1: torch.randn(2, 3)}))
386*da0073e9SAndroid Build Coastguard Worker        run_test(defaultdict(int, {"a": 1, "b": 2, "c": torch.randn(2, 3)}))
387*da0073e9SAndroid Build Coastguard Worker
388*da0073e9SAndroid Build Coastguard Worker    @parametrize(
389*da0073e9SAndroid Build Coastguard Worker        "pytree_impl,gen_expected_fn",
390*da0073e9SAndroid Build Coastguard Worker        [
391*da0073e9SAndroid Build Coastguard Worker            subtest(
392*da0073e9SAndroid Build Coastguard Worker                (
393*da0073e9SAndroid Build Coastguard Worker                    py_pytree,
394*da0073e9SAndroid Build Coastguard Worker                    lambda deq: py_pytree.TreeSpec(
395*da0073e9SAndroid Build Coastguard Worker                        deque, deq.maxlen, [py_pytree.LeafSpec() for _ in deq]
396*da0073e9SAndroid Build Coastguard Worker                    ),
397*da0073e9SAndroid Build Coastguard Worker                ),
398*da0073e9SAndroid Build Coastguard Worker                name="py",
399*da0073e9SAndroid Build Coastguard Worker            ),
400*da0073e9SAndroid Build Coastguard Worker            subtest(
401*da0073e9SAndroid Build Coastguard Worker                (
402*da0073e9SAndroid Build Coastguard Worker                    cxx_pytree,
403*da0073e9SAndroid Build Coastguard Worker                    lambda deq: cxx_pytree.tree_structure(
404*da0073e9SAndroid Build Coastguard Worker                        deque(deq, maxlen=deq.maxlen)
405*da0073e9SAndroid Build Coastguard Worker                    ),
406*da0073e9SAndroid Build Coastguard Worker                ),
407*da0073e9SAndroid Build Coastguard Worker                name="cxx",
408*da0073e9SAndroid Build Coastguard Worker            ),
409*da0073e9SAndroid Build Coastguard Worker        ],
410*da0073e9SAndroid Build Coastguard Worker    )
411*da0073e9SAndroid Build Coastguard Worker    def test_flatten_unflatten_deque(self, pytree_impl, gen_expected_fn):
412*da0073e9SAndroid Build Coastguard Worker        def run_test(deq):
413*da0073e9SAndroid Build Coastguard Worker            expected_spec = gen_expected_fn(deq)
414*da0073e9SAndroid Build Coastguard Worker            values, treespec = pytree_impl.tree_flatten(deq)
415*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(values, list)
416*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(values, list(deq))
417*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(treespec, expected_spec)
418*da0073e9SAndroid Build Coastguard Worker
419*da0073e9SAndroid Build Coastguard Worker            unflattened = pytree_impl.tree_unflatten(values, treespec)
420*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(unflattened, deq)
421*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(unflattened.maxlen, deq.maxlen)
422*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(unflattened, deque)
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker        run_test(deque([]))
425*da0073e9SAndroid Build Coastguard Worker        run_test(deque([1.0, 2]))
426*da0073e9SAndroid Build Coastguard Worker        run_test(deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))
427*da0073e9SAndroid Build Coastguard Worker
428*da0073e9SAndroid Build Coastguard Worker    @parametrize(
429*da0073e9SAndroid Build Coastguard Worker        "pytree_impl",
430*da0073e9SAndroid Build Coastguard Worker        [
431*da0073e9SAndroid Build Coastguard Worker            subtest(py_pytree, name="py"),
432*da0073e9SAndroid Build Coastguard Worker            subtest(cxx_pytree, name="cxx"),
433*da0073e9SAndroid Build Coastguard Worker        ],
434*da0073e9SAndroid Build Coastguard Worker    )
435*da0073e9SAndroid Build Coastguard Worker    def test_flatten_unflatten_namedtuple(self, pytree_impl):
436*da0073e9SAndroid Build Coastguard Worker        Point = namedtuple("Point", ["x", "y"])
437*da0073e9SAndroid Build Coastguard Worker
438*da0073e9SAndroid Build Coastguard Worker        def run_test(tup):
439*da0073e9SAndroid Build Coastguard Worker            if pytree_impl is py_pytree:
440*da0073e9SAndroid Build Coastguard Worker                expected_spec = py_pytree.TreeSpec(
441*da0073e9SAndroid Build Coastguard Worker                    namedtuple, Point, [py_pytree.LeafSpec() for _ in tup]
442*da0073e9SAndroid Build Coastguard Worker                )
443*da0073e9SAndroid Build Coastguard Worker            else:
444*da0073e9SAndroid Build Coastguard Worker                expected_spec = cxx_pytree.tree_structure(Point(0, 1))
445*da0073e9SAndroid Build Coastguard Worker            values, treespec = pytree_impl.tree_flatten(tup)
446*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(values, list)
447*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(values, list(tup))
448*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(treespec, expected_spec)
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker            unflattened = pytree_impl.tree_unflatten(values, treespec)
451*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(unflattened, tup)
452*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(unflattened, Point)
453*da0073e9SAndroid Build Coastguard Worker
454*da0073e9SAndroid Build Coastguard Worker        run_test(Point(1.0, 2))
455*da0073e9SAndroid Build Coastguard Worker        run_test(Point(torch.tensor(1.0), 2))
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Worker    @parametrize(
458*da0073e9SAndroid Build Coastguard Worker        "op",
459*da0073e9SAndroid Build Coastguard Worker        [
460*da0073e9SAndroid Build Coastguard Worker            subtest(torch.max, name="max"),
461*da0073e9SAndroid Build Coastguard Worker            subtest(torch.min, name="min"),
462*da0073e9SAndroid Build Coastguard Worker        ],
463*da0073e9SAndroid Build Coastguard Worker    )
464*da0073e9SAndroid Build Coastguard Worker    @parametrize(
465*da0073e9SAndroid Build Coastguard Worker        "pytree_impl",
466*da0073e9SAndroid Build Coastguard Worker        [
467*da0073e9SAndroid Build Coastguard Worker            subtest(py_pytree, name="py"),
468*da0073e9SAndroid Build Coastguard Worker            subtest(cxx_pytree, name="cxx"),
469*da0073e9SAndroid Build Coastguard Worker        ],
470*da0073e9SAndroid Build Coastguard Worker    )
471*da0073e9SAndroid Build Coastguard Worker    def test_flatten_unflatten_return_types(self, pytree_impl, op):
472*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3)
473*da0073e9SAndroid Build Coastguard Worker        expected = op(x, dim=0)
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker        values, spec = pytree_impl.tree_flatten(expected)
476*da0073e9SAndroid Build Coastguard Worker        # Check that values is actually List[Tensor] and not (ReturnType(...),)
477*da0073e9SAndroid Build Coastguard Worker        for value in values:
478*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(value, torch.Tensor)
479*da0073e9SAndroid Build Coastguard Worker        result = pytree_impl.tree_unflatten(values, spec)
480*da0073e9SAndroid Build Coastguard Worker
481*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(result), type(expected))
482*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
483*da0073e9SAndroid Build Coastguard Worker
484*da0073e9SAndroid Build Coastguard Worker    @parametrize(
485*da0073e9SAndroid Build Coastguard Worker        "pytree_impl",
486*da0073e9SAndroid Build Coastguard Worker        [
487*da0073e9SAndroid Build Coastguard Worker            subtest(py_pytree, name="py"),
488*da0073e9SAndroid Build Coastguard Worker            subtest(cxx_pytree, name="cxx"),
489*da0073e9SAndroid Build Coastguard Worker        ],
490*da0073e9SAndroid Build Coastguard Worker    )
491*da0073e9SAndroid Build Coastguard Worker    def test_flatten_unflatten_nested(self, pytree_impl):
492*da0073e9SAndroid Build Coastguard Worker        def run_test(pytree):
493*da0073e9SAndroid Build Coastguard Worker            values, treespec = pytree_impl.tree_flatten(pytree)
494*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(values, list)
495*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(values), treespec.num_leaves)
496*da0073e9SAndroid Build Coastguard Worker
497*da0073e9SAndroid Build Coastguard Worker            # NB: python basic data structures (dict list tuple) all have
498*da0073e9SAndroid Build Coastguard Worker            # contents equality defined on them, so the following works for them.
499*da0073e9SAndroid Build Coastguard Worker            unflattened = pytree_impl.tree_unflatten(values, treespec)
500*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(unflattened, pytree)
501*da0073e9SAndroid Build Coastguard Worker
502*da0073e9SAndroid Build Coastguard Worker        cases = [
503*da0073e9SAndroid Build Coastguard Worker            [()],
504*da0073e9SAndroid Build Coastguard Worker            ([],),
505*da0073e9SAndroid Build Coastguard Worker            {"a": ()},
506*da0073e9SAndroid Build Coastguard Worker            {"a": 0, "b": [{"c": 1}]},
507*da0073e9SAndroid Build Coastguard Worker            {"a": 0, "b": [1, {"c": 2}, torch.randn(3)], "c": (torch.randn(2, 3), 1)},
508*da0073e9SAndroid Build Coastguard Worker        ]
509*da0073e9SAndroid Build Coastguard Worker        for case in cases:
510*da0073e9SAndroid Build Coastguard Worker            run_test(case)
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker    @parametrize(
513*da0073e9SAndroid Build Coastguard Worker        "pytree_impl",
514*da0073e9SAndroid Build Coastguard Worker        [
515*da0073e9SAndroid Build Coastguard Worker            subtest(py_pytree, name="py"),
516*da0073e9SAndroid Build Coastguard Worker            subtest(cxx_pytree, name="cxx"),
517*da0073e9SAndroid Build Coastguard Worker        ],
518*da0073e9SAndroid Build Coastguard Worker    )
519*da0073e9SAndroid Build Coastguard Worker    def test_flatten_with_is_leaf(self, pytree_impl):
520*da0073e9SAndroid Build Coastguard Worker        def run_test(pytree, one_level_leaves):
521*da0073e9SAndroid Build Coastguard Worker            values, treespec = pytree_impl.tree_flatten(
522*da0073e9SAndroid Build Coastguard Worker                pytree, is_leaf=lambda x: x is not pytree
523*da0073e9SAndroid Build Coastguard Worker            )
524*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(values, list)
525*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(values), treespec.num_nodes - 1)
526*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(values), treespec.num_leaves)
527*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(values), treespec.num_children)
528*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(values, one_level_leaves)
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
531*da0073e9SAndroid Build Coastguard Worker                treespec,
532*da0073e9SAndroid Build Coastguard Worker                pytree_impl.tree_structure(
533*da0073e9SAndroid Build Coastguard Worker                    pytree_impl.tree_unflatten([0] * treespec.num_leaves, treespec)
534*da0073e9SAndroid Build Coastguard Worker                ),
535*da0073e9SAndroid Build Coastguard Worker            )
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Worker            unflattened = pytree_impl.tree_unflatten(values, treespec)
538*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(unflattened, pytree)
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Worker        cases = [
541*da0073e9SAndroid Build Coastguard Worker            ([()], [()]),
542*da0073e9SAndroid Build Coastguard Worker            (([],), [[]]),
543*da0073e9SAndroid Build Coastguard Worker            ({"a": ()}, [()]),
544*da0073e9SAndroid Build Coastguard Worker            ({"a": 0, "b": [{"c": 1}]}, [0, [{"c": 1}]]),
545*da0073e9SAndroid Build Coastguard Worker            (
546*da0073e9SAndroid Build Coastguard Worker                {
547*da0073e9SAndroid Build Coastguard Worker                    "a": 0,
548*da0073e9SAndroid Build Coastguard Worker                    "b": [1, {"c": 2}, torch.ones(3)],
549*da0073e9SAndroid Build Coastguard Worker                    "c": (torch.zeros(2, 3), 1),
550*da0073e9SAndroid Build Coastguard Worker                },
551*da0073e9SAndroid Build Coastguard Worker                [0, [1, {"c": 2}, torch.ones(3)], (torch.zeros(2, 3), 1)],
552*da0073e9SAndroid Build Coastguard Worker            ),
553*da0073e9SAndroid Build Coastguard Worker        ]
554*da0073e9SAndroid Build Coastguard Worker        for case in cases:
555*da0073e9SAndroid Build Coastguard Worker            run_test(*case)
556*da0073e9SAndroid Build Coastguard Worker
557*da0073e9SAndroid Build Coastguard Worker    @parametrize(
558*da0073e9SAndroid Build Coastguard Worker        "pytree_impl",
559*da0073e9SAndroid Build Coastguard Worker        [
560*da0073e9SAndroid Build Coastguard Worker            subtest(py_pytree, name="py"),
561*da0073e9SAndroid Build Coastguard Worker            subtest(cxx_pytree, name="cxx"),
562*da0073e9SAndroid Build Coastguard Worker        ],
563*da0073e9SAndroid Build Coastguard Worker    )
564*da0073e9SAndroid Build Coastguard Worker    def test_tree_map(self, pytree_impl):
565*da0073e9SAndroid Build Coastguard Worker        def run_test(pytree):
566*da0073e9SAndroid Build Coastguard Worker            def f(x):
567*da0073e9SAndroid Build Coastguard Worker                return x * 3
568*da0073e9SAndroid Build Coastguard Worker
569*da0073e9SAndroid Build Coastguard Worker            sm1 = sum(map(f, pytree_impl.tree_leaves(pytree)))
570*da0073e9SAndroid Build Coastguard Worker            sm2 = sum(pytree_impl.tree_leaves(pytree_impl.tree_map(f, pytree)))
571*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(sm1, sm2)
572*da0073e9SAndroid Build Coastguard Worker
573*da0073e9SAndroid Build Coastguard Worker            def invf(x):
574*da0073e9SAndroid Build Coastguard Worker                return x // 3
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
577*da0073e9SAndroid Build Coastguard Worker                pytree_impl.tree_map(invf, pytree_impl.tree_map(f, pytree)),
578*da0073e9SAndroid Build Coastguard Worker                pytree,
579*da0073e9SAndroid Build Coastguard Worker            )
580*da0073e9SAndroid Build Coastguard Worker
581*da0073e9SAndroid Build Coastguard Worker        cases = [
582*da0073e9SAndroid Build Coastguard Worker            [()],
583*da0073e9SAndroid Build Coastguard Worker            ([],),
584*da0073e9SAndroid Build Coastguard Worker            {"a": ()},
585*da0073e9SAndroid Build Coastguard Worker            {"a": 1, "b": [{"c": 2}]},
586*da0073e9SAndroid Build Coastguard Worker            {"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)},
587*da0073e9SAndroid Build Coastguard Worker        ]
588*da0073e9SAndroid Build Coastguard Worker        for case in cases:
589*da0073e9SAndroid Build Coastguard Worker            run_test(case)
590*da0073e9SAndroid Build Coastguard Worker
591*da0073e9SAndroid Build Coastguard Worker    @parametrize(
592*da0073e9SAndroid Build Coastguard Worker        "pytree_impl",
593*da0073e9SAndroid Build Coastguard Worker        [
594*da0073e9SAndroid Build Coastguard Worker            subtest(py_pytree, name="py"),
595*da0073e9SAndroid Build Coastguard Worker            subtest(cxx_pytree, name="cxx"),
596*da0073e9SAndroid Build Coastguard Worker        ],
597*da0073e9SAndroid Build Coastguard Worker    )
598*da0073e9SAndroid Build Coastguard Worker    def test_tree_map_multi_inputs(self, pytree_impl):
599*da0073e9SAndroid Build Coastguard Worker        def run_test(pytree):
600*da0073e9SAndroid Build Coastguard Worker            def f(x, y, z):
601*da0073e9SAndroid Build Coastguard Worker                return x, [y, (z, 0)]
602*da0073e9SAndroid Build Coastguard Worker
603*da0073e9SAndroid Build Coastguard Worker            pytree_x = pytree
604*da0073e9SAndroid Build Coastguard Worker            pytree_y = pytree_impl.tree_map(lambda x: (x + 1,), pytree)
605*da0073e9SAndroid Build Coastguard Worker            pytree_z = pytree_impl.tree_map(lambda x: {"a": x * 2, "b": 2}, pytree)
606*da0073e9SAndroid Build Coastguard Worker
607*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
608*da0073e9SAndroid Build Coastguard Worker                pytree_impl.tree_map(f, pytree_x, pytree_y, pytree_z),
609*da0073e9SAndroid Build Coastguard Worker                pytree_impl.tree_map(
610*da0073e9SAndroid Build Coastguard Worker                    lambda x: f(x, (x + 1,), {"a": x * 2, "b": 2}), pytree
611*da0073e9SAndroid Build Coastguard Worker                ),
612*da0073e9SAndroid Build Coastguard Worker            )
613*da0073e9SAndroid Build Coastguard Worker
614*da0073e9SAndroid Build Coastguard Worker        cases = [
615*da0073e9SAndroid Build Coastguard Worker            [()],
616*da0073e9SAndroid Build Coastguard Worker            ([],),
617*da0073e9SAndroid Build Coastguard Worker            {"a": ()},
618*da0073e9SAndroid Build Coastguard Worker            {"a": 1, "b": [{"c": 2}]},
619*da0073e9SAndroid Build Coastguard Worker            {"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)},
620*da0073e9SAndroid Build Coastguard Worker        ]
621*da0073e9SAndroid Build Coastguard Worker        for case in cases:
622*da0073e9SAndroid Build Coastguard Worker            run_test(case)
623*da0073e9SAndroid Build Coastguard Worker
624*da0073e9SAndroid Build Coastguard Worker    @parametrize(
625*da0073e9SAndroid Build Coastguard Worker        "pytree_impl",
626*da0073e9SAndroid Build Coastguard Worker        [
627*da0073e9SAndroid Build Coastguard Worker            subtest(py_pytree, name="py"),
628*da0073e9SAndroid Build Coastguard Worker            subtest(cxx_pytree, name="cxx"),
629*da0073e9SAndroid Build Coastguard Worker        ],
630*da0073e9SAndroid Build Coastguard Worker    )
631*da0073e9SAndroid Build Coastguard Worker    def test_tree_map_only(self, pytree_impl):
632*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
633*da0073e9SAndroid Build Coastguard Worker            pytree_impl.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]
634*da0073e9SAndroid Build Coastguard Worker        )
635*da0073e9SAndroid Build Coastguard Worker
636*da0073e9SAndroid Build Coastguard Worker    @parametrize(
637*da0073e9SAndroid Build Coastguard Worker        "pytree_impl",
638*da0073e9SAndroid Build Coastguard Worker        [
639*da0073e9SAndroid Build Coastguard Worker            subtest(py_pytree, name="py"),
640*da0073e9SAndroid Build Coastguard Worker            subtest(cxx_pytree, name="cxx"),
641*da0073e9SAndroid Build Coastguard Worker        ],
642*da0073e9SAndroid Build Coastguard Worker    )
643*da0073e9SAndroid Build Coastguard Worker    def test_tree_map_only_predicate_fn(self, pytree_impl):
644*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
645*da0073e9SAndroid Build Coastguard Worker            pytree_impl.tree_map_only(lambda x: x == 0, lambda x: x + 2, [0, 1]), [2, 1]
646*da0073e9SAndroid Build Coastguard Worker        )
647*da0073e9SAndroid Build Coastguard Worker
648*da0073e9SAndroid Build Coastguard Worker    @parametrize(
649*da0073e9SAndroid Build Coastguard Worker        "pytree_impl",
650*da0073e9SAndroid Build Coastguard Worker        [
651*da0073e9SAndroid Build Coastguard Worker            subtest(py_pytree, name="py"),
652*da0073e9SAndroid Build Coastguard Worker            subtest(cxx_pytree, name="cxx"),
653*da0073e9SAndroid Build Coastguard Worker        ],
654*da0073e9SAndroid Build Coastguard Worker    )
655*da0073e9SAndroid Build Coastguard Worker    def test_tree_all_any(self, pytree_impl):
656*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(pytree_impl.tree_all(lambda x: x % 2, [1, 3]))
657*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(pytree_impl.tree_all(lambda x: x % 2, [0, 1]))
658*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(pytree_impl.tree_any(lambda x: x % 2, [0, 1]))
659*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(pytree_impl.tree_any(lambda x: x % 2, [0, 2]))
660*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(pytree_impl.tree_all_only(int, lambda x: x % 2, [1, 3, "a"]))
661*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(pytree_impl.tree_all_only(int, lambda x: x % 2, [0, 1, "a"]))
662*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 1, "a"]))
663*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 2, "a"]))
664*da0073e9SAndroid Build Coastguard Worker
665*da0073e9SAndroid Build Coastguard Worker    @parametrize(
666*da0073e9SAndroid Build Coastguard Worker        "pytree_impl",
667*da0073e9SAndroid Build Coastguard Worker        [
668*da0073e9SAndroid Build Coastguard Worker            subtest(py_pytree, name="py"),
669*da0073e9SAndroid Build Coastguard Worker            subtest(cxx_pytree, name="cxx"),
670*da0073e9SAndroid Build Coastguard Worker        ],
671*da0073e9SAndroid Build Coastguard Worker    )
672*da0073e9SAndroid Build Coastguard Worker    def test_broadcast_to_and_flatten(self, pytree_impl):
673*da0073e9SAndroid Build Coastguard Worker        cases = [
674*da0073e9SAndroid Build Coastguard Worker            (1, (), []),
675*da0073e9SAndroid Build Coastguard Worker            # Same (flat) structures
676*da0073e9SAndroid Build Coastguard Worker            ((1,), (0,), [1]),
677*da0073e9SAndroid Build Coastguard Worker            ([1], [0], [1]),
678*da0073e9SAndroid Build Coastguard Worker            ((1, 2, 3), (0, 0, 0), [1, 2, 3]),
679*da0073e9SAndroid Build Coastguard Worker            ({"a": 1, "b": 2}, {"a": 0, "b": 0}, [1, 2]),
680*da0073e9SAndroid Build Coastguard Worker            # Mismatched (flat) structures
681*da0073e9SAndroid Build Coastguard Worker            ([1], (0,), None),
682*da0073e9SAndroid Build Coastguard Worker            ([1], (0,), None),
683*da0073e9SAndroid Build Coastguard Worker            ((1,), [0], None),
684*da0073e9SAndroid Build Coastguard Worker            ((1, 2, 3), (0, 0), None),
685*da0073e9SAndroid Build Coastguard Worker            ({"a": 1, "b": 2}, {"a": 0}, None),
686*da0073e9SAndroid Build Coastguard Worker            ({"a": 1, "b": 2}, {"a": 0, "c": 0}, None),
687*da0073e9SAndroid Build Coastguard Worker            ({"a": 1, "b": 2}, {"a": 0, "b": 0, "c": 0}, None),
688*da0073e9SAndroid Build Coastguard Worker            # Same (nested) structures
689*da0073e9SAndroid Build Coastguard Worker            ((1, [2, 3]), (0, [0, 0]), [1, 2, 3]),
690*da0073e9SAndroid Build Coastguard Worker            ((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]),
691*da0073e9SAndroid Build Coastguard Worker            # Mismatched (nested) structures
692*da0073e9SAndroid Build Coastguard Worker            ((1, [2, 3]), (0, (0, 0)), None),
693*da0073e9SAndroid Build Coastguard Worker            ((1, [2, 3]), (0, [0, 0, 0]), None),
694*da0073e9SAndroid Build Coastguard Worker            # Broadcasting single value
695*da0073e9SAndroid Build Coastguard Worker            (1, (0, 0, 0), [1, 1, 1]),
696*da0073e9SAndroid Build Coastguard Worker            (1, [0, 0, 0], [1, 1, 1]),
697*da0073e9SAndroid Build Coastguard Worker            (1, {"a": 0, "b": 0}, [1, 1]),
698*da0073e9SAndroid Build Coastguard Worker            (1, (0, [0, [0]], 0), [1, 1, 1, 1]),
699*da0073e9SAndroid Build Coastguard Worker            (1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]),
700*da0073e9SAndroid Build Coastguard Worker            # Broadcast multiple things
701*da0073e9SAndroid Build Coastguard Worker            ((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]),
702*da0073e9SAndroid Build Coastguard Worker            ((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]),
703*da0073e9SAndroid Build Coastguard Worker            (([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]),
704*da0073e9SAndroid Build Coastguard Worker        ]
705*da0073e9SAndroid Build Coastguard Worker        for pytree, to_pytree, expected in cases:
706*da0073e9SAndroid Build Coastguard Worker            _, to_spec = pytree_impl.tree_flatten(to_pytree)
707*da0073e9SAndroid Build Coastguard Worker            result = pytree_impl._broadcast_to_and_flatten(pytree, to_spec)
708*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, expected, msg=str([pytree, to_spec, expected]))
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker    @parametrize(
711*da0073e9SAndroid Build Coastguard Worker        "pytree_impl",
712*da0073e9SAndroid Build Coastguard Worker        [
713*da0073e9SAndroid Build Coastguard Worker            subtest(py_pytree, name="py"),
714*da0073e9SAndroid Build Coastguard Worker            subtest(cxx_pytree, name="cxx"),
715*da0073e9SAndroid Build Coastguard Worker        ],
716*da0073e9SAndroid Build Coastguard Worker    )
717*da0073e9SAndroid Build Coastguard Worker    def test_pytree_serialize_bad_input(self, pytree_impl):
718*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
719*da0073e9SAndroid Build Coastguard Worker            pytree_impl.treespec_dumps("random_blurb")
720*da0073e9SAndroid Build Coastguard Worker
721*da0073e9SAndroid Build Coastguard Worker
722*da0073e9SAndroid Build Coastguard Workerclass TestPythonPytree(TestCase):
723*da0073e9SAndroid Build Coastguard Worker    def test_deprecated_register_pytree_node(self):
724*da0073e9SAndroid Build Coastguard Worker        class DummyType:
725*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x, y):
726*da0073e9SAndroid Build Coastguard Worker                self.x = x
727*da0073e9SAndroid Build Coastguard Worker                self.y = y
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
730*da0073e9SAndroid Build Coastguard Worker            FutureWarning, "torch.utils._pytree._register_pytree_node"
731*da0073e9SAndroid Build Coastguard Worker        ):
732*da0073e9SAndroid Build Coastguard Worker            py_pytree._register_pytree_node(
733*da0073e9SAndroid Build Coastguard Worker                DummyType,
734*da0073e9SAndroid Build Coastguard Worker                lambda dummy: ([dummy.x, dummy.y], None),
735*da0073e9SAndroid Build Coastguard Worker                lambda xs, _: DummyType(*xs),
736*da0073e9SAndroid Build Coastguard Worker            )
737*da0073e9SAndroid Build Coastguard Worker
738*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(UserWarning, "already registered"):
739*da0073e9SAndroid Build Coastguard Worker            py_pytree._register_pytree_node(
740*da0073e9SAndroid Build Coastguard Worker                DummyType,
741*da0073e9SAndroid Build Coastguard Worker                lambda dummy: ([dummy.x, dummy.y], None),
742*da0073e9SAndroid Build Coastguard Worker                lambda xs, _: DummyType(*xs),
743*da0073e9SAndroid Build Coastguard Worker            )
744*da0073e9SAndroid Build Coastguard Worker
745*da0073e9SAndroid Build Coastguard Worker    def test_import_pytree_doesnt_import_optree(self):
746*da0073e9SAndroid Build Coastguard Worker        # importing torch.utils._pytree shouldn't import optree.
747*da0073e9SAndroid Build Coastguard Worker        # only importing torch.utils._cxx_pytree should.
748*da0073e9SAndroid Build Coastguard Worker        script = """
749*da0073e9SAndroid Build Coastguard Workerimport sys
750*da0073e9SAndroid Build Coastguard Workerimport torch
751*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree
752*da0073e9SAndroid Build Coastguard Workerassert "torch.utils._pytree" in sys.modules
753*da0073e9SAndroid Build Coastguard Workerif "torch.utils._cxx_pytree" in sys.modules:
754*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError("importing torch.utils._pytree should not import torch.utils._cxx_pytree")
755*da0073e9SAndroid Build Coastguard Workerif "optree" in sys.modules:
756*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError("importing torch.utils._pytree should not import optree")
757*da0073e9SAndroid Build Coastguard Worker"""
758*da0073e9SAndroid Build Coastguard Worker        try:
759*da0073e9SAndroid Build Coastguard Worker            subprocess.check_output(
760*da0073e9SAndroid Build Coastguard Worker                [sys.executable, "-c", script],
761*da0073e9SAndroid Build Coastguard Worker                stderr=subprocess.STDOUT,
762*da0073e9SAndroid Build Coastguard Worker                # On Windows, opening the subprocess with the default CWD makes `import torch`
763*da0073e9SAndroid Build Coastguard Worker                # fail, so just set CWD to this script's directory
764*da0073e9SAndroid Build Coastguard Worker                cwd=os.path.dirname(os.path.realpath(__file__)),
765*da0073e9SAndroid Build Coastguard Worker            )
766*da0073e9SAndroid Build Coastguard Worker        except subprocess.CalledProcessError as e:
767*da0073e9SAndroid Build Coastguard Worker            self.fail(
768*da0073e9SAndroid Build Coastguard Worker                msg=(
769*da0073e9SAndroid Build Coastguard Worker                    "Subprocess exception while attempting to run test: "
770*da0073e9SAndroid Build Coastguard Worker                    + e.output.decode("utf-8")
771*da0073e9SAndroid Build Coastguard Worker                )
772*da0073e9SAndroid Build Coastguard Worker            )
773*da0073e9SAndroid Build Coastguard Worker
774*da0073e9SAndroid Build Coastguard Worker    def test_treespec_equality(self):
775*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
776*da0073e9SAndroid Build Coastguard Worker            py_pytree.LeafSpec(),
777*da0073e9SAndroid Build Coastguard Worker            py_pytree.LeafSpec(),
778*da0073e9SAndroid Build Coastguard Worker        )
779*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
780*da0073e9SAndroid Build Coastguard Worker            py_pytree.TreeSpec(list, None, []),
781*da0073e9SAndroid Build Coastguard Worker            py_pytree.TreeSpec(list, None, []),
782*da0073e9SAndroid Build Coastguard Worker        )
783*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
784*da0073e9SAndroid Build Coastguard Worker            py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
785*da0073e9SAndroid Build Coastguard Worker            py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
786*da0073e9SAndroid Build Coastguard Worker        )
787*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(
788*da0073e9SAndroid Build Coastguard Worker            py_pytree.TreeSpec(tuple, None, []) == py_pytree.TreeSpec(list, None, []),
789*da0073e9SAndroid Build Coastguard Worker        )
790*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
791*da0073e9SAndroid Build Coastguard Worker            py_pytree.TreeSpec(tuple, None, []) != py_pytree.TreeSpec(list, None, []),
792*da0073e9SAndroid Build Coastguard Worker        )
793*da0073e9SAndroid Build Coastguard Worker
794*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
795*da0073e9SAndroid Build Coastguard Worker    def test_treespec_repr(self):
796*da0073e9SAndroid Build Coastguard Worker        # Check that it looks sane
797*da0073e9SAndroid Build Coastguard Worker        pytree = (0, [0, 0, [0]])
798*da0073e9SAndroid Build Coastguard Worker        _, spec = py_pytree.tree_flatten(pytree)
799*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
800*da0073e9SAndroid Build Coastguard Worker            repr(spec),
801*da0073e9SAndroid Build Coastguard Worker            (
802*da0073e9SAndroid Build Coastguard Worker                "TreeSpec(tuple, None, [*,\n"
803*da0073e9SAndroid Build Coastguard Worker                "  TreeSpec(list, None, [*,\n"
804*da0073e9SAndroid Build Coastguard Worker                "    *,\n"
805*da0073e9SAndroid Build Coastguard Worker                "    TreeSpec(list, None, [*])])])"
806*da0073e9SAndroid Build Coastguard Worker            ),
807*da0073e9SAndroid Build Coastguard Worker        )
808*da0073e9SAndroid Build Coastguard Worker
809*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
810*da0073e9SAndroid Build Coastguard Worker    def test_treespec_repr_dynamo(self):
811*da0073e9SAndroid Build Coastguard Worker        # Check that it looks sane
812*da0073e9SAndroid Build Coastguard Worker        pytree = (0, [0, 0, [0]])
813*da0073e9SAndroid Build Coastguard Worker        _, spec = py_pytree.tree_flatten(pytree)
814*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
815*da0073e9SAndroid Build Coastguard Worker            repr(spec),
816*da0073e9SAndroid Build Coastguard Worker            """\
817*da0073e9SAndroid Build Coastguard WorkerTreeSpec(tuple, None, [*,
818*da0073e9SAndroid Build Coastguard Worker  TreeSpec(list, None, [*,
819*da0073e9SAndroid Build Coastguard Worker    *,
820*da0073e9SAndroid Build Coastguard Worker    TreeSpec(list, None, [*])])])""",
821*da0073e9SAndroid Build Coastguard Worker        )
822*da0073e9SAndroid Build Coastguard Worker
823*da0073e9SAndroid Build Coastguard Worker    @parametrize(
824*da0073e9SAndroid Build Coastguard Worker        "spec",
825*da0073e9SAndroid Build Coastguard Worker        [
826*da0073e9SAndroid Build Coastguard Worker            # py_pytree.tree_structure([])
827*da0073e9SAndroid Build Coastguard Worker            py_pytree.TreeSpec(list, None, []),
828*da0073e9SAndroid Build Coastguard Worker            # py_pytree.tree_structure(())
829*da0073e9SAndroid Build Coastguard Worker            py_pytree.TreeSpec(tuple, None, []),
830*da0073e9SAndroid Build Coastguard Worker            # py_pytree.tree_structure({})
831*da0073e9SAndroid Build Coastguard Worker            py_pytree.TreeSpec(dict, [], []),
832*da0073e9SAndroid Build Coastguard Worker            # py_pytree.tree_structure([0])
833*da0073e9SAndroid Build Coastguard Worker            py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
834*da0073e9SAndroid Build Coastguard Worker            # py_pytree.tree_structure([0, 1])
835*da0073e9SAndroid Build Coastguard Worker            py_pytree.TreeSpec(
836*da0073e9SAndroid Build Coastguard Worker                list,
837*da0073e9SAndroid Build Coastguard Worker                None,
838*da0073e9SAndroid Build Coastguard Worker                [
839*da0073e9SAndroid Build Coastguard Worker                    py_pytree.LeafSpec(),
840*da0073e9SAndroid Build Coastguard Worker                    py_pytree.LeafSpec(),
841*da0073e9SAndroid Build Coastguard Worker                ],
842*da0073e9SAndroid Build Coastguard Worker            ),
843*da0073e9SAndroid Build Coastguard Worker            # py_pytree.tree_structure((0, 1, 2))
844*da0073e9SAndroid Build Coastguard Worker            py_pytree.TreeSpec(
845*da0073e9SAndroid Build Coastguard Worker                tuple,
846*da0073e9SAndroid Build Coastguard Worker                None,
847*da0073e9SAndroid Build Coastguard Worker                [
848*da0073e9SAndroid Build Coastguard Worker                    py_pytree.LeafSpec(),
849*da0073e9SAndroid Build Coastguard Worker                    py_pytree.LeafSpec(),
850*da0073e9SAndroid Build Coastguard Worker                    py_pytree.LeafSpec(),
851*da0073e9SAndroid Build Coastguard Worker                ],
852*da0073e9SAndroid Build Coastguard Worker            ),
853*da0073e9SAndroid Build Coastguard Worker            # py_pytree.tree_structure({"a": 0, "b": 1, "c": 2})
854*da0073e9SAndroid Build Coastguard Worker            py_pytree.TreeSpec(
855*da0073e9SAndroid Build Coastguard Worker                dict,
856*da0073e9SAndroid Build Coastguard Worker                ["a", "b", "c"],
857*da0073e9SAndroid Build Coastguard Worker                [
858*da0073e9SAndroid Build Coastguard Worker                    py_pytree.LeafSpec(),
859*da0073e9SAndroid Build Coastguard Worker                    py_pytree.LeafSpec(),
860*da0073e9SAndroid Build Coastguard Worker                    py_pytree.LeafSpec(),
861*da0073e9SAndroid Build Coastguard Worker                ],
862*da0073e9SAndroid Build Coastguard Worker            ),
863*da0073e9SAndroid Build Coastguard Worker            # py_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
864*da0073e9SAndroid Build Coastguard Worker            py_pytree.TreeSpec(
865*da0073e9SAndroid Build Coastguard Worker                OrderedDict,
866*da0073e9SAndroid Build Coastguard Worker                ["a", "b", "c"],
867*da0073e9SAndroid Build Coastguard Worker                [
868*da0073e9SAndroid Build Coastguard Worker                    py_pytree.TreeSpec(
869*da0073e9SAndroid Build Coastguard Worker                        tuple,
870*da0073e9SAndroid Build Coastguard Worker                        None,
871*da0073e9SAndroid Build Coastguard Worker                        [
872*da0073e9SAndroid Build Coastguard Worker                            py_pytree.LeafSpec(),
873*da0073e9SAndroid Build Coastguard Worker                            py_pytree.LeafSpec(),
874*da0073e9SAndroid Build Coastguard Worker                        ],
875*da0073e9SAndroid Build Coastguard Worker                    ),
876*da0073e9SAndroid Build Coastguard Worker                    py_pytree.LeafSpec(),
877*da0073e9SAndroid Build Coastguard Worker                    py_pytree.TreeSpec(
878*da0073e9SAndroid Build Coastguard Worker                        dict,
879*da0073e9SAndroid Build Coastguard Worker                        ["a", "b", "c"],
880*da0073e9SAndroid Build Coastguard Worker                        [
881*da0073e9SAndroid Build Coastguard Worker                            py_pytree.LeafSpec(),
882*da0073e9SAndroid Build Coastguard Worker                            py_pytree.LeafSpec(),
883*da0073e9SAndroid Build Coastguard Worker                            py_pytree.LeafSpec(),
884*da0073e9SAndroid Build Coastguard Worker                        ],
885*da0073e9SAndroid Build Coastguard Worker                    ),
886*da0073e9SAndroid Build Coastguard Worker                ],
887*da0073e9SAndroid Build Coastguard Worker            ),
888*da0073e9SAndroid Build Coastguard Worker            # py_pytree.tree_structure([(0, 1, [2, 3])])
889*da0073e9SAndroid Build Coastguard Worker            py_pytree.TreeSpec(
890*da0073e9SAndroid Build Coastguard Worker                list,
891*da0073e9SAndroid Build Coastguard Worker                None,
892*da0073e9SAndroid Build Coastguard Worker                [
893*da0073e9SAndroid Build Coastguard Worker                    py_pytree.TreeSpec(
894*da0073e9SAndroid Build Coastguard Worker                        tuple,
895*da0073e9SAndroid Build Coastguard Worker                        None,
896*da0073e9SAndroid Build Coastguard Worker                        [
897*da0073e9SAndroid Build Coastguard Worker                            py_pytree.LeafSpec(),
898*da0073e9SAndroid Build Coastguard Worker                            py_pytree.LeafSpec(),
899*da0073e9SAndroid Build Coastguard Worker                            py_pytree.TreeSpec(
900*da0073e9SAndroid Build Coastguard Worker                                list,
901*da0073e9SAndroid Build Coastguard Worker                                None,
902*da0073e9SAndroid Build Coastguard Worker                                [
903*da0073e9SAndroid Build Coastguard Worker                                    py_pytree.LeafSpec(),
904*da0073e9SAndroid Build Coastguard Worker                                    py_pytree.LeafSpec(),
905*da0073e9SAndroid Build Coastguard Worker                                ],
906*da0073e9SAndroid Build Coastguard Worker                            ),
907*da0073e9SAndroid Build Coastguard Worker                        ],
908*da0073e9SAndroid Build Coastguard Worker                    ),
909*da0073e9SAndroid Build Coastguard Worker                ],
910*da0073e9SAndroid Build Coastguard Worker            ),
911*da0073e9SAndroid Build Coastguard Worker            # py_pytree.tree_structure(defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}}))
912*da0073e9SAndroid Build Coastguard Worker            py_pytree.TreeSpec(
913*da0073e9SAndroid Build Coastguard Worker                defaultdict,
914*da0073e9SAndroid Build Coastguard Worker                [list, ["a", "b", "c"]],
915*da0073e9SAndroid Build Coastguard Worker                [
916*da0073e9SAndroid Build Coastguard Worker                    py_pytree.TreeSpec(
917*da0073e9SAndroid Build Coastguard Worker                        list,
918*da0073e9SAndroid Build Coastguard Worker                        None,
919*da0073e9SAndroid Build Coastguard Worker                        [
920*da0073e9SAndroid Build Coastguard Worker                            py_pytree.LeafSpec(),
921*da0073e9SAndroid Build Coastguard Worker                            py_pytree.LeafSpec(),
922*da0073e9SAndroid Build Coastguard Worker                        ],
923*da0073e9SAndroid Build Coastguard Worker                    ),
924*da0073e9SAndroid Build Coastguard Worker                    py_pytree.TreeSpec(
925*da0073e9SAndroid Build Coastguard Worker                        list,
926*da0073e9SAndroid Build Coastguard Worker                        None,
927*da0073e9SAndroid Build Coastguard Worker                        [
928*da0073e9SAndroid Build Coastguard Worker                            py_pytree.LeafSpec(),
929*da0073e9SAndroid Build Coastguard Worker                            py_pytree.LeafSpec(),
930*da0073e9SAndroid Build Coastguard Worker                        ],
931*da0073e9SAndroid Build Coastguard Worker                    ),
932*da0073e9SAndroid Build Coastguard Worker                    py_pytree.TreeSpec(dict, [], []),
933*da0073e9SAndroid Build Coastguard Worker                ],
934*da0073e9SAndroid Build Coastguard Worker            ),
935*da0073e9SAndroid Build Coastguard Worker        ],
936*da0073e9SAndroid Build Coastguard Worker    )
937*da0073e9SAndroid Build Coastguard Worker    def test_pytree_serialize(self, spec):
938*da0073e9SAndroid Build Coastguard Worker        # Ensure that the spec is valid
939*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
940*da0073e9SAndroid Build Coastguard Worker            spec,
941*da0073e9SAndroid Build Coastguard Worker            py_pytree.tree_structure(
942*da0073e9SAndroid Build Coastguard Worker                py_pytree.tree_unflatten([0] * spec.num_leaves, spec)
943*da0073e9SAndroid Build Coastguard Worker            ),
944*da0073e9SAndroid Build Coastguard Worker        )
945*da0073e9SAndroid Build Coastguard Worker
946*da0073e9SAndroid Build Coastguard Worker        serialized_spec = py_pytree.treespec_dumps(spec)
947*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(serialized_spec, str)
948*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(spec, py_pytree.treespec_loads(serialized_spec))
949*da0073e9SAndroid Build Coastguard Worker
950*da0073e9SAndroid Build Coastguard Worker    def test_pytree_serialize_namedtuple(self):
951*da0073e9SAndroid Build Coastguard Worker        Point1 = namedtuple("Point1", ["x", "y"])
952*da0073e9SAndroid Build Coastguard Worker        py_pytree._register_namedtuple(
953*da0073e9SAndroid Build Coastguard Worker            Point1,
954*da0073e9SAndroid Build Coastguard Worker            serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point1",
955*da0073e9SAndroid Build Coastguard Worker        )
956*da0073e9SAndroid Build Coastguard Worker
957*da0073e9SAndroid Build Coastguard Worker        spec = py_pytree.TreeSpec(
958*da0073e9SAndroid Build Coastguard Worker            namedtuple, Point1, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
959*da0073e9SAndroid Build Coastguard Worker        )
960*da0073e9SAndroid Build Coastguard Worker        roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
961*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(spec, roundtrip_spec)
962*da0073e9SAndroid Build Coastguard Worker
963*da0073e9SAndroid Build Coastguard Worker        class Point2(NamedTuple):
964*da0073e9SAndroid Build Coastguard Worker            x: int
965*da0073e9SAndroid Build Coastguard Worker            y: int
966*da0073e9SAndroid Build Coastguard Worker
967*da0073e9SAndroid Build Coastguard Worker        py_pytree._register_namedtuple(
968*da0073e9SAndroid Build Coastguard Worker            Point2,
969*da0073e9SAndroid Build Coastguard Worker            serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point2",
970*da0073e9SAndroid Build Coastguard Worker        )
971*da0073e9SAndroid Build Coastguard Worker
972*da0073e9SAndroid Build Coastguard Worker        spec = py_pytree.TreeSpec(
973*da0073e9SAndroid Build Coastguard Worker            namedtuple, Point2, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
974*da0073e9SAndroid Build Coastguard Worker        )
975*da0073e9SAndroid Build Coastguard Worker        roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
976*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(spec, roundtrip_spec)
977*da0073e9SAndroid Build Coastguard Worker
978*da0073e9SAndroid Build Coastguard Worker    def test_pytree_serialize_namedtuple_bad(self):
979*da0073e9SAndroid Build Coastguard Worker        DummyType = namedtuple("DummyType", ["x", "y"])
980*da0073e9SAndroid Build Coastguard Worker
981*da0073e9SAndroid Build Coastguard Worker        spec = py_pytree.TreeSpec(
982*da0073e9SAndroid Build Coastguard Worker            namedtuple, DummyType, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
983*da0073e9SAndroid Build Coastguard Worker        )
984*da0073e9SAndroid Build Coastguard Worker
985*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
986*da0073e9SAndroid Build Coastguard Worker            NotImplementedError, "Please register using `_register_namedtuple`"
987*da0073e9SAndroid Build Coastguard Worker        ):
988*da0073e9SAndroid Build Coastguard Worker            py_pytree.treespec_dumps(spec)
989*da0073e9SAndroid Build Coastguard Worker
990*da0073e9SAndroid Build Coastguard Worker    def test_pytree_custom_type_serialize_bad(self):
991*da0073e9SAndroid Build Coastguard Worker        class DummyType:
992*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x, y):
993*da0073e9SAndroid Build Coastguard Worker                self.x = x
994*da0073e9SAndroid Build Coastguard Worker                self.y = y
995*da0073e9SAndroid Build Coastguard Worker
996*da0073e9SAndroid Build Coastguard Worker        py_pytree.register_pytree_node(
997*da0073e9SAndroid Build Coastguard Worker            DummyType,
998*da0073e9SAndroid Build Coastguard Worker            lambda dummy: ([dummy.x, dummy.y], None),
999*da0073e9SAndroid Build Coastguard Worker            lambda xs, _: DummyType(*xs),
1000*da0073e9SAndroid Build Coastguard Worker        )
1001*da0073e9SAndroid Build Coastguard Worker
1002*da0073e9SAndroid Build Coastguard Worker        spec = py_pytree.TreeSpec(
1003*da0073e9SAndroid Build Coastguard Worker            DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1004*da0073e9SAndroid Build Coastguard Worker        )
1005*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1006*da0073e9SAndroid Build Coastguard Worker            NotImplementedError, "No registered serialization name"
1007*da0073e9SAndroid Build Coastguard Worker        ):
1008*da0073e9SAndroid Build Coastguard Worker            roundtrip_spec = py_pytree.treespec_dumps(spec)
1009*da0073e9SAndroid Build Coastguard Worker
1010*da0073e9SAndroid Build Coastguard Worker    def test_pytree_custom_type_serialize(self):
1011*da0073e9SAndroid Build Coastguard Worker        class DummyType:
1012*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x, y):
1013*da0073e9SAndroid Build Coastguard Worker                self.x = x
1014*da0073e9SAndroid Build Coastguard Worker                self.y = y
1015*da0073e9SAndroid Build Coastguard Worker
1016*da0073e9SAndroid Build Coastguard Worker        py_pytree.register_pytree_node(
1017*da0073e9SAndroid Build Coastguard Worker            DummyType,
1018*da0073e9SAndroid Build Coastguard Worker            lambda dummy: ([dummy.x, dummy.y], None),
1019*da0073e9SAndroid Build Coastguard Worker            lambda xs, _: DummyType(*xs),
1020*da0073e9SAndroid Build Coastguard Worker            serialized_type_name="test_pytree_custom_type_serialize.DummyType",
1021*da0073e9SAndroid Build Coastguard Worker            to_dumpable_context=lambda context: "moo",
1022*da0073e9SAndroid Build Coastguard Worker            from_dumpable_context=lambda dumpable_context: None,
1023*da0073e9SAndroid Build Coastguard Worker        )
1024*da0073e9SAndroid Build Coastguard Worker        spec = py_pytree.TreeSpec(
1025*da0073e9SAndroid Build Coastguard Worker            DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1026*da0073e9SAndroid Build Coastguard Worker        )
1027*da0073e9SAndroid Build Coastguard Worker        serialized_spec = py_pytree.treespec_dumps(spec, 1)
1028*da0073e9SAndroid Build Coastguard Worker        self.assertIn("moo", serialized_spec)
1029*da0073e9SAndroid Build Coastguard Worker        roundtrip_spec = py_pytree.treespec_loads(serialized_spec)
1030*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(roundtrip_spec, spec)
1031*da0073e9SAndroid Build Coastguard Worker
1032*da0073e9SAndroid Build Coastguard Worker    def test_pytree_serialize_register_bad(self):
1033*da0073e9SAndroid Build Coastguard Worker        class DummyType:
1034*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x, y):
1035*da0073e9SAndroid Build Coastguard Worker                self.x = x
1036*da0073e9SAndroid Build Coastguard Worker                self.y = y
1037*da0073e9SAndroid Build Coastguard Worker
1038*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1039*da0073e9SAndroid Build Coastguard Worker            ValueError, "Both to_dumpable_context and from_dumpable_context"
1040*da0073e9SAndroid Build Coastguard Worker        ):
1041*da0073e9SAndroid Build Coastguard Worker            py_pytree.register_pytree_node(
1042*da0073e9SAndroid Build Coastguard Worker                DummyType,
1043*da0073e9SAndroid Build Coastguard Worker                lambda dummy: ([dummy.x, dummy.y], None),
1044*da0073e9SAndroid Build Coastguard Worker                lambda xs, _: DummyType(*xs),
1045*da0073e9SAndroid Build Coastguard Worker                serialized_type_name="test_pytree_serialize_register_bad.DummyType",
1046*da0073e9SAndroid Build Coastguard Worker                to_dumpable_context=lambda context: "moo",
1047*da0073e9SAndroid Build Coastguard Worker            )
1048*da0073e9SAndroid Build Coastguard Worker
1049*da0073e9SAndroid Build Coastguard Worker    def test_pytree_context_serialize_bad(self):
1050*da0073e9SAndroid Build Coastguard Worker        class DummyType:
1051*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x, y):
1052*da0073e9SAndroid Build Coastguard Worker                self.x = x
1053*da0073e9SAndroid Build Coastguard Worker                self.y = y
1054*da0073e9SAndroid Build Coastguard Worker
1055*da0073e9SAndroid Build Coastguard Worker        py_pytree.register_pytree_node(
1056*da0073e9SAndroid Build Coastguard Worker            DummyType,
1057*da0073e9SAndroid Build Coastguard Worker            lambda dummy: ([dummy.x, dummy.y], None),
1058*da0073e9SAndroid Build Coastguard Worker            lambda xs, _: DummyType(*xs),
1059*da0073e9SAndroid Build Coastguard Worker            serialized_type_name="test_pytree_serialize_serialize_bad.DummyType",
1060*da0073e9SAndroid Build Coastguard Worker            to_dumpable_context=lambda context: DummyType,
1061*da0073e9SAndroid Build Coastguard Worker            from_dumpable_context=lambda dumpable_context: None,
1062*da0073e9SAndroid Build Coastguard Worker        )
1063*da0073e9SAndroid Build Coastguard Worker
1064*da0073e9SAndroid Build Coastguard Worker        spec = py_pytree.TreeSpec(
1065*da0073e9SAndroid Build Coastguard Worker            DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1066*da0073e9SAndroid Build Coastguard Worker        )
1067*da0073e9SAndroid Build Coastguard Worker
1068*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1069*da0073e9SAndroid Build Coastguard Worker            TypeError, "Object of type type is not JSON serializable"
1070*da0073e9SAndroid Build Coastguard Worker        ):
1071*da0073e9SAndroid Build Coastguard Worker            py_pytree.treespec_dumps(spec)
1072*da0073e9SAndroid Build Coastguard Worker
1073*da0073e9SAndroid Build Coastguard Worker    def test_pytree_serialize_bad_protocol(self):
1074*da0073e9SAndroid Build Coastguard Worker        import json
1075*da0073e9SAndroid Build Coastguard Worker
1076*da0073e9SAndroid Build Coastguard Worker        Point = namedtuple("Point", ["x", "y"])
1077*da0073e9SAndroid Build Coastguard Worker        spec = py_pytree.TreeSpec(
1078*da0073e9SAndroid Build Coastguard Worker            namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1079*da0073e9SAndroid Build Coastguard Worker        )
1080*da0073e9SAndroid Build Coastguard Worker        py_pytree._register_namedtuple(
1081*da0073e9SAndroid Build Coastguard Worker            Point,
1082*da0073e9SAndroid Build Coastguard Worker            serialized_type_name="test_pytree.test_pytree_serialize_bad_protocol.Point",
1083*da0073e9SAndroid Build Coastguard Worker        )
1084*da0073e9SAndroid Build Coastguard Worker
1085*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Unknown protocol"):
1086*da0073e9SAndroid Build Coastguard Worker            py_pytree.treespec_dumps(spec, -1)
1087*da0073e9SAndroid Build Coastguard Worker
1088*da0073e9SAndroid Build Coastguard Worker        serialized_spec = py_pytree.treespec_dumps(spec)
1089*da0073e9SAndroid Build Coastguard Worker        protocol, data = json.loads(serialized_spec)
1090*da0073e9SAndroid Build Coastguard Worker        bad_protocol_serialized_spec = json.dumps((-1, data))
1091*da0073e9SAndroid Build Coastguard Worker
1092*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Unknown protocol"):
1093*da0073e9SAndroid Build Coastguard Worker            py_pytree.treespec_loads(bad_protocol_serialized_spec)
1094*da0073e9SAndroid Build Coastguard Worker
1095*da0073e9SAndroid Build Coastguard Worker    def test_saved_serialized(self):
1096*da0073e9SAndroid Build Coastguard Worker        # py_pytree.tree_structure(OrderedDict([(1, (0, 1)), (2, 2), (3, {4: 3, 5: 4, 6: 5})]))
1097*da0073e9SAndroid Build Coastguard Worker        complicated_spec = py_pytree.TreeSpec(
1098*da0073e9SAndroid Build Coastguard Worker            OrderedDict,
1099*da0073e9SAndroid Build Coastguard Worker            [1, 2, 3],
1100*da0073e9SAndroid Build Coastguard Worker            [
1101*da0073e9SAndroid Build Coastguard Worker                py_pytree.TreeSpec(
1102*da0073e9SAndroid Build Coastguard Worker                    tuple, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1103*da0073e9SAndroid Build Coastguard Worker                ),
1104*da0073e9SAndroid Build Coastguard Worker                py_pytree.LeafSpec(),
1105*da0073e9SAndroid Build Coastguard Worker                py_pytree.TreeSpec(
1106*da0073e9SAndroid Build Coastguard Worker                    dict,
1107*da0073e9SAndroid Build Coastguard Worker                    [4, 5, 6],
1108*da0073e9SAndroid Build Coastguard Worker                    [
1109*da0073e9SAndroid Build Coastguard Worker                        py_pytree.LeafSpec(),
1110*da0073e9SAndroid Build Coastguard Worker                        py_pytree.LeafSpec(),
1111*da0073e9SAndroid Build Coastguard Worker                        py_pytree.LeafSpec(),
1112*da0073e9SAndroid Build Coastguard Worker                    ],
1113*da0073e9SAndroid Build Coastguard Worker                ),
1114*da0073e9SAndroid Build Coastguard Worker            ],
1115*da0073e9SAndroid Build Coastguard Worker        )
1116*da0073e9SAndroid Build Coastguard Worker        # Ensure that the spec is valid
1117*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1118*da0073e9SAndroid Build Coastguard Worker            complicated_spec,
1119*da0073e9SAndroid Build Coastguard Worker            py_pytree.tree_structure(
1120*da0073e9SAndroid Build Coastguard Worker                py_pytree.tree_unflatten(
1121*da0073e9SAndroid Build Coastguard Worker                    [0] * complicated_spec.num_leaves, complicated_spec
1122*da0073e9SAndroid Build Coastguard Worker                )
1123*da0073e9SAndroid Build Coastguard Worker            ),
1124*da0073e9SAndroid Build Coastguard Worker        )
1125*da0073e9SAndroid Build Coastguard Worker
1126*da0073e9SAndroid Build Coastguard Worker        serialized_spec = py_pytree.treespec_dumps(complicated_spec)
1127*da0073e9SAndroid Build Coastguard Worker        saved_spec = (
1128*da0073e9SAndroid Build Coastguard Worker            '[1, {"type": "collections.OrderedDict", "context": "[1, 2, 3]", '
1129*da0073e9SAndroid Build Coastguard Worker            '"children_spec": [{"type": "builtins.tuple", "context": "null", '
1130*da0073e9SAndroid Build Coastguard Worker            '"children_spec": [{"type": null, "context": null, '
1131*da0073e9SAndroid Build Coastguard Worker            '"children_spec": []}, {"type": null, "context": null, '
1132*da0073e9SAndroid Build Coastguard Worker            '"children_spec": []}]}, {"type": null, "context": null, '
1133*da0073e9SAndroid Build Coastguard Worker            '"children_spec": []}, {"type": "builtins.dict", "context": '
1134*da0073e9SAndroid Build Coastguard Worker            '"[4, 5, 6]", "children_spec": [{"type": null, "context": null, '
1135*da0073e9SAndroid Build Coastguard Worker            '"children_spec": []}, {"type": null, "context": null, "children_spec": '
1136*da0073e9SAndroid Build Coastguard Worker            '[]}, {"type": null, "context": null, "children_spec": []}]}]}]'
1137*da0073e9SAndroid Build Coastguard Worker        )
1138*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(serialized_spec, saved_spec)
1139*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(complicated_spec, py_pytree.treespec_loads(saved_spec))
1140*da0073e9SAndroid Build Coastguard Worker
1141*da0073e9SAndroid Build Coastguard Worker    def test_tree_map_with_path(self):
1142*da0073e9SAndroid Build Coastguard Worker        tree = [{i: i for i in range(10)}]
1143*da0073e9SAndroid Build Coastguard Worker        all_zeros = py_pytree.tree_map_with_path(
1144*da0073e9SAndroid Build Coastguard Worker            lambda kp, val: val - kp[1].key + kp[0].idx, tree
1145*da0073e9SAndroid Build Coastguard Worker        )
1146*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(all_zeros, [dict.fromkeys(range(10), 0)])
1147*da0073e9SAndroid Build Coastguard Worker
1148*da0073e9SAndroid Build Coastguard Worker    def test_tree_map_with_path_multiple_trees(self):
1149*da0073e9SAndroid Build Coastguard Worker        @dataclass
1150*da0073e9SAndroid Build Coastguard Worker        class ACustomPytree:
1151*da0073e9SAndroid Build Coastguard Worker            x: Any
1152*da0073e9SAndroid Build Coastguard Worker            y: Any
1153*da0073e9SAndroid Build Coastguard Worker            z: Any
1154*da0073e9SAndroid Build Coastguard Worker
1155*da0073e9SAndroid Build Coastguard Worker        tree1 = [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5]
1156*da0073e9SAndroid Build Coastguard Worker        tree2 = [ACustomPytree(x=2, y={"cin": [2, 2, 2], "bar": 2}, z="leaf"), 2]
1157*da0073e9SAndroid Build Coastguard Worker
1158*da0073e9SAndroid Build Coastguard Worker        py_pytree.register_pytree_node(
1159*da0073e9SAndroid Build Coastguard Worker            ACustomPytree,
1160*da0073e9SAndroid Build Coastguard Worker            flatten_fn=lambda f: ([f.x, f.y], f.z),
1161*da0073e9SAndroid Build Coastguard Worker            unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
1162*da0073e9SAndroid Build Coastguard Worker            flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
1163*da0073e9SAndroid Build Coastguard Worker        )
1164*da0073e9SAndroid Build Coastguard Worker        from_two_trees = py_pytree.tree_map_with_path(
1165*da0073e9SAndroid Build Coastguard Worker            lambda kp, a, b: a + b, tree1, tree2
1166*da0073e9SAndroid Build Coastguard Worker        )
1167*da0073e9SAndroid Build Coastguard Worker        from_one_tree = py_pytree.tree_map(lambda a: a + 2, tree1)
1168*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(from_two_trees, from_one_tree)
1169*da0073e9SAndroid Build Coastguard Worker
1170*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("dynamo pytree tracing doesn't work here")
1171*da0073e9SAndroid Build Coastguard Worker    def test_tree_flatten_with_path_is_leaf(self):
1172*da0073e9SAndroid Build Coastguard Worker        leaf_dict = {"foo": [(3)]}
1173*da0073e9SAndroid Build Coastguard Worker        pytree = (["hello", [1, 2], leaf_dict],)
1174*da0073e9SAndroid Build Coastguard Worker        key_leaves, spec = py_pytree.tree_flatten_with_path(
1175*da0073e9SAndroid Build Coastguard Worker            pytree, is_leaf=lambda x: isinstance(x, dict)
1176*da0073e9SAndroid Build Coastguard Worker        )
1177*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(key_leaves[-1][1] is leaf_dict)
1178*da0073e9SAndroid Build Coastguard Worker
1179*da0073e9SAndroid Build Coastguard Worker    def test_tree_flatten_with_path_roundtrip(self):
1180*da0073e9SAndroid Build Coastguard Worker        class ANamedTuple(NamedTuple):
1181*da0073e9SAndroid Build Coastguard Worker            x: torch.Tensor
1182*da0073e9SAndroid Build Coastguard Worker            y: int
1183*da0073e9SAndroid Build Coastguard Worker            z: str
1184*da0073e9SAndroid Build Coastguard Worker
1185*da0073e9SAndroid Build Coastguard Worker        @dataclass
1186*da0073e9SAndroid Build Coastguard Worker        class ACustomPytree:
1187*da0073e9SAndroid Build Coastguard Worker            x: Any
1188*da0073e9SAndroid Build Coastguard Worker            y: Any
1189*da0073e9SAndroid Build Coastguard Worker            z: Any
1190*da0073e9SAndroid Build Coastguard Worker
1191*da0073e9SAndroid Build Coastguard Worker        py_pytree.register_pytree_node(
1192*da0073e9SAndroid Build Coastguard Worker            ACustomPytree,
1193*da0073e9SAndroid Build Coastguard Worker            flatten_fn=lambda f: ([f.x, f.y], f.z),
1194*da0073e9SAndroid Build Coastguard Worker            unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
1195*da0073e9SAndroid Build Coastguard Worker            flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
1196*da0073e9SAndroid Build Coastguard Worker        )
1197*da0073e9SAndroid Build Coastguard Worker
1198*da0073e9SAndroid Build Coastguard Worker        SOME_PYTREES = [
1199*da0073e9SAndroid Build Coastguard Worker            (None,),
1200*da0073e9SAndroid Build Coastguard Worker            ["hello", [1, 2], {"foo": [(3)]}],
1201*da0073e9SAndroid Build Coastguard Worker            [ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")],
1202*da0073e9SAndroid Build Coastguard Worker            [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5],
1203*da0073e9SAndroid Build Coastguard Worker        ]
1204*da0073e9SAndroid Build Coastguard Worker        for pytree in SOME_PYTREES:
1205*da0073e9SAndroid Build Coastguard Worker            key_leaves, spec = py_pytree.tree_flatten_with_path(pytree)
1206*da0073e9SAndroid Build Coastguard Worker            actual = py_pytree.tree_unflatten([leaf for _, leaf in key_leaves], spec)
1207*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(actual, pytree)
1208*da0073e9SAndroid Build Coastguard Worker
1209*da0073e9SAndroid Build Coastguard Worker    def test_tree_leaves_with_path(self):
1210*da0073e9SAndroid Build Coastguard Worker        class ANamedTuple(NamedTuple):
1211*da0073e9SAndroid Build Coastguard Worker            x: torch.Tensor
1212*da0073e9SAndroid Build Coastguard Worker            y: int
1213*da0073e9SAndroid Build Coastguard Worker            z: str
1214*da0073e9SAndroid Build Coastguard Worker
1215*da0073e9SAndroid Build Coastguard Worker        @dataclass
1216*da0073e9SAndroid Build Coastguard Worker        class ACustomPytree:
1217*da0073e9SAndroid Build Coastguard Worker            x: Any
1218*da0073e9SAndroid Build Coastguard Worker            y: Any
1219*da0073e9SAndroid Build Coastguard Worker            z: Any
1220*da0073e9SAndroid Build Coastguard Worker
1221*da0073e9SAndroid Build Coastguard Worker        py_pytree.register_pytree_node(
1222*da0073e9SAndroid Build Coastguard Worker            ACustomPytree,
1223*da0073e9SAndroid Build Coastguard Worker            flatten_fn=lambda f: ([f.x, f.y], f.z),
1224*da0073e9SAndroid Build Coastguard Worker            unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
1225*da0073e9SAndroid Build Coastguard Worker            flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
1226*da0073e9SAndroid Build Coastguard Worker        )
1227*da0073e9SAndroid Build Coastguard Worker
1228*da0073e9SAndroid Build Coastguard Worker        SOME_PYTREES = [
1229*da0073e9SAndroid Build Coastguard Worker            (None,),
1230*da0073e9SAndroid Build Coastguard Worker            ["hello", [1, 2], {"foo": [(3)]}],
1231*da0073e9SAndroid Build Coastguard Worker            [ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")],
1232*da0073e9SAndroid Build Coastguard Worker            [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5],
1233*da0073e9SAndroid Build Coastguard Worker        ]
1234*da0073e9SAndroid Build Coastguard Worker        for pytree in SOME_PYTREES:
1235*da0073e9SAndroid Build Coastguard Worker            flat_out, _ = py_pytree.tree_flatten_with_path(pytree)
1236*da0073e9SAndroid Build Coastguard Worker            leaves_out = py_pytree.tree_leaves_with_path(pytree)
1237*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(flat_out, leaves_out)
1238*da0073e9SAndroid Build Coastguard Worker
1239*da0073e9SAndroid Build Coastguard Worker    def test_key_str(self):
1240*da0073e9SAndroid Build Coastguard Worker        class ANamedTuple(NamedTuple):
1241*da0073e9SAndroid Build Coastguard Worker            x: str
1242*da0073e9SAndroid Build Coastguard Worker            y: int
1243*da0073e9SAndroid Build Coastguard Worker
1244*da0073e9SAndroid Build Coastguard Worker        tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],)
1245*da0073e9SAndroid Build Coastguard Worker        flat, _ = py_pytree.tree_flatten_with_path(tree)
1246*da0073e9SAndroid Build Coastguard Worker        paths = [f"{py_pytree.keystr(kp)}: {val}" for kp, val in flat]
1247*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1248*da0073e9SAndroid Build Coastguard Worker            paths,
1249*da0073e9SAndroid Build Coastguard Worker            [
1250*da0073e9SAndroid Build Coastguard Worker                "[0][0]: hello",
1251*da0073e9SAndroid Build Coastguard Worker                "[0][1][0]: 1",
1252*da0073e9SAndroid Build Coastguard Worker                "[0][1][1]: 2",
1253*da0073e9SAndroid Build Coastguard Worker                "[0][2]['foo'][0]: 3",
1254*da0073e9SAndroid Build Coastguard Worker                "[0][2]['bar'][0].x: baz",
1255*da0073e9SAndroid Build Coastguard Worker                "[0][2]['bar'][0].y: 10",
1256*da0073e9SAndroid Build Coastguard Worker            ],
1257*da0073e9SAndroid Build Coastguard Worker        )
1258*da0073e9SAndroid Build Coastguard Worker
1259*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("AssertionError in dynamo")
1260*da0073e9SAndroid Build Coastguard Worker    def test_flatten_flatten_with_key_consistency(self):
1261*da0073e9SAndroid Build Coastguard Worker        """Check that flatten and flatten_with_key produces consistent leaves/context."""
1262*da0073e9SAndroid Build Coastguard Worker        reg = py_pytree.SUPPORTED_NODES
1263*da0073e9SAndroid Build Coastguard Worker
1264*da0073e9SAndroid Build Coastguard Worker        EXAMPLE_TREE = {
1265*da0073e9SAndroid Build Coastguard Worker            list: [1, 2, 3],
1266*da0073e9SAndroid Build Coastguard Worker            tuple: (1, 2, 3),
1267*da0073e9SAndroid Build Coastguard Worker            dict: {"foo": 1, "bar": 2},
1268*da0073e9SAndroid Build Coastguard Worker            namedtuple: collections.namedtuple("ANamedTuple", ["x", "y"])(1, 2),
1269*da0073e9SAndroid Build Coastguard Worker            OrderedDict: OrderedDict([("foo", 1), ("bar", 2)]),
1270*da0073e9SAndroid Build Coastguard Worker            defaultdict: defaultdict(int, {"foo": 1, "bar": 2}),
1271*da0073e9SAndroid Build Coastguard Worker            deque: deque([1, 2, 3]),
1272*da0073e9SAndroid Build Coastguard Worker            torch.Size: torch.Size([1, 2, 3]),
1273*da0073e9SAndroid Build Coastguard Worker            immutable_dict: immutable_dict({"foo": 1, "bar": 2}),
1274*da0073e9SAndroid Build Coastguard Worker            immutable_list: immutable_list([1, 2, 3]),
1275*da0073e9SAndroid Build Coastguard Worker        }
1276*da0073e9SAndroid Build Coastguard Worker
1277*da0073e9SAndroid Build Coastguard Worker        for typ in reg:
1278*da0073e9SAndroid Build Coastguard Worker            example = EXAMPLE_TREE.get(typ)
1279*da0073e9SAndroid Build Coastguard Worker            if example is None:
1280*da0073e9SAndroid Build Coastguard Worker                continue
1281*da0073e9SAndroid Build Coastguard Worker            flat_with_path, spec1 = py_pytree.tree_flatten_with_path(example)
1282*da0073e9SAndroid Build Coastguard Worker            flat, spec2 = py_pytree.tree_flatten(example)
1283*da0073e9SAndroid Build Coastguard Worker
1284*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(flat, [x[1] for x in flat_with_path])
1285*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(spec1, spec2)
1286*da0073e9SAndroid Build Coastguard Worker
1287*da0073e9SAndroid Build Coastguard Worker    def test_key_access(self):
1288*da0073e9SAndroid Build Coastguard Worker        class ANamedTuple(NamedTuple):
1289*da0073e9SAndroid Build Coastguard Worker            x: str
1290*da0073e9SAndroid Build Coastguard Worker            y: int
1291*da0073e9SAndroid Build Coastguard Worker
1292*da0073e9SAndroid Build Coastguard Worker        tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],)
1293*da0073e9SAndroid Build Coastguard Worker        flat, _ = py_pytree.tree_flatten_with_path(tree)
1294*da0073e9SAndroid Build Coastguard Worker        for kp, val in flat:
1295*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(py_pytree.key_get(tree, kp), val)
1296*da0073e9SAndroid Build Coastguard Worker
1297*da0073e9SAndroid Build Coastguard Worker
1298*da0073e9SAndroid Build Coastguard Workerclass TestCxxPytree(TestCase):
1299*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
1300*da0073e9SAndroid Build Coastguard Worker        if IS_FBCODE:
1301*da0073e9SAndroid Build Coastguard Worker            raise unittest.SkipTest("C++ pytree tests are not supported in fbcode")
1302*da0073e9SAndroid Build Coastguard Worker
1303*da0073e9SAndroid Build Coastguard Worker    def test_treespec_equality(self):
1304*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec())
1305*da0073e9SAndroid Build Coastguard Worker
1306*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
1307*da0073e9SAndroid Build Coastguard Worker    def test_treespec_repr(self):
1308*da0073e9SAndroid Build Coastguard Worker        # Check that it looks sane
1309*da0073e9SAndroid Build Coastguard Worker        pytree = (0, [0, 0, [0]])
1310*da0073e9SAndroid Build Coastguard Worker        _, spec = cxx_pytree.tree_flatten(pytree)
1311*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(repr(spec), "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)")
1312*da0073e9SAndroid Build Coastguard Worker
1313*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
1314*da0073e9SAndroid Build Coastguard Worker    def test_treespec_repr_dynamo(self):
1315*da0073e9SAndroid Build Coastguard Worker        # Check that it looks sane
1316*da0073e9SAndroid Build Coastguard Worker        pytree = (0, [0, 0, [0]])
1317*da0073e9SAndroid Build Coastguard Worker        _, spec = cxx_pytree.tree_flatten(pytree)
1318*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1319*da0073e9SAndroid Build Coastguard Worker            repr(spec),
1320*da0073e9SAndroid Build Coastguard Worker            "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)",
1321*da0073e9SAndroid Build Coastguard Worker        )
1322*da0073e9SAndroid Build Coastguard Worker
1323*da0073e9SAndroid Build Coastguard Worker    @parametrize(
1324*da0073e9SAndroid Build Coastguard Worker        "spec",
1325*da0073e9SAndroid Build Coastguard Worker        [
1326*da0073e9SAndroid Build Coastguard Worker            cxx_pytree.tree_structure([]),
1327*da0073e9SAndroid Build Coastguard Worker            cxx_pytree.tree_structure(()),
1328*da0073e9SAndroid Build Coastguard Worker            cxx_pytree.tree_structure({}),
1329*da0073e9SAndroid Build Coastguard Worker            cxx_pytree.tree_structure([0]),
1330*da0073e9SAndroid Build Coastguard Worker            cxx_pytree.tree_structure([0, 1]),
1331*da0073e9SAndroid Build Coastguard Worker            cxx_pytree.tree_structure((0, 1, 2)),
1332*da0073e9SAndroid Build Coastguard Worker            cxx_pytree.tree_structure({"a": 0, "b": 1, "c": 2}),
1333*da0073e9SAndroid Build Coastguard Worker            cxx_pytree.tree_structure(
1334*da0073e9SAndroid Build Coastguard Worker                OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
1335*da0073e9SAndroid Build Coastguard Worker            ),
1336*da0073e9SAndroid Build Coastguard Worker            cxx_pytree.tree_structure([(0, 1, [2, 3])]),
1337*da0073e9SAndroid Build Coastguard Worker            cxx_pytree.tree_structure(
1338*da0073e9SAndroid Build Coastguard Worker                defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}})
1339*da0073e9SAndroid Build Coastguard Worker            ),
1340*da0073e9SAndroid Build Coastguard Worker        ],
1341*da0073e9SAndroid Build Coastguard Worker    )
1342*da0073e9SAndroid Build Coastguard Worker    def test_pytree_serialize(self, spec):
1343*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1344*da0073e9SAndroid Build Coastguard Worker            spec,
1345*da0073e9SAndroid Build Coastguard Worker            cxx_pytree.tree_structure(
1346*da0073e9SAndroid Build Coastguard Worker                cxx_pytree.tree_unflatten([0] * spec.num_leaves, spec)
1347*da0073e9SAndroid Build Coastguard Worker            ),
1348*da0073e9SAndroid Build Coastguard Worker        )
1349*da0073e9SAndroid Build Coastguard Worker
1350*da0073e9SAndroid Build Coastguard Worker        serialized_spec = cxx_pytree.treespec_dumps(spec)
1351*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(serialized_spec, str)
1352*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(spec, cxx_pytree.treespec_loads(serialized_spec))
1353*da0073e9SAndroid Build Coastguard Worker
1354*da0073e9SAndroid Build Coastguard Worker    def test_pytree_serialize_namedtuple(self):
1355*da0073e9SAndroid Build Coastguard Worker        py_pytree._register_namedtuple(
1356*da0073e9SAndroid Build Coastguard Worker            GlobalPoint,
1357*da0073e9SAndroid Build Coastguard Worker            serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.GlobalPoint",
1358*da0073e9SAndroid Build Coastguard Worker        )
1359*da0073e9SAndroid Build Coastguard Worker        spec = cxx_pytree.tree_structure(GlobalPoint(0, 1))
1360*da0073e9SAndroid Build Coastguard Worker
1361*da0073e9SAndroid Build Coastguard Worker        roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
1362*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)
1363*da0073e9SAndroid Build Coastguard Worker
1364*da0073e9SAndroid Build Coastguard Worker        LocalPoint = namedtuple("LocalPoint", ["x", "y"])
1365*da0073e9SAndroid Build Coastguard Worker        py_pytree._register_namedtuple(
1366*da0073e9SAndroid Build Coastguard Worker            LocalPoint,
1367*da0073e9SAndroid Build Coastguard Worker            serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.LocalPoint",
1368*da0073e9SAndroid Build Coastguard Worker        )
1369*da0073e9SAndroid Build Coastguard Worker        spec = cxx_pytree.tree_structure(LocalPoint(0, 1))
1370*da0073e9SAndroid Build Coastguard Worker
1371*da0073e9SAndroid Build Coastguard Worker        roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
1372*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)
1373*da0073e9SAndroid Build Coastguard Worker
1374*da0073e9SAndroid Build Coastguard Worker    def test_pytree_custom_type_serialize(self):
1375*da0073e9SAndroid Build Coastguard Worker        cxx_pytree.register_pytree_node(
1376*da0073e9SAndroid Build Coastguard Worker            GlobalDummyType,
1377*da0073e9SAndroid Build Coastguard Worker            lambda dummy: ([dummy.x, dummy.y], None),
1378*da0073e9SAndroid Build Coastguard Worker            lambda xs, _: GlobalDummyType(*xs),
1379*da0073e9SAndroid Build Coastguard Worker            serialized_type_name="GlobalDummyType",
1380*da0073e9SAndroid Build Coastguard Worker        )
1381*da0073e9SAndroid Build Coastguard Worker        spec = cxx_pytree.tree_structure(GlobalDummyType(0, 1))
1382*da0073e9SAndroid Build Coastguard Worker        serialized_spec = cxx_pytree.treespec_dumps(spec)
1383*da0073e9SAndroid Build Coastguard Worker        roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
1384*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(roundtrip_spec, spec)
1385*da0073e9SAndroid Build Coastguard Worker
1386*da0073e9SAndroid Build Coastguard Worker        class LocalDummyType:
1387*da0073e9SAndroid Build Coastguard Worker            def __init__(self, x, y):
1388*da0073e9SAndroid Build Coastguard Worker                self.x = x
1389*da0073e9SAndroid Build Coastguard Worker                self.y = y
1390*da0073e9SAndroid Build Coastguard Worker
1391*da0073e9SAndroid Build Coastguard Worker        cxx_pytree.register_pytree_node(
1392*da0073e9SAndroid Build Coastguard Worker            LocalDummyType,
1393*da0073e9SAndroid Build Coastguard Worker            lambda dummy: ([dummy.x, dummy.y], None),
1394*da0073e9SAndroid Build Coastguard Worker            lambda xs, _: LocalDummyType(*xs),
1395*da0073e9SAndroid Build Coastguard Worker            serialized_type_name="LocalDummyType",
1396*da0073e9SAndroid Build Coastguard Worker        )
1397*da0073e9SAndroid Build Coastguard Worker        spec = cxx_pytree.tree_structure(LocalDummyType(0, 1))
1398*da0073e9SAndroid Build Coastguard Worker        serialized_spec = cxx_pytree.treespec_dumps(spec)
1399*da0073e9SAndroid Build Coastguard Worker        roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
1400*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(roundtrip_spec, spec)
1401*da0073e9SAndroid Build Coastguard Worker
1402*da0073e9SAndroid Build Coastguard Worker
1403*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestGenericPytree)
1404*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestPythonPytree)
1405*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestCxxPytree)
1406*da0073e9SAndroid Build Coastguard Worker
1407*da0073e9SAndroid Build Coastguard Worker
1408*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1409*da0073e9SAndroid Build Coastguard Worker    run_tests()
1410