1# Owner(s): ["module: pytree"] 2 3import collections 4import inspect 5import os 6import re 7import subprocess 8import sys 9import unittest 10from collections import defaultdict, deque, namedtuple, OrderedDict, UserDict 11from dataclasses import dataclass 12from typing import Any, NamedTuple 13 14import torch 15import torch.utils._pytree as py_pytree 16from torch.fx.immutable_collections import immutable_dict, immutable_list 17from torch.testing._internal.common_utils import ( 18 instantiate_parametrized_tests, 19 IS_FBCODE, 20 parametrize, 21 run_tests, 22 skipIfTorchDynamo, 23 subtest, 24 TEST_WITH_TORCHDYNAMO, 25 TestCase, 26) 27 28 29if IS_FBCODE: 30 # optree is not yet enabled in fbcode, so just re-test the python implementation 31 cxx_pytree = py_pytree 32else: 33 import torch.utils._cxx_pytree as cxx_pytree 34 35GlobalPoint = namedtuple("GlobalPoint", ["x", "y"]) 36 37 38class GlobalDummyType: 39 def __init__(self, x, y): 40 self.x = x 41 self.y = y 42 43 44class TestGenericPytree(TestCase): 45 def test_aligned_public_apis(self): 46 public_apis = py_pytree.__all__ 47 48 self.assertEqual(public_apis, cxx_pytree.__all__) 49 50 for name in public_apis: 51 cxx_api = getattr(cxx_pytree, name) 52 py_api = getattr(py_pytree, name) 53 54 self.assertEqual(inspect.isclass(cxx_api), inspect.isclass(py_api)) 55 self.assertEqual(inspect.isfunction(cxx_api), inspect.isfunction(py_api)) 56 if inspect.isfunction(cxx_api): 57 cxx_signature = inspect.signature(cxx_api) 58 py_signature = inspect.signature(py_api) 59 60 # Check the parameter names are the same. 61 cxx_param_names = list(cxx_signature.parameters) 62 py_param_names = list(py_signature.parameters) 63 self.assertEqual(cxx_param_names, py_param_names) 64 65 # Check the positional parameters are the same. 66 cxx_positional_param_names = [ 67 n 68 for n, p in cxx_signature.parameters.items() 69 if ( 70 p.kind 71 in { 72 inspect.Parameter.POSITIONAL_ONLY, 73 inspect.Parameter.POSITIONAL_OR_KEYWORD, 74 } 75 ) 76 ] 77 py_positional_param_names = [ 78 n 79 for n, p in py_signature.parameters.items() 80 if ( 81 p.kind 82 in { 83 inspect.Parameter.POSITIONAL_ONLY, 84 inspect.Parameter.POSITIONAL_OR_KEYWORD, 85 } 86 ) 87 ] 88 self.assertEqual(cxx_positional_param_names, py_positional_param_names) 89 90 for py_name, py_param in py_signature.parameters.items(): 91 self.assertIn(py_name, cxx_signature.parameters) 92 cxx_param = cxx_signature.parameters[py_name] 93 94 # Check parameter kinds and default values are the same. 95 self.assertEqual(cxx_param.kind, py_param.kind) 96 self.assertEqual(cxx_param.default, py_param.default) 97 98 # Check parameter annotations are the same. 99 if "TreeSpec" in str(cxx_param.annotation): 100 self.assertIn("TreeSpec", str(py_param.annotation)) 101 self.assertEqual( 102 re.sub( 103 r"(?:\b)([\w\.]*)TreeSpec(?:\b)", 104 "TreeSpec", 105 str(cxx_param.annotation), 106 ), 107 re.sub( 108 r"(?:\b)([\w\.]*)TreeSpec(?:\b)", 109 "TreeSpec", 110 str(py_param.annotation), 111 ), 112 msg=( 113 f"C++ parameter {cxx_param} " 114 f"does not match Python parameter {py_param} " 115 f"for API `{name}`" 116 ), 117 ) 118 else: 119 self.assertEqual( 120 cxx_param.annotation, 121 py_param.annotation, 122 msg=( 123 f"C++ parameter {cxx_param} " 124 f"does not match Python parameter {py_param} " 125 f"for API `{name}`" 126 ), 127 ) 128 129 @parametrize( 130 "pytree_impl", 131 [ 132 subtest(py_pytree, name="py"), 133 subtest(cxx_pytree, name="cxx"), 134 ], 135 ) 136 def test_register_pytree_node(self, pytree_impl): 137 class MyDict(UserDict): 138 pass 139 140 d = MyDict(a=1, b=2, c=3) 141 142 # Custom types are leaf nodes by default 143 values, spec = pytree_impl.tree_flatten(d) 144 self.assertEqual(values, [d]) 145 self.assertIs(values[0], d) 146 self.assertEqual(d, pytree_impl.tree_unflatten(values, spec)) 147 self.assertTrue(spec.is_leaf()) 148 149 # Register MyDict as a pytree node 150 pytree_impl.register_pytree_node( 151 MyDict, 152 lambda d: (list(d.values()), list(d.keys())), 153 lambda values, keys: MyDict(zip(keys, values)), 154 ) 155 156 values, spec = pytree_impl.tree_flatten(d) 157 self.assertEqual(values, [1, 2, 3]) 158 self.assertEqual(d, pytree_impl.tree_unflatten(values, spec)) 159 160 # Do not allow registering the same type twice 161 with self.assertRaisesRegex(ValueError, "already registered"): 162 pytree_impl.register_pytree_node( 163 MyDict, 164 lambda d: (list(d.values()), list(d.keys())), 165 lambda values, keys: MyDict(zip(keys, values)), 166 ) 167 168 @parametrize( 169 "pytree_impl", 170 [ 171 subtest(py_pytree, name="py"), 172 subtest(cxx_pytree, name="cxx"), 173 ], 174 ) 175 def test_flatten_unflatten_leaf(self, pytree_impl): 176 def run_test_with_leaf(leaf): 177 values, treespec = pytree_impl.tree_flatten(leaf) 178 self.assertEqual(values, [leaf]) 179 self.assertEqual(treespec, pytree_impl.LeafSpec()) 180 181 unflattened = pytree_impl.tree_unflatten(values, treespec) 182 self.assertEqual(unflattened, leaf) 183 184 run_test_with_leaf(1) 185 run_test_with_leaf(1.0) 186 run_test_with_leaf(None) 187 run_test_with_leaf(bool) 188 run_test_with_leaf(torch.randn(3, 3)) 189 190 @parametrize( 191 "pytree_impl,gen_expected_fn", 192 [ 193 subtest( 194 ( 195 py_pytree, 196 lambda tup: py_pytree.TreeSpec( 197 tuple, None, [py_pytree.LeafSpec() for _ in tup] 198 ), 199 ), 200 name="py", 201 ), 202 subtest( 203 (cxx_pytree, lambda tup: cxx_pytree.tree_structure((0,) * len(tup))), 204 name="cxx", 205 ), 206 ], 207 ) 208 def test_flatten_unflatten_tuple(self, pytree_impl, gen_expected_fn): 209 def run_test(tup): 210 expected_spec = gen_expected_fn(tup) 211 values, treespec = pytree_impl.tree_flatten(tup) 212 self.assertIsInstance(values, list) 213 self.assertEqual(values, list(tup)) 214 self.assertEqual(treespec, expected_spec) 215 216 unflattened = pytree_impl.tree_unflatten(values, treespec) 217 self.assertEqual(unflattened, tup) 218 self.assertIsInstance(unflattened, tuple) 219 220 run_test(()) 221 run_test((1.0,)) 222 run_test((1.0, 2)) 223 run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11)) 224 225 @parametrize( 226 "pytree_impl,gen_expected_fn", 227 [ 228 subtest( 229 ( 230 py_pytree, 231 lambda lst: py_pytree.TreeSpec( 232 list, None, [py_pytree.LeafSpec() for _ in lst] 233 ), 234 ), 235 name="py", 236 ), 237 subtest( 238 (cxx_pytree, lambda lst: cxx_pytree.tree_structure([0] * len(lst))), 239 name="cxx", 240 ), 241 ], 242 ) 243 def test_flatten_unflatten_list(self, pytree_impl, gen_expected_fn): 244 def run_test(lst): 245 expected_spec = gen_expected_fn(lst) 246 values, treespec = pytree_impl.tree_flatten(lst) 247 self.assertIsInstance(values, list) 248 self.assertEqual(values, lst) 249 self.assertEqual(treespec, expected_spec) 250 251 unflattened = pytree_impl.tree_unflatten(values, treespec) 252 self.assertEqual(unflattened, lst) 253 self.assertIsInstance(unflattened, list) 254 255 run_test([]) 256 run_test([1.0, 2]) 257 run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11]) 258 259 @parametrize( 260 "pytree_impl,gen_expected_fn", 261 [ 262 subtest( 263 ( 264 py_pytree, 265 lambda dct: py_pytree.TreeSpec( 266 dict, 267 list(dct.keys()), 268 [py_pytree.LeafSpec() for _ in dct.values()], 269 ), 270 ), 271 name="py", 272 ), 273 subtest( 274 ( 275 cxx_pytree, 276 lambda dct: cxx_pytree.tree_structure(dict.fromkeys(dct, 0)), 277 ), 278 name="cxx", 279 ), 280 ], 281 ) 282 def test_flatten_unflatten_dict(self, pytree_impl, gen_expected_fn): 283 def run_test(dct): 284 expected_spec = gen_expected_fn(dct) 285 values, treespec = pytree_impl.tree_flatten(dct) 286 self.assertIsInstance(values, list) 287 self.assertEqual(values, list(dct.values())) 288 self.assertEqual(treespec, expected_spec) 289 290 unflattened = pytree_impl.tree_unflatten(values, treespec) 291 self.assertEqual(unflattened, dct) 292 self.assertIsInstance(unflattened, dict) 293 294 run_test({}) 295 run_test({"a": 1}) 296 run_test({"abcdefg": torch.randn(2, 3)}) 297 run_test({1: torch.randn(2, 3)}) 298 run_test({"a": 1, "b": 2, "c": torch.randn(2, 3)}) 299 300 @parametrize( 301 "pytree_impl,gen_expected_fn", 302 [ 303 subtest( 304 ( 305 py_pytree, 306 lambda odict: py_pytree.TreeSpec( 307 OrderedDict, 308 list(odict.keys()), 309 [py_pytree.LeafSpec() for _ in odict.values()], 310 ), 311 ), 312 name="py", 313 ), 314 subtest( 315 ( 316 cxx_pytree, 317 lambda odict: cxx_pytree.tree_structure( 318 OrderedDict.fromkeys(odict, 0) 319 ), 320 ), 321 name="cxx", 322 ), 323 ], 324 ) 325 def test_flatten_unflatten_ordereddict(self, pytree_impl, gen_expected_fn): 326 def run_test(odict): 327 expected_spec = gen_expected_fn(odict) 328 values, treespec = pytree_impl.tree_flatten(odict) 329 self.assertIsInstance(values, list) 330 self.assertEqual(values, list(odict.values())) 331 self.assertEqual(treespec, expected_spec) 332 333 unflattened = pytree_impl.tree_unflatten(values, treespec) 334 self.assertEqual(unflattened, odict) 335 self.assertIsInstance(unflattened, OrderedDict) 336 337 od = OrderedDict() 338 run_test(od) 339 340 od["b"] = 1 341 od["a"] = torch.tensor(3.14) 342 run_test(od) 343 344 @parametrize( 345 "pytree_impl,gen_expected_fn", 346 [ 347 subtest( 348 ( 349 py_pytree, 350 lambda ddct: py_pytree.TreeSpec( 351 defaultdict, 352 [ddct.default_factory, list(ddct.keys())], 353 [py_pytree.LeafSpec() for _ in ddct.values()], 354 ), 355 ), 356 name="py", 357 ), 358 subtest( 359 ( 360 cxx_pytree, 361 lambda ddct: cxx_pytree.tree_structure( 362 defaultdict(ddct.default_factory, dict.fromkeys(ddct, 0)) 363 ), 364 ), 365 name="cxx", 366 ), 367 ], 368 ) 369 def test_flatten_unflatten_defaultdict(self, pytree_impl, gen_expected_fn): 370 def run_test(ddct): 371 expected_spec = gen_expected_fn(ddct) 372 values, treespec = pytree_impl.tree_flatten(ddct) 373 self.assertIsInstance(values, list) 374 self.assertEqual(values, list(ddct.values())) 375 self.assertEqual(treespec, expected_spec) 376 377 unflattened = pytree_impl.tree_unflatten(values, treespec) 378 self.assertEqual(unflattened, ddct) 379 self.assertEqual(unflattened.default_factory, ddct.default_factory) 380 self.assertIsInstance(unflattened, defaultdict) 381 382 run_test(defaultdict(list, {})) 383 run_test(defaultdict(int, {"a": 1})) 384 run_test(defaultdict(int, {"abcdefg": torch.randn(2, 3)})) 385 run_test(defaultdict(int, {1: torch.randn(2, 3)})) 386 run_test(defaultdict(int, {"a": 1, "b": 2, "c": torch.randn(2, 3)})) 387 388 @parametrize( 389 "pytree_impl,gen_expected_fn", 390 [ 391 subtest( 392 ( 393 py_pytree, 394 lambda deq: py_pytree.TreeSpec( 395 deque, deq.maxlen, [py_pytree.LeafSpec() for _ in deq] 396 ), 397 ), 398 name="py", 399 ), 400 subtest( 401 ( 402 cxx_pytree, 403 lambda deq: cxx_pytree.tree_structure( 404 deque(deq, maxlen=deq.maxlen) 405 ), 406 ), 407 name="cxx", 408 ), 409 ], 410 ) 411 def test_flatten_unflatten_deque(self, pytree_impl, gen_expected_fn): 412 def run_test(deq): 413 expected_spec = gen_expected_fn(deq) 414 values, treespec = pytree_impl.tree_flatten(deq) 415 self.assertIsInstance(values, list) 416 self.assertEqual(values, list(deq)) 417 self.assertEqual(treespec, expected_spec) 418 419 unflattened = pytree_impl.tree_unflatten(values, treespec) 420 self.assertEqual(unflattened, deq) 421 self.assertEqual(unflattened.maxlen, deq.maxlen) 422 self.assertIsInstance(unflattened, deque) 423 424 run_test(deque([])) 425 run_test(deque([1.0, 2])) 426 run_test(deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8)) 427 428 @parametrize( 429 "pytree_impl", 430 [ 431 subtest(py_pytree, name="py"), 432 subtest(cxx_pytree, name="cxx"), 433 ], 434 ) 435 def test_flatten_unflatten_namedtuple(self, pytree_impl): 436 Point = namedtuple("Point", ["x", "y"]) 437 438 def run_test(tup): 439 if pytree_impl is py_pytree: 440 expected_spec = py_pytree.TreeSpec( 441 namedtuple, Point, [py_pytree.LeafSpec() for _ in tup] 442 ) 443 else: 444 expected_spec = cxx_pytree.tree_structure(Point(0, 1)) 445 values, treespec = pytree_impl.tree_flatten(tup) 446 self.assertIsInstance(values, list) 447 self.assertEqual(values, list(tup)) 448 self.assertEqual(treespec, expected_spec) 449 450 unflattened = pytree_impl.tree_unflatten(values, treespec) 451 self.assertEqual(unflattened, tup) 452 self.assertIsInstance(unflattened, Point) 453 454 run_test(Point(1.0, 2)) 455 run_test(Point(torch.tensor(1.0), 2)) 456 457 @parametrize( 458 "op", 459 [ 460 subtest(torch.max, name="max"), 461 subtest(torch.min, name="min"), 462 ], 463 ) 464 @parametrize( 465 "pytree_impl", 466 [ 467 subtest(py_pytree, name="py"), 468 subtest(cxx_pytree, name="cxx"), 469 ], 470 ) 471 def test_flatten_unflatten_return_types(self, pytree_impl, op): 472 x = torch.randn(3, 3) 473 expected = op(x, dim=0) 474 475 values, spec = pytree_impl.tree_flatten(expected) 476 # Check that values is actually List[Tensor] and not (ReturnType(...),) 477 for value in values: 478 self.assertIsInstance(value, torch.Tensor) 479 result = pytree_impl.tree_unflatten(values, spec) 480 481 self.assertEqual(type(result), type(expected)) 482 self.assertEqual(result, expected) 483 484 @parametrize( 485 "pytree_impl", 486 [ 487 subtest(py_pytree, name="py"), 488 subtest(cxx_pytree, name="cxx"), 489 ], 490 ) 491 def test_flatten_unflatten_nested(self, pytree_impl): 492 def run_test(pytree): 493 values, treespec = pytree_impl.tree_flatten(pytree) 494 self.assertIsInstance(values, list) 495 self.assertEqual(len(values), treespec.num_leaves) 496 497 # NB: python basic data structures (dict list tuple) all have 498 # contents equality defined on them, so the following works for them. 499 unflattened = pytree_impl.tree_unflatten(values, treespec) 500 self.assertEqual(unflattened, pytree) 501 502 cases = [ 503 [()], 504 ([],), 505 {"a": ()}, 506 {"a": 0, "b": [{"c": 1}]}, 507 {"a": 0, "b": [1, {"c": 2}, torch.randn(3)], "c": (torch.randn(2, 3), 1)}, 508 ] 509 for case in cases: 510 run_test(case) 511 512 @parametrize( 513 "pytree_impl", 514 [ 515 subtest(py_pytree, name="py"), 516 subtest(cxx_pytree, name="cxx"), 517 ], 518 ) 519 def test_flatten_with_is_leaf(self, pytree_impl): 520 def run_test(pytree, one_level_leaves): 521 values, treespec = pytree_impl.tree_flatten( 522 pytree, is_leaf=lambda x: x is not pytree 523 ) 524 self.assertIsInstance(values, list) 525 self.assertEqual(len(values), treespec.num_nodes - 1) 526 self.assertEqual(len(values), treespec.num_leaves) 527 self.assertEqual(len(values), treespec.num_children) 528 self.assertEqual(values, one_level_leaves) 529 530 self.assertEqual( 531 treespec, 532 pytree_impl.tree_structure( 533 pytree_impl.tree_unflatten([0] * treespec.num_leaves, treespec) 534 ), 535 ) 536 537 unflattened = pytree_impl.tree_unflatten(values, treespec) 538 self.assertEqual(unflattened, pytree) 539 540 cases = [ 541 ([()], [()]), 542 (([],), [[]]), 543 ({"a": ()}, [()]), 544 ({"a": 0, "b": [{"c": 1}]}, [0, [{"c": 1}]]), 545 ( 546 { 547 "a": 0, 548 "b": [1, {"c": 2}, torch.ones(3)], 549 "c": (torch.zeros(2, 3), 1), 550 }, 551 [0, [1, {"c": 2}, torch.ones(3)], (torch.zeros(2, 3), 1)], 552 ), 553 ] 554 for case in cases: 555 run_test(*case) 556 557 @parametrize( 558 "pytree_impl", 559 [ 560 subtest(py_pytree, name="py"), 561 subtest(cxx_pytree, name="cxx"), 562 ], 563 ) 564 def test_tree_map(self, pytree_impl): 565 def run_test(pytree): 566 def f(x): 567 return x * 3 568 569 sm1 = sum(map(f, pytree_impl.tree_leaves(pytree))) 570 sm2 = sum(pytree_impl.tree_leaves(pytree_impl.tree_map(f, pytree))) 571 self.assertEqual(sm1, sm2) 572 573 def invf(x): 574 return x // 3 575 576 self.assertEqual( 577 pytree_impl.tree_map(invf, pytree_impl.tree_map(f, pytree)), 578 pytree, 579 ) 580 581 cases = [ 582 [()], 583 ([],), 584 {"a": ()}, 585 {"a": 1, "b": [{"c": 2}]}, 586 {"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)}, 587 ] 588 for case in cases: 589 run_test(case) 590 591 @parametrize( 592 "pytree_impl", 593 [ 594 subtest(py_pytree, name="py"), 595 subtest(cxx_pytree, name="cxx"), 596 ], 597 ) 598 def test_tree_map_multi_inputs(self, pytree_impl): 599 def run_test(pytree): 600 def f(x, y, z): 601 return x, [y, (z, 0)] 602 603 pytree_x = pytree 604 pytree_y = pytree_impl.tree_map(lambda x: (x + 1,), pytree) 605 pytree_z = pytree_impl.tree_map(lambda x: {"a": x * 2, "b": 2}, pytree) 606 607 self.assertEqual( 608 pytree_impl.tree_map(f, pytree_x, pytree_y, pytree_z), 609 pytree_impl.tree_map( 610 lambda x: f(x, (x + 1,), {"a": x * 2, "b": 2}), pytree 611 ), 612 ) 613 614 cases = [ 615 [()], 616 ([],), 617 {"a": ()}, 618 {"a": 1, "b": [{"c": 2}]}, 619 {"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)}, 620 ] 621 for case in cases: 622 run_test(case) 623 624 @parametrize( 625 "pytree_impl", 626 [ 627 subtest(py_pytree, name="py"), 628 subtest(cxx_pytree, name="cxx"), 629 ], 630 ) 631 def test_tree_map_only(self, pytree_impl): 632 self.assertEqual( 633 pytree_impl.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"] 634 ) 635 636 @parametrize( 637 "pytree_impl", 638 [ 639 subtest(py_pytree, name="py"), 640 subtest(cxx_pytree, name="cxx"), 641 ], 642 ) 643 def test_tree_map_only_predicate_fn(self, pytree_impl): 644 self.assertEqual( 645 pytree_impl.tree_map_only(lambda x: x == 0, lambda x: x + 2, [0, 1]), [2, 1] 646 ) 647 648 @parametrize( 649 "pytree_impl", 650 [ 651 subtest(py_pytree, name="py"), 652 subtest(cxx_pytree, name="cxx"), 653 ], 654 ) 655 def test_tree_all_any(self, pytree_impl): 656 self.assertTrue(pytree_impl.tree_all(lambda x: x % 2, [1, 3])) 657 self.assertFalse(pytree_impl.tree_all(lambda x: x % 2, [0, 1])) 658 self.assertTrue(pytree_impl.tree_any(lambda x: x % 2, [0, 1])) 659 self.assertFalse(pytree_impl.tree_any(lambda x: x % 2, [0, 2])) 660 self.assertTrue(pytree_impl.tree_all_only(int, lambda x: x % 2, [1, 3, "a"])) 661 self.assertFalse(pytree_impl.tree_all_only(int, lambda x: x % 2, [0, 1, "a"])) 662 self.assertTrue(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 1, "a"])) 663 self.assertFalse(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 2, "a"])) 664 665 @parametrize( 666 "pytree_impl", 667 [ 668 subtest(py_pytree, name="py"), 669 subtest(cxx_pytree, name="cxx"), 670 ], 671 ) 672 def test_broadcast_to_and_flatten(self, pytree_impl): 673 cases = [ 674 (1, (), []), 675 # Same (flat) structures 676 ((1,), (0,), [1]), 677 ([1], [0], [1]), 678 ((1, 2, 3), (0, 0, 0), [1, 2, 3]), 679 ({"a": 1, "b": 2}, {"a": 0, "b": 0}, [1, 2]), 680 # Mismatched (flat) structures 681 ([1], (0,), None), 682 ([1], (0,), None), 683 ((1,), [0], None), 684 ((1, 2, 3), (0, 0), None), 685 ({"a": 1, "b": 2}, {"a": 0}, None), 686 ({"a": 1, "b": 2}, {"a": 0, "c": 0}, None), 687 ({"a": 1, "b": 2}, {"a": 0, "b": 0, "c": 0}, None), 688 # Same (nested) structures 689 ((1, [2, 3]), (0, [0, 0]), [1, 2, 3]), 690 ((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]), 691 # Mismatched (nested) structures 692 ((1, [2, 3]), (0, (0, 0)), None), 693 ((1, [2, 3]), (0, [0, 0, 0]), None), 694 # Broadcasting single value 695 (1, (0, 0, 0), [1, 1, 1]), 696 (1, [0, 0, 0], [1, 1, 1]), 697 (1, {"a": 0, "b": 0}, [1, 1]), 698 (1, (0, [0, [0]], 0), [1, 1, 1, 1]), 699 (1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]), 700 # Broadcast multiple things 701 ((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]), 702 ((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]), 703 (([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]), 704 ] 705 for pytree, to_pytree, expected in cases: 706 _, to_spec = pytree_impl.tree_flatten(to_pytree) 707 result = pytree_impl._broadcast_to_and_flatten(pytree, to_spec) 708 self.assertEqual(result, expected, msg=str([pytree, to_spec, expected])) 709 710 @parametrize( 711 "pytree_impl", 712 [ 713 subtest(py_pytree, name="py"), 714 subtest(cxx_pytree, name="cxx"), 715 ], 716 ) 717 def test_pytree_serialize_bad_input(self, pytree_impl): 718 with self.assertRaises(TypeError): 719 pytree_impl.treespec_dumps("random_blurb") 720 721 722class TestPythonPytree(TestCase): 723 def test_deprecated_register_pytree_node(self): 724 class DummyType: 725 def __init__(self, x, y): 726 self.x = x 727 self.y = y 728 729 with self.assertWarnsRegex( 730 FutureWarning, "torch.utils._pytree._register_pytree_node" 731 ): 732 py_pytree._register_pytree_node( 733 DummyType, 734 lambda dummy: ([dummy.x, dummy.y], None), 735 lambda xs, _: DummyType(*xs), 736 ) 737 738 with self.assertWarnsRegex(UserWarning, "already registered"): 739 py_pytree._register_pytree_node( 740 DummyType, 741 lambda dummy: ([dummy.x, dummy.y], None), 742 lambda xs, _: DummyType(*xs), 743 ) 744 745 def test_import_pytree_doesnt_import_optree(self): 746 # importing torch.utils._pytree shouldn't import optree. 747 # only importing torch.utils._cxx_pytree should. 748 script = """ 749import sys 750import torch 751import torch.utils._pytree 752assert "torch.utils._pytree" in sys.modules 753if "torch.utils._cxx_pytree" in sys.modules: 754 raise RuntimeError("importing torch.utils._pytree should not import torch.utils._cxx_pytree") 755if "optree" in sys.modules: 756 raise RuntimeError("importing torch.utils._pytree should not import optree") 757""" 758 try: 759 subprocess.check_output( 760 [sys.executable, "-c", script], 761 stderr=subprocess.STDOUT, 762 # On Windows, opening the subprocess with the default CWD makes `import torch` 763 # fail, so just set CWD to this script's directory 764 cwd=os.path.dirname(os.path.realpath(__file__)), 765 ) 766 except subprocess.CalledProcessError as e: 767 self.fail( 768 msg=( 769 "Subprocess exception while attempting to run test: " 770 + e.output.decode("utf-8") 771 ) 772 ) 773 774 def test_treespec_equality(self): 775 self.assertEqual( 776 py_pytree.LeafSpec(), 777 py_pytree.LeafSpec(), 778 ) 779 self.assertEqual( 780 py_pytree.TreeSpec(list, None, []), 781 py_pytree.TreeSpec(list, None, []), 782 ) 783 self.assertEqual( 784 py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]), 785 py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]), 786 ) 787 self.assertFalse( 788 py_pytree.TreeSpec(tuple, None, []) == py_pytree.TreeSpec(list, None, []), 789 ) 790 self.assertTrue( 791 py_pytree.TreeSpec(tuple, None, []) != py_pytree.TreeSpec(list, None, []), 792 ) 793 794 @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.") 795 def test_treespec_repr(self): 796 # Check that it looks sane 797 pytree = (0, [0, 0, [0]]) 798 _, spec = py_pytree.tree_flatten(pytree) 799 self.assertEqual( 800 repr(spec), 801 ( 802 "TreeSpec(tuple, None, [*,\n" 803 " TreeSpec(list, None, [*,\n" 804 " *,\n" 805 " TreeSpec(list, None, [*])])])" 806 ), 807 ) 808 809 @unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.") 810 def test_treespec_repr_dynamo(self): 811 # Check that it looks sane 812 pytree = (0, [0, 0, [0]]) 813 _, spec = py_pytree.tree_flatten(pytree) 814 self.assertExpectedInline( 815 repr(spec), 816 """\ 817TreeSpec(tuple, None, [*, 818 TreeSpec(list, None, [*, 819 *, 820 TreeSpec(list, None, [*])])])""", 821 ) 822 823 @parametrize( 824 "spec", 825 [ 826 # py_pytree.tree_structure([]) 827 py_pytree.TreeSpec(list, None, []), 828 # py_pytree.tree_structure(()) 829 py_pytree.TreeSpec(tuple, None, []), 830 # py_pytree.tree_structure({}) 831 py_pytree.TreeSpec(dict, [], []), 832 # py_pytree.tree_structure([0]) 833 py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]), 834 # py_pytree.tree_structure([0, 1]) 835 py_pytree.TreeSpec( 836 list, 837 None, 838 [ 839 py_pytree.LeafSpec(), 840 py_pytree.LeafSpec(), 841 ], 842 ), 843 # py_pytree.tree_structure((0, 1, 2)) 844 py_pytree.TreeSpec( 845 tuple, 846 None, 847 [ 848 py_pytree.LeafSpec(), 849 py_pytree.LeafSpec(), 850 py_pytree.LeafSpec(), 851 ], 852 ), 853 # py_pytree.tree_structure({"a": 0, "b": 1, "c": 2}) 854 py_pytree.TreeSpec( 855 dict, 856 ["a", "b", "c"], 857 [ 858 py_pytree.LeafSpec(), 859 py_pytree.LeafSpec(), 860 py_pytree.LeafSpec(), 861 ], 862 ), 863 # py_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})]) 864 py_pytree.TreeSpec( 865 OrderedDict, 866 ["a", "b", "c"], 867 [ 868 py_pytree.TreeSpec( 869 tuple, 870 None, 871 [ 872 py_pytree.LeafSpec(), 873 py_pytree.LeafSpec(), 874 ], 875 ), 876 py_pytree.LeafSpec(), 877 py_pytree.TreeSpec( 878 dict, 879 ["a", "b", "c"], 880 [ 881 py_pytree.LeafSpec(), 882 py_pytree.LeafSpec(), 883 py_pytree.LeafSpec(), 884 ], 885 ), 886 ], 887 ), 888 # py_pytree.tree_structure([(0, 1, [2, 3])]) 889 py_pytree.TreeSpec( 890 list, 891 None, 892 [ 893 py_pytree.TreeSpec( 894 tuple, 895 None, 896 [ 897 py_pytree.LeafSpec(), 898 py_pytree.LeafSpec(), 899 py_pytree.TreeSpec( 900 list, 901 None, 902 [ 903 py_pytree.LeafSpec(), 904 py_pytree.LeafSpec(), 905 ], 906 ), 907 ], 908 ), 909 ], 910 ), 911 # py_pytree.tree_structure(defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}})) 912 py_pytree.TreeSpec( 913 defaultdict, 914 [list, ["a", "b", "c"]], 915 [ 916 py_pytree.TreeSpec( 917 list, 918 None, 919 [ 920 py_pytree.LeafSpec(), 921 py_pytree.LeafSpec(), 922 ], 923 ), 924 py_pytree.TreeSpec( 925 list, 926 None, 927 [ 928 py_pytree.LeafSpec(), 929 py_pytree.LeafSpec(), 930 ], 931 ), 932 py_pytree.TreeSpec(dict, [], []), 933 ], 934 ), 935 ], 936 ) 937 def test_pytree_serialize(self, spec): 938 # Ensure that the spec is valid 939 self.assertEqual( 940 spec, 941 py_pytree.tree_structure( 942 py_pytree.tree_unflatten([0] * spec.num_leaves, spec) 943 ), 944 ) 945 946 serialized_spec = py_pytree.treespec_dumps(spec) 947 self.assertIsInstance(serialized_spec, str) 948 self.assertEqual(spec, py_pytree.treespec_loads(serialized_spec)) 949 950 def test_pytree_serialize_namedtuple(self): 951 Point1 = namedtuple("Point1", ["x", "y"]) 952 py_pytree._register_namedtuple( 953 Point1, 954 serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point1", 955 ) 956 957 spec = py_pytree.TreeSpec( 958 namedtuple, Point1, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] 959 ) 960 roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec)) 961 self.assertEqual(spec, roundtrip_spec) 962 963 class Point2(NamedTuple): 964 x: int 965 y: int 966 967 py_pytree._register_namedtuple( 968 Point2, 969 serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point2", 970 ) 971 972 spec = py_pytree.TreeSpec( 973 namedtuple, Point2, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] 974 ) 975 roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec)) 976 self.assertEqual(spec, roundtrip_spec) 977 978 def test_pytree_serialize_namedtuple_bad(self): 979 DummyType = namedtuple("DummyType", ["x", "y"]) 980 981 spec = py_pytree.TreeSpec( 982 namedtuple, DummyType, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] 983 ) 984 985 with self.assertRaisesRegex( 986 NotImplementedError, "Please register using `_register_namedtuple`" 987 ): 988 py_pytree.treespec_dumps(spec) 989 990 def test_pytree_custom_type_serialize_bad(self): 991 class DummyType: 992 def __init__(self, x, y): 993 self.x = x 994 self.y = y 995 996 py_pytree.register_pytree_node( 997 DummyType, 998 lambda dummy: ([dummy.x, dummy.y], None), 999 lambda xs, _: DummyType(*xs), 1000 ) 1001 1002 spec = py_pytree.TreeSpec( 1003 DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] 1004 ) 1005 with self.assertRaisesRegex( 1006 NotImplementedError, "No registered serialization name" 1007 ): 1008 roundtrip_spec = py_pytree.treespec_dumps(spec) 1009 1010 def test_pytree_custom_type_serialize(self): 1011 class DummyType: 1012 def __init__(self, x, y): 1013 self.x = x 1014 self.y = y 1015 1016 py_pytree.register_pytree_node( 1017 DummyType, 1018 lambda dummy: ([dummy.x, dummy.y], None), 1019 lambda xs, _: DummyType(*xs), 1020 serialized_type_name="test_pytree_custom_type_serialize.DummyType", 1021 to_dumpable_context=lambda context: "moo", 1022 from_dumpable_context=lambda dumpable_context: None, 1023 ) 1024 spec = py_pytree.TreeSpec( 1025 DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] 1026 ) 1027 serialized_spec = py_pytree.treespec_dumps(spec, 1) 1028 self.assertIn("moo", serialized_spec) 1029 roundtrip_spec = py_pytree.treespec_loads(serialized_spec) 1030 self.assertEqual(roundtrip_spec, spec) 1031 1032 def test_pytree_serialize_register_bad(self): 1033 class DummyType: 1034 def __init__(self, x, y): 1035 self.x = x 1036 self.y = y 1037 1038 with self.assertRaisesRegex( 1039 ValueError, "Both to_dumpable_context and from_dumpable_context" 1040 ): 1041 py_pytree.register_pytree_node( 1042 DummyType, 1043 lambda dummy: ([dummy.x, dummy.y], None), 1044 lambda xs, _: DummyType(*xs), 1045 serialized_type_name="test_pytree_serialize_register_bad.DummyType", 1046 to_dumpable_context=lambda context: "moo", 1047 ) 1048 1049 def test_pytree_context_serialize_bad(self): 1050 class DummyType: 1051 def __init__(self, x, y): 1052 self.x = x 1053 self.y = y 1054 1055 py_pytree.register_pytree_node( 1056 DummyType, 1057 lambda dummy: ([dummy.x, dummy.y], None), 1058 lambda xs, _: DummyType(*xs), 1059 serialized_type_name="test_pytree_serialize_serialize_bad.DummyType", 1060 to_dumpable_context=lambda context: DummyType, 1061 from_dumpable_context=lambda dumpable_context: None, 1062 ) 1063 1064 spec = py_pytree.TreeSpec( 1065 DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] 1066 ) 1067 1068 with self.assertRaisesRegex( 1069 TypeError, "Object of type type is not JSON serializable" 1070 ): 1071 py_pytree.treespec_dumps(spec) 1072 1073 def test_pytree_serialize_bad_protocol(self): 1074 import json 1075 1076 Point = namedtuple("Point", ["x", "y"]) 1077 spec = py_pytree.TreeSpec( 1078 namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] 1079 ) 1080 py_pytree._register_namedtuple( 1081 Point, 1082 serialized_type_name="test_pytree.test_pytree_serialize_bad_protocol.Point", 1083 ) 1084 1085 with self.assertRaisesRegex(ValueError, "Unknown protocol"): 1086 py_pytree.treespec_dumps(spec, -1) 1087 1088 serialized_spec = py_pytree.treespec_dumps(spec) 1089 protocol, data = json.loads(serialized_spec) 1090 bad_protocol_serialized_spec = json.dumps((-1, data)) 1091 1092 with self.assertRaisesRegex(ValueError, "Unknown protocol"): 1093 py_pytree.treespec_loads(bad_protocol_serialized_spec) 1094 1095 def test_saved_serialized(self): 1096 # py_pytree.tree_structure(OrderedDict([(1, (0, 1)), (2, 2), (3, {4: 3, 5: 4, 6: 5})])) 1097 complicated_spec = py_pytree.TreeSpec( 1098 OrderedDict, 1099 [1, 2, 3], 1100 [ 1101 py_pytree.TreeSpec( 1102 tuple, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] 1103 ), 1104 py_pytree.LeafSpec(), 1105 py_pytree.TreeSpec( 1106 dict, 1107 [4, 5, 6], 1108 [ 1109 py_pytree.LeafSpec(), 1110 py_pytree.LeafSpec(), 1111 py_pytree.LeafSpec(), 1112 ], 1113 ), 1114 ], 1115 ) 1116 # Ensure that the spec is valid 1117 self.assertEqual( 1118 complicated_spec, 1119 py_pytree.tree_structure( 1120 py_pytree.tree_unflatten( 1121 [0] * complicated_spec.num_leaves, complicated_spec 1122 ) 1123 ), 1124 ) 1125 1126 serialized_spec = py_pytree.treespec_dumps(complicated_spec) 1127 saved_spec = ( 1128 '[1, {"type": "collections.OrderedDict", "context": "[1, 2, 3]", ' 1129 '"children_spec": [{"type": "builtins.tuple", "context": "null", ' 1130 '"children_spec": [{"type": null, "context": null, ' 1131 '"children_spec": []}, {"type": null, "context": null, ' 1132 '"children_spec": []}]}, {"type": null, "context": null, ' 1133 '"children_spec": []}, {"type": "builtins.dict", "context": ' 1134 '"[4, 5, 6]", "children_spec": [{"type": null, "context": null, ' 1135 '"children_spec": []}, {"type": null, "context": null, "children_spec": ' 1136 '[]}, {"type": null, "context": null, "children_spec": []}]}]}]' 1137 ) 1138 self.assertEqual(serialized_spec, saved_spec) 1139 self.assertEqual(complicated_spec, py_pytree.treespec_loads(saved_spec)) 1140 1141 def test_tree_map_with_path(self): 1142 tree = [{i: i for i in range(10)}] 1143 all_zeros = py_pytree.tree_map_with_path( 1144 lambda kp, val: val - kp[1].key + kp[0].idx, tree 1145 ) 1146 self.assertEqual(all_zeros, [dict.fromkeys(range(10), 0)]) 1147 1148 def test_tree_map_with_path_multiple_trees(self): 1149 @dataclass 1150 class ACustomPytree: 1151 x: Any 1152 y: Any 1153 z: Any 1154 1155 tree1 = [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5] 1156 tree2 = [ACustomPytree(x=2, y={"cin": [2, 2, 2], "bar": 2}, z="leaf"), 2] 1157 1158 py_pytree.register_pytree_node( 1159 ACustomPytree, 1160 flatten_fn=lambda f: ([f.x, f.y], f.z), 1161 unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z), 1162 flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z), 1163 ) 1164 from_two_trees = py_pytree.tree_map_with_path( 1165 lambda kp, a, b: a + b, tree1, tree2 1166 ) 1167 from_one_tree = py_pytree.tree_map(lambda a: a + 2, tree1) 1168 self.assertEqual(from_two_trees, from_one_tree) 1169 1170 @skipIfTorchDynamo("dynamo pytree tracing doesn't work here") 1171 def test_tree_flatten_with_path_is_leaf(self): 1172 leaf_dict = {"foo": [(3)]} 1173 pytree = (["hello", [1, 2], leaf_dict],) 1174 key_leaves, spec = py_pytree.tree_flatten_with_path( 1175 pytree, is_leaf=lambda x: isinstance(x, dict) 1176 ) 1177 self.assertTrue(key_leaves[-1][1] is leaf_dict) 1178 1179 def test_tree_flatten_with_path_roundtrip(self): 1180 class ANamedTuple(NamedTuple): 1181 x: torch.Tensor 1182 y: int 1183 z: str 1184 1185 @dataclass 1186 class ACustomPytree: 1187 x: Any 1188 y: Any 1189 z: Any 1190 1191 py_pytree.register_pytree_node( 1192 ACustomPytree, 1193 flatten_fn=lambda f: ([f.x, f.y], f.z), 1194 unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z), 1195 flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z), 1196 ) 1197 1198 SOME_PYTREES = [ 1199 (None,), 1200 ["hello", [1, 2], {"foo": [(3)]}], 1201 [ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")], 1202 [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5], 1203 ] 1204 for pytree in SOME_PYTREES: 1205 key_leaves, spec = py_pytree.tree_flatten_with_path(pytree) 1206 actual = py_pytree.tree_unflatten([leaf for _, leaf in key_leaves], spec) 1207 self.assertEqual(actual, pytree) 1208 1209 def test_tree_leaves_with_path(self): 1210 class ANamedTuple(NamedTuple): 1211 x: torch.Tensor 1212 y: int 1213 z: str 1214 1215 @dataclass 1216 class ACustomPytree: 1217 x: Any 1218 y: Any 1219 z: Any 1220 1221 py_pytree.register_pytree_node( 1222 ACustomPytree, 1223 flatten_fn=lambda f: ([f.x, f.y], f.z), 1224 unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z), 1225 flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z), 1226 ) 1227 1228 SOME_PYTREES = [ 1229 (None,), 1230 ["hello", [1, 2], {"foo": [(3)]}], 1231 [ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")], 1232 [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5], 1233 ] 1234 for pytree in SOME_PYTREES: 1235 flat_out, _ = py_pytree.tree_flatten_with_path(pytree) 1236 leaves_out = py_pytree.tree_leaves_with_path(pytree) 1237 self.assertEqual(flat_out, leaves_out) 1238 1239 def test_key_str(self): 1240 class ANamedTuple(NamedTuple): 1241 x: str 1242 y: int 1243 1244 tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],) 1245 flat, _ = py_pytree.tree_flatten_with_path(tree) 1246 paths = [f"{py_pytree.keystr(kp)}: {val}" for kp, val in flat] 1247 self.assertEqual( 1248 paths, 1249 [ 1250 "[0][0]: hello", 1251 "[0][1][0]: 1", 1252 "[0][1][1]: 2", 1253 "[0][2]['foo'][0]: 3", 1254 "[0][2]['bar'][0].x: baz", 1255 "[0][2]['bar'][0].y: 10", 1256 ], 1257 ) 1258 1259 @skipIfTorchDynamo("AssertionError in dynamo") 1260 def test_flatten_flatten_with_key_consistency(self): 1261 """Check that flatten and flatten_with_key produces consistent leaves/context.""" 1262 reg = py_pytree.SUPPORTED_NODES 1263 1264 EXAMPLE_TREE = { 1265 list: [1, 2, 3], 1266 tuple: (1, 2, 3), 1267 dict: {"foo": 1, "bar": 2}, 1268 namedtuple: collections.namedtuple("ANamedTuple", ["x", "y"])(1, 2), 1269 OrderedDict: OrderedDict([("foo", 1), ("bar", 2)]), 1270 defaultdict: defaultdict(int, {"foo": 1, "bar": 2}), 1271 deque: deque([1, 2, 3]), 1272 torch.Size: torch.Size([1, 2, 3]), 1273 immutable_dict: immutable_dict({"foo": 1, "bar": 2}), 1274 immutable_list: immutable_list([1, 2, 3]), 1275 } 1276 1277 for typ in reg: 1278 example = EXAMPLE_TREE.get(typ) 1279 if example is None: 1280 continue 1281 flat_with_path, spec1 = py_pytree.tree_flatten_with_path(example) 1282 flat, spec2 = py_pytree.tree_flatten(example) 1283 1284 self.assertEqual(flat, [x[1] for x in flat_with_path]) 1285 self.assertEqual(spec1, spec2) 1286 1287 def test_key_access(self): 1288 class ANamedTuple(NamedTuple): 1289 x: str 1290 y: int 1291 1292 tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],) 1293 flat, _ = py_pytree.tree_flatten_with_path(tree) 1294 for kp, val in flat: 1295 self.assertEqual(py_pytree.key_get(tree, kp), val) 1296 1297 1298class TestCxxPytree(TestCase): 1299 def setUp(self): 1300 if IS_FBCODE: 1301 raise unittest.SkipTest("C++ pytree tests are not supported in fbcode") 1302 1303 def test_treespec_equality(self): 1304 self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec()) 1305 1306 @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.") 1307 def test_treespec_repr(self): 1308 # Check that it looks sane 1309 pytree = (0, [0, 0, [0]]) 1310 _, spec = cxx_pytree.tree_flatten(pytree) 1311 self.assertEqual(repr(spec), "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)") 1312 1313 @unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.") 1314 def test_treespec_repr_dynamo(self): 1315 # Check that it looks sane 1316 pytree = (0, [0, 0, [0]]) 1317 _, spec = cxx_pytree.tree_flatten(pytree) 1318 self.assertExpectedInline( 1319 repr(spec), 1320 "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)", 1321 ) 1322 1323 @parametrize( 1324 "spec", 1325 [ 1326 cxx_pytree.tree_structure([]), 1327 cxx_pytree.tree_structure(()), 1328 cxx_pytree.tree_structure({}), 1329 cxx_pytree.tree_structure([0]), 1330 cxx_pytree.tree_structure([0, 1]), 1331 cxx_pytree.tree_structure((0, 1, 2)), 1332 cxx_pytree.tree_structure({"a": 0, "b": 1, "c": 2}), 1333 cxx_pytree.tree_structure( 1334 OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})]) 1335 ), 1336 cxx_pytree.tree_structure([(0, 1, [2, 3])]), 1337 cxx_pytree.tree_structure( 1338 defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}}) 1339 ), 1340 ], 1341 ) 1342 def test_pytree_serialize(self, spec): 1343 self.assertEqual( 1344 spec, 1345 cxx_pytree.tree_structure( 1346 cxx_pytree.tree_unflatten([0] * spec.num_leaves, spec) 1347 ), 1348 ) 1349 1350 serialized_spec = cxx_pytree.treespec_dumps(spec) 1351 self.assertIsInstance(serialized_spec, str) 1352 self.assertEqual(spec, cxx_pytree.treespec_loads(serialized_spec)) 1353 1354 def test_pytree_serialize_namedtuple(self): 1355 py_pytree._register_namedtuple( 1356 GlobalPoint, 1357 serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.GlobalPoint", 1358 ) 1359 spec = cxx_pytree.tree_structure(GlobalPoint(0, 1)) 1360 1361 roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec)) 1362 self.assertEqual(roundtrip_spec.type._fields, spec.type._fields) 1363 1364 LocalPoint = namedtuple("LocalPoint", ["x", "y"]) 1365 py_pytree._register_namedtuple( 1366 LocalPoint, 1367 serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.LocalPoint", 1368 ) 1369 spec = cxx_pytree.tree_structure(LocalPoint(0, 1)) 1370 1371 roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec)) 1372 self.assertEqual(roundtrip_spec.type._fields, spec.type._fields) 1373 1374 def test_pytree_custom_type_serialize(self): 1375 cxx_pytree.register_pytree_node( 1376 GlobalDummyType, 1377 lambda dummy: ([dummy.x, dummy.y], None), 1378 lambda xs, _: GlobalDummyType(*xs), 1379 serialized_type_name="GlobalDummyType", 1380 ) 1381 spec = cxx_pytree.tree_structure(GlobalDummyType(0, 1)) 1382 serialized_spec = cxx_pytree.treespec_dumps(spec) 1383 roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec) 1384 self.assertEqual(roundtrip_spec, spec) 1385 1386 class LocalDummyType: 1387 def __init__(self, x, y): 1388 self.x = x 1389 self.y = y 1390 1391 cxx_pytree.register_pytree_node( 1392 LocalDummyType, 1393 lambda dummy: ([dummy.x, dummy.y], None), 1394 lambda xs, _: LocalDummyType(*xs), 1395 serialized_type_name="LocalDummyType", 1396 ) 1397 spec = cxx_pytree.tree_structure(LocalDummyType(0, 1)) 1398 serialized_spec = cxx_pytree.treespec_dumps(spec) 1399 roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec) 1400 self.assertEqual(roundtrip_spec, spec) 1401 1402 1403instantiate_parametrized_tests(TestGenericPytree) 1404instantiate_parametrized_tests(TestPythonPytree) 1405instantiate_parametrized_tests(TestCxxPytree) 1406 1407 1408if __name__ == "__main__": 1409 run_tests() 1410