xref: /aosp_15_r20/external/pytorch/test/test_pytree.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: pytree"]
2
3import collections
4import inspect
5import os
6import re
7import subprocess
8import sys
9import unittest
10from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict
11from dataclasses import dataclass
12from typing import Any, NamedTuple
13
14import torch
15import torch.utils._pytree as py_pytree
16from torch.fx.immutable_collections import immutable_dict, immutable_list
17from torch.testing._internal.common_utils import (
18    instantiate_parametrized_tests,
19    IS_FBCODE,
20    parametrize,
21    run_tests,
22    skipIfTorchDynamo,
23    subtest,
24    TEST_WITH_TORCHDYNAMO,
25    TestCase,
26)
27
28
29if IS_FBCODE:
30    # optree is not yet enabled in fbcode, so just re-test the python implementation
31    cxx_pytree = py_pytree
32else:
33    import torch.utils._cxx_pytree as cxx_pytree
34
35GlobalPoint = namedtuple("GlobalPoint", ["x", "y"])
36
37
38class GlobalDummyType:
39    def __init__(self, x, y):
40        self.x = x
41        self.y = y
42
43
44class TestGenericPytree(TestCase):
45    def test_aligned_public_apis(self):
46        public_apis = py_pytree.__all__
47
48        self.assertEqual(public_apis, cxx_pytree.__all__)
49
50        for name in public_apis:
51            cxx_api = getattr(cxx_pytree, name)
52            py_api = getattr(py_pytree, name)
53
54            self.assertEqual(inspect.isclass(cxx_api), inspect.isclass(py_api))
55            self.assertEqual(inspect.isfunction(cxx_api), inspect.isfunction(py_api))
56            if inspect.isfunction(cxx_api):
57                cxx_signature = inspect.signature(cxx_api)
58                py_signature = inspect.signature(py_api)
59
60                # Check the parameter names are the same.
61                cxx_param_names = list(cxx_signature.parameters)
62                py_param_names = list(py_signature.parameters)
63                self.assertEqual(cxx_param_names, py_param_names)
64
65                # Check the positional parameters are the same.
66                cxx_positional_param_names = [
67                    n
68                    for n, p in cxx_signature.parameters.items()
69                    if (
70                        p.kind
71                        in {
72                            inspect.Parameter.POSITIONAL_ONLY,
73                            inspect.Parameter.POSITIONAL_OR_KEYWORD,
74                        }
75                    )
76                ]
77                py_positional_param_names = [
78                    n
79                    for n, p in py_signature.parameters.items()
80                    if (
81                        p.kind
82                        in {
83                            inspect.Parameter.POSITIONAL_ONLY,
84                            inspect.Parameter.POSITIONAL_OR_KEYWORD,
85                        }
86                    )
87                ]
88                self.assertEqual(cxx_positional_param_names, py_positional_param_names)
89
90                for py_name, py_param in py_signature.parameters.items():
91                    self.assertIn(py_name, cxx_signature.parameters)
92                    cxx_param = cxx_signature.parameters[py_name]
93
94                    # Check parameter kinds and default values are the same.
95                    self.assertEqual(cxx_param.kind, py_param.kind)
96                    self.assertEqual(cxx_param.default, py_param.default)
97
98                    # Check parameter annotations are the same.
99                    if "TreeSpec" in str(cxx_param.annotation):
100                        self.assertIn("TreeSpec", str(py_param.annotation))
101                        self.assertEqual(
102                            re.sub(
103                                r"(?:\b)([\w\.]*)TreeSpec(?:\b)",
104                                "TreeSpec",
105                                str(cxx_param.annotation),
106                            ),
107                            re.sub(
108                                r"(?:\b)([\w\.]*)TreeSpec(?:\b)",
109                                "TreeSpec",
110                                str(py_param.annotation),
111                            ),
112                            msg=(
113                                f"C++ parameter {cxx_param} "
114                                f"does not match Python parameter {py_param} "
115                                f"for API `{name}`"
116                            ),
117                        )
118                    else:
119                        self.assertEqual(
120                            cxx_param.annotation,
121                            py_param.annotation,
122                            msg=(
123                                f"C++ parameter {cxx_param} "
124                                f"does not match Python parameter {py_param} "
125                                f"for API `{name}`"
126                            ),
127                        )
128
129    @parametrize(
130        "pytree_impl",
131        [
132            subtest(py_pytree, name="py"),
133            subtest(cxx_pytree, name="cxx"),
134        ],
135    )
136    def test_register_pytree_node(self, pytree_impl):
137        class MyDict(UserDict):
138            pass
139
140        d = MyDict(a=1, b=2, c=3)
141
142        # Custom types are leaf nodes by default
143        values, spec = pytree_impl.tree_flatten(d)
144        self.assertEqual(values, [d])
145        self.assertIs(values[0], d)
146        self.assertEqual(d, pytree_impl.tree_unflatten(values, spec))
147        self.assertTrue(spec.is_leaf())
148
149        # Register MyDict as a pytree node
150        pytree_impl.register_pytree_node(
151            MyDict,
152            lambda d: (list(d.values()), list(d.keys())),
153            lambda values, keys: MyDict(zip(keys, values)),
154        )
155
156        values, spec = pytree_impl.tree_flatten(d)
157        self.assertEqual(values, [1, 2, 3])
158        self.assertEqual(d, pytree_impl.tree_unflatten(values, spec))
159
160        # Do not allow registering the same type twice
161        with self.assertRaisesRegex(ValueError, "already registered"):
162            pytree_impl.register_pytree_node(
163                MyDict,
164                lambda d: (list(d.values()), list(d.keys())),
165                lambda values, keys: MyDict(zip(keys, values)),
166            )
167
168    @parametrize(
169        "pytree_impl",
170        [
171            subtest(py_pytree, name="py"),
172            subtest(cxx_pytree, name="cxx"),
173        ],
174    )
175    def test_flatten_unflatten_leaf(self, pytree_impl):
176        def run_test_with_leaf(leaf):
177            values, treespec = pytree_impl.tree_flatten(leaf)
178            self.assertEqual(values, [leaf])
179            self.assertEqual(treespec, pytree_impl.LeafSpec())
180
181            unflattened = pytree_impl.tree_unflatten(values, treespec)
182            self.assertEqual(unflattened, leaf)
183
184        run_test_with_leaf(1)
185        run_test_with_leaf(1.0)
186        run_test_with_leaf(None)
187        run_test_with_leaf(bool)
188        run_test_with_leaf(torch.randn(3, 3))
189
190    @parametrize(
191        "pytree_impl,gen_expected_fn",
192        [
193            subtest(
194                (
195                    py_pytree,
196                    lambda tup: py_pytree.TreeSpec(
197                        tuple, None, [py_pytree.LeafSpec() for _ in tup]
198                    ),
199                ),
200                name="py",
201            ),
202            subtest(
203                (cxx_pytree, lambda tup: cxx_pytree.tree_structure((0,) * len(tup))),
204                name="cxx",
205            ),
206        ],
207    )
208    def test_flatten_unflatten_tuple(self, pytree_impl, gen_expected_fn):
209        def run_test(tup):
210            expected_spec = gen_expected_fn(tup)
211            values, treespec = pytree_impl.tree_flatten(tup)
212            self.assertIsInstance(values, list)
213            self.assertEqual(values, list(tup))
214            self.assertEqual(treespec, expected_spec)
215
216            unflattened = pytree_impl.tree_unflatten(values, treespec)
217            self.assertEqual(unflattened, tup)
218            self.assertIsInstance(unflattened, tuple)
219
220        run_test(())
221        run_test((1.0,))
222        run_test((1.0, 2))
223        run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11))
224
225    @parametrize(
226        "pytree_impl,gen_expected_fn",
227        [
228            subtest(
229                (
230                    py_pytree,
231                    lambda lst: py_pytree.TreeSpec(
232                        list, None, [py_pytree.LeafSpec() for _ in lst]
233                    ),
234                ),
235                name="py",
236            ),
237            subtest(
238                (cxx_pytree, lambda lst: cxx_pytree.tree_structure([0] * len(lst))),
239                name="cxx",
240            ),
241        ],
242    )
243    def test_flatten_unflatten_list(self, pytree_impl, gen_expected_fn):
244        def run_test(lst):
245            expected_spec = gen_expected_fn(lst)
246            values, treespec = pytree_impl.tree_flatten(lst)
247            self.assertIsInstance(values, list)
248            self.assertEqual(values, lst)
249            self.assertEqual(treespec, expected_spec)
250
251            unflattened = pytree_impl.tree_unflatten(values, treespec)
252            self.assertEqual(unflattened, lst)
253            self.assertIsInstance(unflattened, list)
254
255        run_test([])
256        run_test([1.0, 2])
257        run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11])
258
259    @parametrize(
260        "pytree_impl,gen_expected_fn",
261        [
262            subtest(
263                (
264                    py_pytree,
265                    lambda dct: py_pytree.TreeSpec(
266                        dict,
267                        list(dct.keys()),
268                        [py_pytree.LeafSpec() for _ in dct.values()],
269                    ),
270                ),
271                name="py",
272            ),
273            subtest(
274                (
275                    cxx_pytree,
276                    lambda dct: cxx_pytree.tree_structure(dict.fromkeys(dct, 0)),
277                ),
278                name="cxx",
279            ),
280        ],
281    )
282    def test_flatten_unflatten_dict(self, pytree_impl, gen_expected_fn):
283        def run_test(dct):
284            expected_spec = gen_expected_fn(dct)
285            values, treespec = pytree_impl.tree_flatten(dct)
286            self.assertIsInstance(values, list)
287            self.assertEqual(values, list(dct.values()))
288            self.assertEqual(treespec, expected_spec)
289
290            unflattened = pytree_impl.tree_unflatten(values, treespec)
291            self.assertEqual(unflattened, dct)
292            self.assertIsInstance(unflattened, dict)
293
294        run_test({})
295        run_test({"a": 1})
296        run_test({"abcdefg": torch.randn(2, 3)})
297        run_test({1: torch.randn(2, 3)})
298        run_test({"a": 1, "b": 2, "c": torch.randn(2, 3)})
299
300    @parametrize(
301        "pytree_impl,gen_expected_fn",
302        [
303            subtest(
304                (
305                    py_pytree,
306                    lambda odict: py_pytree.TreeSpec(
307                        OrderedDict,
308                        list(odict.keys()),
309                        [py_pytree.LeafSpec() for _ in odict.values()],
310                    ),
311                ),
312                name="py",
313            ),
314            subtest(
315                (
316                    cxx_pytree,
317                    lambda odict: cxx_pytree.tree_structure(
318                        OrderedDict.fromkeys(odict, 0)
319                    ),
320                ),
321                name="cxx",
322            ),
323        ],
324    )
325    def test_flatten_unflatten_ordereddict(self, pytree_impl, gen_expected_fn):
326        def run_test(odict):
327            expected_spec = gen_expected_fn(odict)
328            values, treespec = pytree_impl.tree_flatten(odict)
329            self.assertIsInstance(values, list)
330            self.assertEqual(values, list(odict.values()))
331            self.assertEqual(treespec, expected_spec)
332
333            unflattened = pytree_impl.tree_unflatten(values, treespec)
334            self.assertEqual(unflattened, odict)
335            self.assertIsInstance(unflattened, OrderedDict)
336
337        od = OrderedDict()
338        run_test(od)
339
340        od["b"] = 1
341        od["a"] = torch.tensor(3.14)
342        run_test(od)
343
344    @parametrize(
345        "pytree_impl,gen_expected_fn",
346        [
347            subtest(
348                (
349                    py_pytree,
350                    lambda ddct: py_pytree.TreeSpec(
351                        defaultdict,
352                        [ddct.default_factory, list(ddct.keys())],
353                        [py_pytree.LeafSpec() for _ in ddct.values()],
354                    ),
355                ),
356                name="py",
357            ),
358            subtest(
359                (
360                    cxx_pytree,
361                    lambda ddct: cxx_pytree.tree_structure(
362                        defaultdict(ddct.default_factory, dict.fromkeys(ddct, 0))
363                    ),
364                ),
365                name="cxx",
366            ),
367        ],
368    )
369    def test_flatten_unflatten_defaultdict(self, pytree_impl, gen_expected_fn):
370        def run_test(ddct):
371            expected_spec = gen_expected_fn(ddct)
372            values, treespec = pytree_impl.tree_flatten(ddct)
373            self.assertIsInstance(values, list)
374            self.assertEqual(values, list(ddct.values()))
375            self.assertEqual(treespec, expected_spec)
376
377            unflattened = pytree_impl.tree_unflatten(values, treespec)
378            self.assertEqual(unflattened, ddct)
379            self.assertEqual(unflattened.default_factory, ddct.default_factory)
380            self.assertIsInstance(unflattened, defaultdict)
381
382        run_test(defaultdict(list, {}))
383        run_test(defaultdict(int, {"a": 1}))
384        run_test(defaultdict(int, {"abcdefg": torch.randn(2, 3)}))
385        run_test(defaultdict(int, {1: torch.randn(2, 3)}))
386        run_test(defaultdict(int, {"a": 1, "b": 2, "c": torch.randn(2, 3)}))
387
388    @parametrize(
389        "pytree_impl,gen_expected_fn",
390        [
391            subtest(
392                (
393                    py_pytree,
394                    lambda deq: py_pytree.TreeSpec(
395                        deque, deq.maxlen, [py_pytree.LeafSpec() for _ in deq]
396                    ),
397                ),
398                name="py",
399            ),
400            subtest(
401                (
402                    cxx_pytree,
403                    lambda deq: cxx_pytree.tree_structure(
404                        deque(deq, maxlen=deq.maxlen)
405                    ),
406                ),
407                name="cxx",
408            ),
409        ],
410    )
411    def test_flatten_unflatten_deque(self, pytree_impl, gen_expected_fn):
412        def run_test(deq):
413            expected_spec = gen_expected_fn(deq)
414            values, treespec = pytree_impl.tree_flatten(deq)
415            self.assertIsInstance(values, list)
416            self.assertEqual(values, list(deq))
417            self.assertEqual(treespec, expected_spec)
418
419            unflattened = pytree_impl.tree_unflatten(values, treespec)
420            self.assertEqual(unflattened, deq)
421            self.assertEqual(unflattened.maxlen, deq.maxlen)
422            self.assertIsInstance(unflattened, deque)
423
424        run_test(deque([]))
425        run_test(deque([1.0, 2]))
426        run_test(deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8))
427
428    @parametrize(
429        "pytree_impl",
430        [
431            subtest(py_pytree, name="py"),
432            subtest(cxx_pytree, name="cxx"),
433        ],
434    )
435    def test_flatten_unflatten_namedtuple(self, pytree_impl):
436        Point = namedtuple("Point", ["x", "y"])
437
438        def run_test(tup):
439            if pytree_impl is py_pytree:
440                expected_spec = py_pytree.TreeSpec(
441                    namedtuple, Point, [py_pytree.LeafSpec() for _ in tup]
442                )
443            else:
444                expected_spec = cxx_pytree.tree_structure(Point(0, 1))
445            values, treespec = pytree_impl.tree_flatten(tup)
446            self.assertIsInstance(values, list)
447            self.assertEqual(values, list(tup))
448            self.assertEqual(treespec, expected_spec)
449
450            unflattened = pytree_impl.tree_unflatten(values, treespec)
451            self.assertEqual(unflattened, tup)
452            self.assertIsInstance(unflattened, Point)
453
454        run_test(Point(1.0, 2))
455        run_test(Point(torch.tensor(1.0), 2))
456
457    @parametrize(
458        "op",
459        [
460            subtest(torch.max, name="max"),
461            subtest(torch.min, name="min"),
462        ],
463    )
464    @parametrize(
465        "pytree_impl",
466        [
467            subtest(py_pytree, name="py"),
468            subtest(cxx_pytree, name="cxx"),
469        ],
470    )
471    def test_flatten_unflatten_return_types(self, pytree_impl, op):
472        x = torch.randn(3, 3)
473        expected = op(x, dim=0)
474
475        values, spec = pytree_impl.tree_flatten(expected)
476        # Check that values is actually List[Tensor] and not (ReturnType(...),)
477        for value in values:
478            self.assertIsInstance(value, torch.Tensor)
479        result = pytree_impl.tree_unflatten(values, spec)
480
481        self.assertEqual(type(result), type(expected))
482        self.assertEqual(result, expected)
483
484    @parametrize(
485        "pytree_impl",
486        [
487            subtest(py_pytree, name="py"),
488            subtest(cxx_pytree, name="cxx"),
489        ],
490    )
491    def test_flatten_unflatten_nested(self, pytree_impl):
492        def run_test(pytree):
493            values, treespec = pytree_impl.tree_flatten(pytree)
494            self.assertIsInstance(values, list)
495            self.assertEqual(len(values), treespec.num_leaves)
496
497            # NB: python basic data structures (dict list tuple) all have
498            # contents equality defined on them, so the following works for them.
499            unflattened = pytree_impl.tree_unflatten(values, treespec)
500            self.assertEqual(unflattened, pytree)
501
502        cases = [
503            [()],
504            ([],),
505            {"a": ()},
506            {"a": 0, "b": [{"c": 1}]},
507            {"a": 0, "b": [1, {"c": 2}, torch.randn(3)], "c": (torch.randn(2, 3), 1)},
508        ]
509        for case in cases:
510            run_test(case)
511
512    @parametrize(
513        "pytree_impl",
514        [
515            subtest(py_pytree, name="py"),
516            subtest(cxx_pytree, name="cxx"),
517        ],
518    )
519    def test_flatten_with_is_leaf(self, pytree_impl):
520        def run_test(pytree, one_level_leaves):
521            values, treespec = pytree_impl.tree_flatten(
522                pytree, is_leaf=lambda x: x is not pytree
523            )
524            self.assertIsInstance(values, list)
525            self.assertEqual(len(values), treespec.num_nodes - 1)
526            self.assertEqual(len(values), treespec.num_leaves)
527            self.assertEqual(len(values), treespec.num_children)
528            self.assertEqual(values, one_level_leaves)
529
530            self.assertEqual(
531                treespec,
532                pytree_impl.tree_structure(
533                    pytree_impl.tree_unflatten([0] * treespec.num_leaves, treespec)
534                ),
535            )
536
537            unflattened = pytree_impl.tree_unflatten(values, treespec)
538            self.assertEqual(unflattened, pytree)
539
540        cases = [
541            ([()], [()]),
542            (([],), [[]]),
543            ({"a": ()}, [()]),
544            ({"a": 0, "b": [{"c": 1}]}, [0, [{"c": 1}]]),
545            (
546                {
547                    "a": 0,
548                    "b": [1, {"c": 2}, torch.ones(3)],
549                    "c": (torch.zeros(2, 3), 1),
550                },
551                [0, [1, {"c": 2}, torch.ones(3)], (torch.zeros(2, 3), 1)],
552            ),
553        ]
554        for case in cases:
555            run_test(*case)
556
557    @parametrize(
558        "pytree_impl",
559        [
560            subtest(py_pytree, name="py"),
561            subtest(cxx_pytree, name="cxx"),
562        ],
563    )
564    def test_tree_map(self, pytree_impl):
565        def run_test(pytree):
566            def f(x):
567                return x * 3
568
569            sm1 = sum(map(f, pytree_impl.tree_leaves(pytree)))
570            sm2 = sum(pytree_impl.tree_leaves(pytree_impl.tree_map(f, pytree)))
571            self.assertEqual(sm1, sm2)
572
573            def invf(x):
574                return x // 3
575
576            self.assertEqual(
577                pytree_impl.tree_map(invf, pytree_impl.tree_map(f, pytree)),
578                pytree,
579            )
580
581        cases = [
582            [()],
583            ([],),
584            {"a": ()},
585            {"a": 1, "b": [{"c": 2}]},
586            {"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)},
587        ]
588        for case in cases:
589            run_test(case)
590
591    @parametrize(
592        "pytree_impl",
593        [
594            subtest(py_pytree, name="py"),
595            subtest(cxx_pytree, name="cxx"),
596        ],
597    )
598    def test_tree_map_multi_inputs(self, pytree_impl):
599        def run_test(pytree):
600            def f(x, y, z):
601                return x, [y, (z, 0)]
602
603            pytree_x = pytree
604            pytree_y = pytree_impl.tree_map(lambda x: (x + 1,), pytree)
605            pytree_z = pytree_impl.tree_map(lambda x: {"a": x * 2, "b": 2}, pytree)
606
607            self.assertEqual(
608                pytree_impl.tree_map(f, pytree_x, pytree_y, pytree_z),
609                pytree_impl.tree_map(
610                    lambda x: f(x, (x + 1,), {"a": x * 2, "b": 2}), pytree
611                ),
612            )
613
614        cases = [
615            [()],
616            ([],),
617            {"a": ()},
618            {"a": 1, "b": [{"c": 2}]},
619            {"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)},
620        ]
621        for case in cases:
622            run_test(case)
623
624    @parametrize(
625        "pytree_impl",
626        [
627            subtest(py_pytree, name="py"),
628            subtest(cxx_pytree, name="cxx"),
629        ],
630    )
631    def test_tree_map_only(self, pytree_impl):
632        self.assertEqual(
633            pytree_impl.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]
634        )
635
636    @parametrize(
637        "pytree_impl",
638        [
639            subtest(py_pytree, name="py"),
640            subtest(cxx_pytree, name="cxx"),
641        ],
642    )
643    def test_tree_map_only_predicate_fn(self, pytree_impl):
644        self.assertEqual(
645            pytree_impl.tree_map_only(lambda x: x == 0, lambda x: x + 2, [0, 1]), [2, 1]
646        )
647
648    @parametrize(
649        "pytree_impl",
650        [
651            subtest(py_pytree, name="py"),
652            subtest(cxx_pytree, name="cxx"),
653        ],
654    )
655    def test_tree_all_any(self, pytree_impl):
656        self.assertTrue(pytree_impl.tree_all(lambda x: x % 2, [1, 3]))
657        self.assertFalse(pytree_impl.tree_all(lambda x: x % 2, [0, 1]))
658        self.assertTrue(pytree_impl.tree_any(lambda x: x % 2, [0, 1]))
659        self.assertFalse(pytree_impl.tree_any(lambda x: x % 2, [0, 2]))
660        self.assertTrue(pytree_impl.tree_all_only(int, lambda x: x % 2, [1, 3, "a"]))
661        self.assertFalse(pytree_impl.tree_all_only(int, lambda x: x % 2, [0, 1, "a"]))
662        self.assertTrue(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 1, "a"]))
663        self.assertFalse(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 2, "a"]))
664
665    @parametrize(
666        "pytree_impl",
667        [
668            subtest(py_pytree, name="py"),
669            subtest(cxx_pytree, name="cxx"),
670        ],
671    )
672    def test_broadcast_to_and_flatten(self, pytree_impl):
673        cases = [
674            (1, (), []),
675            # Same (flat) structures
676            ((1,), (0,), [1]),
677            ([1], [0], [1]),
678            ((1, 2, 3), (0, 0, 0), [1, 2, 3]),
679            ({"a": 1, "b": 2}, {"a": 0, "b": 0}, [1, 2]),
680            # Mismatched (flat) structures
681            ([1], (0,), None),
682            ([1], (0,), None),
683            ((1,), [0], None),
684            ((1, 2, 3), (0, 0), None),
685            ({"a": 1, "b": 2}, {"a": 0}, None),
686            ({"a": 1, "b": 2}, {"a": 0, "c": 0}, None),
687            ({"a": 1, "b": 2}, {"a": 0, "b": 0, "c": 0}, None),
688            # Same (nested) structures
689            ((1, [2, 3]), (0, [0, 0]), [1, 2, 3]),
690            ((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]),
691            # Mismatched (nested) structures
692            ((1, [2, 3]), (0, (0, 0)), None),
693            ((1, [2, 3]), (0, [0, 0, 0]), None),
694            # Broadcasting single value
695            (1, (0, 0, 0), [1, 1, 1]),
696            (1, [0, 0, 0], [1, 1, 1]),
697            (1, {"a": 0, "b": 0}, [1, 1]),
698            (1, (0, [0, [0]], 0), [1, 1, 1, 1]),
699            (1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]),
700            # Broadcast multiple things
701            ((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]),
702            ((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]),
703            (([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]),
704        ]
705        for pytree, to_pytree, expected in cases:
706            _, to_spec = pytree_impl.tree_flatten(to_pytree)
707            result = pytree_impl._broadcast_to_and_flatten(pytree, to_spec)
708            self.assertEqual(result, expected, msg=str([pytree, to_spec, expected]))
709
710    @parametrize(
711        "pytree_impl",
712        [
713            subtest(py_pytree, name="py"),
714            subtest(cxx_pytree, name="cxx"),
715        ],
716    )
717    def test_pytree_serialize_bad_input(self, pytree_impl):
718        with self.assertRaises(TypeError):
719            pytree_impl.treespec_dumps("random_blurb")
720
721
722class TestPythonPytree(TestCase):
723    def test_deprecated_register_pytree_node(self):
724        class DummyType:
725            def __init__(self, x, y):
726                self.x = x
727                self.y = y
728
729        with self.assertWarnsRegex(
730            FutureWarning, "torch.utils._pytree._register_pytree_node"
731        ):
732            py_pytree._register_pytree_node(
733                DummyType,
734                lambda dummy: ([dummy.x, dummy.y], None),
735                lambda xs, _: DummyType(*xs),
736            )
737
738        with self.assertWarnsRegex(UserWarning, "already registered"):
739            py_pytree._register_pytree_node(
740                DummyType,
741                lambda dummy: ([dummy.x, dummy.y], None),
742                lambda xs, _: DummyType(*xs),
743            )
744
745    def test_import_pytree_doesnt_import_optree(self):
746        # importing torch.utils._pytree shouldn't import optree.
747        # only importing torch.utils._cxx_pytree should.
748        script = """
749import sys
750import torch
751import torch.utils._pytree
752assert "torch.utils._pytree" in sys.modules
753if "torch.utils._cxx_pytree" in sys.modules:
754    raise RuntimeError("importing torch.utils._pytree should not import torch.utils._cxx_pytree")
755if "optree" in sys.modules:
756    raise RuntimeError("importing torch.utils._pytree should not import optree")
757"""
758        try:
759            subprocess.check_output(
760                [sys.executable, "-c", script],
761                stderr=subprocess.STDOUT,
762                # On Windows, opening the subprocess with the default CWD makes `import torch`
763                # fail, so just set CWD to this script's directory
764                cwd=os.path.dirname(os.path.realpath(__file__)),
765            )
766        except subprocess.CalledProcessError as e:
767            self.fail(
768                msg=(
769                    "Subprocess exception while attempting to run test: "
770                    + e.output.decode("utf-8")
771                )
772            )
773
774    def test_treespec_equality(self):
775        self.assertEqual(
776            py_pytree.LeafSpec(),
777            py_pytree.LeafSpec(),
778        )
779        self.assertEqual(
780            py_pytree.TreeSpec(list, None, []),
781            py_pytree.TreeSpec(list, None, []),
782        )
783        self.assertEqual(
784            py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
785            py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
786        )
787        self.assertFalse(
788            py_pytree.TreeSpec(tuple, None, []) == py_pytree.TreeSpec(list, None, []),
789        )
790        self.assertTrue(
791            py_pytree.TreeSpec(tuple, None, []) != py_pytree.TreeSpec(list, None, []),
792        )
793
794    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
795    def test_treespec_repr(self):
796        # Check that it looks sane
797        pytree = (0, [0, 0, [0]])
798        _, spec = py_pytree.tree_flatten(pytree)
799        self.assertEqual(
800            repr(spec),
801            (
802                "TreeSpec(tuple, None, [*,\n"
803                "  TreeSpec(list, None, [*,\n"
804                "    *,\n"
805                "    TreeSpec(list, None, [*])])])"
806            ),
807        )
808
809    @unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
810    def test_treespec_repr_dynamo(self):
811        # Check that it looks sane
812        pytree = (0, [0, 0, [0]])
813        _, spec = py_pytree.tree_flatten(pytree)
814        self.assertExpectedInline(
815            repr(spec),
816            """\
817TreeSpec(tuple, None, [*,
818  TreeSpec(list, None, [*,
819    *,
820    TreeSpec(list, None, [*])])])""",
821        )
822
823    @parametrize(
824        "spec",
825        [
826            # py_pytree.tree_structure([])
827            py_pytree.TreeSpec(list, None, []),
828            # py_pytree.tree_structure(())
829            py_pytree.TreeSpec(tuple, None, []),
830            # py_pytree.tree_structure({})
831            py_pytree.TreeSpec(dict, [], []),
832            # py_pytree.tree_structure([0])
833            py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]),
834            # py_pytree.tree_structure([0, 1])
835            py_pytree.TreeSpec(
836                list,
837                None,
838                [
839                    py_pytree.LeafSpec(),
840                    py_pytree.LeafSpec(),
841                ],
842            ),
843            # py_pytree.tree_structure((0, 1, 2))
844            py_pytree.TreeSpec(
845                tuple,
846                None,
847                [
848                    py_pytree.LeafSpec(),
849                    py_pytree.LeafSpec(),
850                    py_pytree.LeafSpec(),
851                ],
852            ),
853            # py_pytree.tree_structure({"a": 0, "b": 1, "c": 2})
854            py_pytree.TreeSpec(
855                dict,
856                ["a", "b", "c"],
857                [
858                    py_pytree.LeafSpec(),
859                    py_pytree.LeafSpec(),
860                    py_pytree.LeafSpec(),
861                ],
862            ),
863            # py_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
864            py_pytree.TreeSpec(
865                OrderedDict,
866                ["a", "b", "c"],
867                [
868                    py_pytree.TreeSpec(
869                        tuple,
870                        None,
871                        [
872                            py_pytree.LeafSpec(),
873                            py_pytree.LeafSpec(),
874                        ],
875                    ),
876                    py_pytree.LeafSpec(),
877                    py_pytree.TreeSpec(
878                        dict,
879                        ["a", "b", "c"],
880                        [
881                            py_pytree.LeafSpec(),
882                            py_pytree.LeafSpec(),
883                            py_pytree.LeafSpec(),
884                        ],
885                    ),
886                ],
887            ),
888            # py_pytree.tree_structure([(0, 1, [2, 3])])
889            py_pytree.TreeSpec(
890                list,
891                None,
892                [
893                    py_pytree.TreeSpec(
894                        tuple,
895                        None,
896                        [
897                            py_pytree.LeafSpec(),
898                            py_pytree.LeafSpec(),
899                            py_pytree.TreeSpec(
900                                list,
901                                None,
902                                [
903                                    py_pytree.LeafSpec(),
904                                    py_pytree.LeafSpec(),
905                                ],
906                            ),
907                        ],
908                    ),
909                ],
910            ),
911            # py_pytree.tree_structure(defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}}))
912            py_pytree.TreeSpec(
913                defaultdict,
914                [list, ["a", "b", "c"]],
915                [
916                    py_pytree.TreeSpec(
917                        list,
918                        None,
919                        [
920                            py_pytree.LeafSpec(),
921                            py_pytree.LeafSpec(),
922                        ],
923                    ),
924                    py_pytree.TreeSpec(
925                        list,
926                        None,
927                        [
928                            py_pytree.LeafSpec(),
929                            py_pytree.LeafSpec(),
930                        ],
931                    ),
932                    py_pytree.TreeSpec(dict, [], []),
933                ],
934            ),
935        ],
936    )
937    def test_pytree_serialize(self, spec):
938        # Ensure that the spec is valid
939        self.assertEqual(
940            spec,
941            py_pytree.tree_structure(
942                py_pytree.tree_unflatten([0] * spec.num_leaves, spec)
943            ),
944        )
945
946        serialized_spec = py_pytree.treespec_dumps(spec)
947        self.assertIsInstance(serialized_spec, str)
948        self.assertEqual(spec, py_pytree.treespec_loads(serialized_spec))
949
950    def test_pytree_serialize_namedtuple(self):
951        Point1 = namedtuple("Point1", ["x", "y"])
952        py_pytree._register_namedtuple(
953            Point1,
954            serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point1",
955        )
956
957        spec = py_pytree.TreeSpec(
958            namedtuple, Point1, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
959        )
960        roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
961        self.assertEqual(spec, roundtrip_spec)
962
963        class Point2(NamedTuple):
964            x: int
965            y: int
966
967        py_pytree._register_namedtuple(
968            Point2,
969            serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point2",
970        )
971
972        spec = py_pytree.TreeSpec(
973            namedtuple, Point2, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
974        )
975        roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec))
976        self.assertEqual(spec, roundtrip_spec)
977
978    def test_pytree_serialize_namedtuple_bad(self):
979        DummyType = namedtuple("DummyType", ["x", "y"])
980
981        spec = py_pytree.TreeSpec(
982            namedtuple, DummyType, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
983        )
984
985        with self.assertRaisesRegex(
986            NotImplementedError, "Please register using `_register_namedtuple`"
987        ):
988            py_pytree.treespec_dumps(spec)
989
990    def test_pytree_custom_type_serialize_bad(self):
991        class DummyType:
992            def __init__(self, x, y):
993                self.x = x
994                self.y = y
995
996        py_pytree.register_pytree_node(
997            DummyType,
998            lambda dummy: ([dummy.x, dummy.y], None),
999            lambda xs, _: DummyType(*xs),
1000        )
1001
1002        spec = py_pytree.TreeSpec(
1003            DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1004        )
1005        with self.assertRaisesRegex(
1006            NotImplementedError, "No registered serialization name"
1007        ):
1008            roundtrip_spec = py_pytree.treespec_dumps(spec)
1009
1010    def test_pytree_custom_type_serialize(self):
1011        class DummyType:
1012            def __init__(self, x, y):
1013                self.x = x
1014                self.y = y
1015
1016        py_pytree.register_pytree_node(
1017            DummyType,
1018            lambda dummy: ([dummy.x, dummy.y], None),
1019            lambda xs, _: DummyType(*xs),
1020            serialized_type_name="test_pytree_custom_type_serialize.DummyType",
1021            to_dumpable_context=lambda context: "moo",
1022            from_dumpable_context=lambda dumpable_context: None,
1023        )
1024        spec = py_pytree.TreeSpec(
1025            DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1026        )
1027        serialized_spec = py_pytree.treespec_dumps(spec, 1)
1028        self.assertIn("moo", serialized_spec)
1029        roundtrip_spec = py_pytree.treespec_loads(serialized_spec)
1030        self.assertEqual(roundtrip_spec, spec)
1031
1032    def test_pytree_serialize_register_bad(self):
1033        class DummyType:
1034            def __init__(self, x, y):
1035                self.x = x
1036                self.y = y
1037
1038        with self.assertRaisesRegex(
1039            ValueError, "Both to_dumpable_context and from_dumpable_context"
1040        ):
1041            py_pytree.register_pytree_node(
1042                DummyType,
1043                lambda dummy: ([dummy.x, dummy.y], None),
1044                lambda xs, _: DummyType(*xs),
1045                serialized_type_name="test_pytree_serialize_register_bad.DummyType",
1046                to_dumpable_context=lambda context: "moo",
1047            )
1048
1049    def test_pytree_context_serialize_bad(self):
1050        class DummyType:
1051            def __init__(self, x, y):
1052                self.x = x
1053                self.y = y
1054
1055        py_pytree.register_pytree_node(
1056            DummyType,
1057            lambda dummy: ([dummy.x, dummy.y], None),
1058            lambda xs, _: DummyType(*xs),
1059            serialized_type_name="test_pytree_serialize_serialize_bad.DummyType",
1060            to_dumpable_context=lambda context: DummyType,
1061            from_dumpable_context=lambda dumpable_context: None,
1062        )
1063
1064        spec = py_pytree.TreeSpec(
1065            DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1066        )
1067
1068        with self.assertRaisesRegex(
1069            TypeError, "Object of type type is not JSON serializable"
1070        ):
1071            py_pytree.treespec_dumps(spec)
1072
1073    def test_pytree_serialize_bad_protocol(self):
1074        import json
1075
1076        Point = namedtuple("Point", ["x", "y"])
1077        spec = py_pytree.TreeSpec(
1078            namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1079        )
1080        py_pytree._register_namedtuple(
1081            Point,
1082            serialized_type_name="test_pytree.test_pytree_serialize_bad_protocol.Point",
1083        )
1084
1085        with self.assertRaisesRegex(ValueError, "Unknown protocol"):
1086            py_pytree.treespec_dumps(spec, -1)
1087
1088        serialized_spec = py_pytree.treespec_dumps(spec)
1089        protocol, data = json.loads(serialized_spec)
1090        bad_protocol_serialized_spec = json.dumps((-1, data))
1091
1092        with self.assertRaisesRegex(ValueError, "Unknown protocol"):
1093            py_pytree.treespec_loads(bad_protocol_serialized_spec)
1094
1095    def test_saved_serialized(self):
1096        # py_pytree.tree_structure(OrderedDict([(1, (0, 1)), (2, 2), (3, {4: 3, 5: 4, 6: 5})]))
1097        complicated_spec = py_pytree.TreeSpec(
1098            OrderedDict,
1099            [1, 2, 3],
1100            [
1101                py_pytree.TreeSpec(
1102                    tuple, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()]
1103                ),
1104                py_pytree.LeafSpec(),
1105                py_pytree.TreeSpec(
1106                    dict,
1107                    [4, 5, 6],
1108                    [
1109                        py_pytree.LeafSpec(),
1110                        py_pytree.LeafSpec(),
1111                        py_pytree.LeafSpec(),
1112                    ],
1113                ),
1114            ],
1115        )
1116        # Ensure that the spec is valid
1117        self.assertEqual(
1118            complicated_spec,
1119            py_pytree.tree_structure(
1120                py_pytree.tree_unflatten(
1121                    [0] * complicated_spec.num_leaves, complicated_spec
1122                )
1123            ),
1124        )
1125
1126        serialized_spec = py_pytree.treespec_dumps(complicated_spec)
1127        saved_spec = (
1128            '[1, {"type": "collections.OrderedDict", "context": "[1, 2, 3]", '
1129            '"children_spec": [{"type": "builtins.tuple", "context": "null", '
1130            '"children_spec": [{"type": null, "context": null, '
1131            '"children_spec": []}, {"type": null, "context": null, '
1132            '"children_spec": []}]}, {"type": null, "context": null, '
1133            '"children_spec": []}, {"type": "builtins.dict", "context": '
1134            '"[4, 5, 6]", "children_spec": [{"type": null, "context": null, '
1135            '"children_spec": []}, {"type": null, "context": null, "children_spec": '
1136            '[]}, {"type": null, "context": null, "children_spec": []}]}]}]'
1137        )
1138        self.assertEqual(serialized_spec, saved_spec)
1139        self.assertEqual(complicated_spec, py_pytree.treespec_loads(saved_spec))
1140
1141    def test_tree_map_with_path(self):
1142        tree = [{i: i for i in range(10)}]
1143        all_zeros = py_pytree.tree_map_with_path(
1144            lambda kp, val: val - kp[1].key + kp[0].idx, tree
1145        )
1146        self.assertEqual(all_zeros, [dict.fromkeys(range(10), 0)])
1147
1148    def test_tree_map_with_path_multiple_trees(self):
1149        @dataclass
1150        class ACustomPytree:
1151            x: Any
1152            y: Any
1153            z: Any
1154
1155        tree1 = [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5]
1156        tree2 = [ACustomPytree(x=2, y={"cin": [2, 2, 2], "bar": 2}, z="leaf"), 2]
1157
1158        py_pytree.register_pytree_node(
1159            ACustomPytree,
1160            flatten_fn=lambda f: ([f.x, f.y], f.z),
1161            unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
1162            flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
1163        )
1164        from_two_trees = py_pytree.tree_map_with_path(
1165            lambda kp, a, b: a + b, tree1, tree2
1166        )
1167        from_one_tree = py_pytree.tree_map(lambda a: a + 2, tree1)
1168        self.assertEqual(from_two_trees, from_one_tree)
1169
1170    @skipIfTorchDynamo("dynamo pytree tracing doesn't work here")
1171    def test_tree_flatten_with_path_is_leaf(self):
1172        leaf_dict = {"foo": [(3)]}
1173        pytree = (["hello", [1, 2], leaf_dict],)
1174        key_leaves, spec = py_pytree.tree_flatten_with_path(
1175            pytree, is_leaf=lambda x: isinstance(x, dict)
1176        )
1177        self.assertTrue(key_leaves[-1][1] is leaf_dict)
1178
1179    def test_tree_flatten_with_path_roundtrip(self):
1180        class ANamedTuple(NamedTuple):
1181            x: torch.Tensor
1182            y: int
1183            z: str
1184
1185        @dataclass
1186        class ACustomPytree:
1187            x: Any
1188            y: Any
1189            z: Any
1190
1191        py_pytree.register_pytree_node(
1192            ACustomPytree,
1193            flatten_fn=lambda f: ([f.x, f.y], f.z),
1194            unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
1195            flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
1196        )
1197
1198        SOME_PYTREES = [
1199            (None,),
1200            ["hello", [1, 2], {"foo": [(3)]}],
1201            [ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")],
1202            [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5],
1203        ]
1204        for pytree in SOME_PYTREES:
1205            key_leaves, spec = py_pytree.tree_flatten_with_path(pytree)
1206            actual = py_pytree.tree_unflatten([leaf for _, leaf in key_leaves], spec)
1207            self.assertEqual(actual, pytree)
1208
1209    def test_tree_leaves_with_path(self):
1210        class ANamedTuple(NamedTuple):
1211            x: torch.Tensor
1212            y: int
1213            z: str
1214
1215        @dataclass
1216        class ACustomPytree:
1217            x: Any
1218            y: Any
1219            z: Any
1220
1221        py_pytree.register_pytree_node(
1222            ACustomPytree,
1223            flatten_fn=lambda f: ([f.x, f.y], f.z),
1224            unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z),
1225            flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z),
1226        )
1227
1228        SOME_PYTREES = [
1229            (None,),
1230            ["hello", [1, 2], {"foo": [(3)]}],
1231            [ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")],
1232            [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5],
1233        ]
1234        for pytree in SOME_PYTREES:
1235            flat_out, _ = py_pytree.tree_flatten_with_path(pytree)
1236            leaves_out = py_pytree.tree_leaves_with_path(pytree)
1237            self.assertEqual(flat_out, leaves_out)
1238
1239    def test_key_str(self):
1240        class ANamedTuple(NamedTuple):
1241            x: str
1242            y: int
1243
1244        tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],)
1245        flat, _ = py_pytree.tree_flatten_with_path(tree)
1246        paths = [f"{py_pytree.keystr(kp)}: {val}" for kp, val in flat]
1247        self.assertEqual(
1248            paths,
1249            [
1250                "[0][0]: hello",
1251                "[0][1][0]: 1",
1252                "[0][1][1]: 2",
1253                "[0][2]['foo'][0]: 3",
1254                "[0][2]['bar'][0].x: baz",
1255                "[0][2]['bar'][0].y: 10",
1256            ],
1257        )
1258
1259    @skipIfTorchDynamo("AssertionError in dynamo")
1260    def test_flatten_flatten_with_key_consistency(self):
1261        """Check that flatten and flatten_with_key produces consistent leaves/context."""
1262        reg = py_pytree.SUPPORTED_NODES
1263
1264        EXAMPLE_TREE = {
1265            list: [1, 2, 3],
1266            tuple: (1, 2, 3),
1267            dict: {"foo": 1, "bar": 2},
1268            namedtuple: collections.namedtuple("ANamedTuple", ["x", "y"])(1, 2),
1269            OrderedDict: OrderedDict([("foo", 1), ("bar", 2)]),
1270            defaultdict: defaultdict(int, {"foo": 1, "bar": 2}),
1271            deque: deque([1, 2, 3]),
1272            torch.Size: torch.Size([1, 2, 3]),
1273            immutable_dict: immutable_dict({"foo": 1, "bar": 2}),
1274            immutable_list: immutable_list([1, 2, 3]),
1275        }
1276
1277        for typ in reg:
1278            example = EXAMPLE_TREE.get(typ)
1279            if example is None:
1280                continue
1281            flat_with_path, spec1 = py_pytree.tree_flatten_with_path(example)
1282            flat, spec2 = py_pytree.tree_flatten(example)
1283
1284            self.assertEqual(flat, [x[1] for x in flat_with_path])
1285            self.assertEqual(spec1, spec2)
1286
1287    def test_key_access(self):
1288        class ANamedTuple(NamedTuple):
1289            x: str
1290            y: int
1291
1292        tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],)
1293        flat, _ = py_pytree.tree_flatten_with_path(tree)
1294        for kp, val in flat:
1295            self.assertEqual(py_pytree.key_get(tree, kp), val)
1296
1297
1298class TestCxxPytree(TestCase):
1299    def setUp(self):
1300        if IS_FBCODE:
1301            raise unittest.SkipTest("C++ pytree tests are not supported in fbcode")
1302
1303    def test_treespec_equality(self):
1304        self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec())
1305
1306    @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.")
1307    def test_treespec_repr(self):
1308        # Check that it looks sane
1309        pytree = (0, [0, 0, [0]])
1310        _, spec = cxx_pytree.tree_flatten(pytree)
1311        self.assertEqual(repr(spec), "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)")
1312
1313    @unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.")
1314    def test_treespec_repr_dynamo(self):
1315        # Check that it looks sane
1316        pytree = (0, [0, 0, [0]])
1317        _, spec = cxx_pytree.tree_flatten(pytree)
1318        self.assertExpectedInline(
1319            repr(spec),
1320            "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)",
1321        )
1322
1323    @parametrize(
1324        "spec",
1325        [
1326            cxx_pytree.tree_structure([]),
1327            cxx_pytree.tree_structure(()),
1328            cxx_pytree.tree_structure({}),
1329            cxx_pytree.tree_structure([0]),
1330            cxx_pytree.tree_structure([0, 1]),
1331            cxx_pytree.tree_structure((0, 1, 2)),
1332            cxx_pytree.tree_structure({"a": 0, "b": 1, "c": 2}),
1333            cxx_pytree.tree_structure(
1334                OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
1335            ),
1336            cxx_pytree.tree_structure([(0, 1, [2, 3])]),
1337            cxx_pytree.tree_structure(
1338                defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}})
1339            ),
1340        ],
1341    )
1342    def test_pytree_serialize(self, spec):
1343        self.assertEqual(
1344            spec,
1345            cxx_pytree.tree_structure(
1346                cxx_pytree.tree_unflatten([0] * spec.num_leaves, spec)
1347            ),
1348        )
1349
1350        serialized_spec = cxx_pytree.treespec_dumps(spec)
1351        self.assertIsInstance(serialized_spec, str)
1352        self.assertEqual(spec, cxx_pytree.treespec_loads(serialized_spec))
1353
1354    def test_pytree_serialize_namedtuple(self):
1355        py_pytree._register_namedtuple(
1356            GlobalPoint,
1357            serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.GlobalPoint",
1358        )
1359        spec = cxx_pytree.tree_structure(GlobalPoint(0, 1))
1360
1361        roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
1362        self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)
1363
1364        LocalPoint = namedtuple("LocalPoint", ["x", "y"])
1365        py_pytree._register_namedtuple(
1366            LocalPoint,
1367            serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.LocalPoint",
1368        )
1369        spec = cxx_pytree.tree_structure(LocalPoint(0, 1))
1370
1371        roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec))
1372        self.assertEqual(roundtrip_spec.type._fields, spec.type._fields)
1373
1374    def test_pytree_custom_type_serialize(self):
1375        cxx_pytree.register_pytree_node(
1376            GlobalDummyType,
1377            lambda dummy: ([dummy.x, dummy.y], None),
1378            lambda xs, _: GlobalDummyType(*xs),
1379            serialized_type_name="GlobalDummyType",
1380        )
1381        spec = cxx_pytree.tree_structure(GlobalDummyType(0, 1))
1382        serialized_spec = cxx_pytree.treespec_dumps(spec)
1383        roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
1384        self.assertEqual(roundtrip_spec, spec)
1385
1386        class LocalDummyType:
1387            def __init__(self, x, y):
1388                self.x = x
1389                self.y = y
1390
1391        cxx_pytree.register_pytree_node(
1392            LocalDummyType,
1393            lambda dummy: ([dummy.x, dummy.y], None),
1394            lambda xs, _: LocalDummyType(*xs),
1395            serialized_type_name="LocalDummyType",
1396        )
1397        spec = cxx_pytree.tree_structure(LocalDummyType(0, 1))
1398        serialized_spec = cxx_pytree.treespec_dumps(spec)
1399        roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec)
1400        self.assertEqual(roundtrip_spec, spec)
1401
1402
1403instantiate_parametrized_tests(TestGenericPytree)
1404instantiate_parametrized_tests(TestPythonPytree)
1405instantiate_parametrized_tests(TestCxxPytree)
1406
1407
1408if __name__ == "__main__":
1409    run_tests()
1410