1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: pytree"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport collections 4*da0073e9SAndroid Build Coastguard Workerimport inspect 5*da0073e9SAndroid Build Coastguard Workerimport os 6*da0073e9SAndroid Build Coastguard Workerimport re 7*da0073e9SAndroid Build Coastguard Workerimport subprocess 8*da0073e9SAndroid Build Coastguard Workerimport sys 9*da0073e9SAndroid Build Coastguard Workerimport unittest 10*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict, deque, namedtuple, OrderedDict, UserDict 11*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass 12*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, NamedTuple 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerimport torch 15*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree as py_pytree 16*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.immutable_collections import immutable_dict, immutable_list 17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 18*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 19*da0073e9SAndroid Build Coastguard Worker IS_FBCODE, 20*da0073e9SAndroid Build Coastguard Worker parametrize, 21*da0073e9SAndroid Build Coastguard Worker run_tests, 22*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 23*da0073e9SAndroid Build Coastguard Worker subtest, 24*da0073e9SAndroid Build Coastguard Worker TEST_WITH_TORCHDYNAMO, 25*da0073e9SAndroid Build Coastguard Worker TestCase, 26*da0073e9SAndroid Build Coastguard Worker) 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Workerif IS_FBCODE: 30*da0073e9SAndroid Build Coastguard Worker # optree is not yet enabled in fbcode, so just re-test the python implementation 31*da0073e9SAndroid Build Coastguard Worker cxx_pytree = py_pytree 32*da0073e9SAndroid Build Coastguard Workerelse: 33*da0073e9SAndroid Build Coastguard Worker import torch.utils._cxx_pytree as cxx_pytree 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard WorkerGlobalPoint = namedtuple("GlobalPoint", ["x", "y"]) 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Workerclass GlobalDummyType: 39*da0073e9SAndroid Build Coastguard Worker def __init__(self, x, y): 40*da0073e9SAndroid Build Coastguard Worker self.x = x 41*da0073e9SAndroid Build Coastguard Worker self.y = y 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Workerclass TestGenericPytree(TestCase): 45*da0073e9SAndroid Build Coastguard Worker def test_aligned_public_apis(self): 46*da0073e9SAndroid Build Coastguard Worker public_apis = py_pytree.__all__ 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker self.assertEqual(public_apis, cxx_pytree.__all__) 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker for name in public_apis: 51*da0073e9SAndroid Build Coastguard Worker cxx_api = getattr(cxx_pytree, name) 52*da0073e9SAndroid Build Coastguard Worker py_api = getattr(py_pytree, name) 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inspect.isclass(cxx_api), inspect.isclass(py_api)) 55*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inspect.isfunction(cxx_api), inspect.isfunction(py_api)) 56*da0073e9SAndroid Build Coastguard Worker if inspect.isfunction(cxx_api): 57*da0073e9SAndroid Build Coastguard Worker cxx_signature = inspect.signature(cxx_api) 58*da0073e9SAndroid Build Coastguard Worker py_signature = inspect.signature(py_api) 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker # Check the parameter names are the same. 61*da0073e9SAndroid Build Coastguard Worker cxx_param_names = list(cxx_signature.parameters) 62*da0073e9SAndroid Build Coastguard Worker py_param_names = list(py_signature.parameters) 63*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cxx_param_names, py_param_names) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker # Check the positional parameters are the same. 66*da0073e9SAndroid Build Coastguard Worker cxx_positional_param_names = [ 67*da0073e9SAndroid Build Coastguard Worker n 68*da0073e9SAndroid Build Coastguard Worker for n, p in cxx_signature.parameters.items() 69*da0073e9SAndroid Build Coastguard Worker if ( 70*da0073e9SAndroid Build Coastguard Worker p.kind 71*da0073e9SAndroid Build Coastguard Worker in { 72*da0073e9SAndroid Build Coastguard Worker inspect.Parameter.POSITIONAL_ONLY, 73*da0073e9SAndroid Build Coastguard Worker inspect.Parameter.POSITIONAL_OR_KEYWORD, 74*da0073e9SAndroid Build Coastguard Worker } 75*da0073e9SAndroid Build Coastguard Worker ) 76*da0073e9SAndroid Build Coastguard Worker ] 77*da0073e9SAndroid Build Coastguard Worker py_positional_param_names = [ 78*da0073e9SAndroid Build Coastguard Worker n 79*da0073e9SAndroid Build Coastguard Worker for n, p in py_signature.parameters.items() 80*da0073e9SAndroid Build Coastguard Worker if ( 81*da0073e9SAndroid Build Coastguard Worker p.kind 82*da0073e9SAndroid Build Coastguard Worker in { 83*da0073e9SAndroid Build Coastguard Worker inspect.Parameter.POSITIONAL_ONLY, 84*da0073e9SAndroid Build Coastguard Worker inspect.Parameter.POSITIONAL_OR_KEYWORD, 85*da0073e9SAndroid Build Coastguard Worker } 86*da0073e9SAndroid Build Coastguard Worker ) 87*da0073e9SAndroid Build Coastguard Worker ] 88*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cxx_positional_param_names, py_positional_param_names) 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker for py_name, py_param in py_signature.parameters.items(): 91*da0073e9SAndroid Build Coastguard Worker self.assertIn(py_name, cxx_signature.parameters) 92*da0073e9SAndroid Build Coastguard Worker cxx_param = cxx_signature.parameters[py_name] 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker # Check parameter kinds and default values are the same. 95*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cxx_param.kind, py_param.kind) 96*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cxx_param.default, py_param.default) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker # Check parameter annotations are the same. 99*da0073e9SAndroid Build Coastguard Worker if "TreeSpec" in str(cxx_param.annotation): 100*da0073e9SAndroid Build Coastguard Worker self.assertIn("TreeSpec", str(py_param.annotation)) 101*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 102*da0073e9SAndroid Build Coastguard Worker re.sub( 103*da0073e9SAndroid Build Coastguard Worker r"(?:\b)([\w\.]*)TreeSpec(?:\b)", 104*da0073e9SAndroid Build Coastguard Worker "TreeSpec", 105*da0073e9SAndroid Build Coastguard Worker str(cxx_param.annotation), 106*da0073e9SAndroid Build Coastguard Worker ), 107*da0073e9SAndroid Build Coastguard Worker re.sub( 108*da0073e9SAndroid Build Coastguard Worker r"(?:\b)([\w\.]*)TreeSpec(?:\b)", 109*da0073e9SAndroid Build Coastguard Worker "TreeSpec", 110*da0073e9SAndroid Build Coastguard Worker str(py_param.annotation), 111*da0073e9SAndroid Build Coastguard Worker ), 112*da0073e9SAndroid Build Coastguard Worker msg=( 113*da0073e9SAndroid Build Coastguard Worker f"C++ parameter {cxx_param} " 114*da0073e9SAndroid Build Coastguard Worker f"does not match Python parameter {py_param} " 115*da0073e9SAndroid Build Coastguard Worker f"for API `{name}`" 116*da0073e9SAndroid Build Coastguard Worker ), 117*da0073e9SAndroid Build Coastguard Worker ) 118*da0073e9SAndroid Build Coastguard Worker else: 119*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 120*da0073e9SAndroid Build Coastguard Worker cxx_param.annotation, 121*da0073e9SAndroid Build Coastguard Worker py_param.annotation, 122*da0073e9SAndroid Build Coastguard Worker msg=( 123*da0073e9SAndroid Build Coastguard Worker f"C++ parameter {cxx_param} " 124*da0073e9SAndroid Build Coastguard Worker f"does not match Python parameter {py_param} " 125*da0073e9SAndroid Build Coastguard Worker f"for API `{name}`" 126*da0073e9SAndroid Build Coastguard Worker ), 127*da0073e9SAndroid Build Coastguard Worker ) 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker @parametrize( 130*da0073e9SAndroid Build Coastguard Worker "pytree_impl", 131*da0073e9SAndroid Build Coastguard Worker [ 132*da0073e9SAndroid Build Coastguard Worker subtest(py_pytree, name="py"), 133*da0073e9SAndroid Build Coastguard Worker subtest(cxx_pytree, name="cxx"), 134*da0073e9SAndroid Build Coastguard Worker ], 135*da0073e9SAndroid Build Coastguard Worker ) 136*da0073e9SAndroid Build Coastguard Worker def test_register_pytree_node(self, pytree_impl): 137*da0073e9SAndroid Build Coastguard Worker class MyDict(UserDict): 138*da0073e9SAndroid Build Coastguard Worker pass 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker d = MyDict(a=1, b=2, c=3) 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker # Custom types are leaf nodes by default 143*da0073e9SAndroid Build Coastguard Worker values, spec = pytree_impl.tree_flatten(d) 144*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, [d]) 145*da0073e9SAndroid Build Coastguard Worker self.assertIs(values[0], d) 146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d, pytree_impl.tree_unflatten(values, spec)) 147*da0073e9SAndroid Build Coastguard Worker self.assertTrue(spec.is_leaf()) 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker # Register MyDict as a pytree node 150*da0073e9SAndroid Build Coastguard Worker pytree_impl.register_pytree_node( 151*da0073e9SAndroid Build Coastguard Worker MyDict, 152*da0073e9SAndroid Build Coastguard Worker lambda d: (list(d.values()), list(d.keys())), 153*da0073e9SAndroid Build Coastguard Worker lambda values, keys: MyDict(zip(keys, values)), 154*da0073e9SAndroid Build Coastguard Worker ) 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker values, spec = pytree_impl.tree_flatten(d) 157*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, [1, 2, 3]) 158*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d, pytree_impl.tree_unflatten(values, spec)) 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker # Do not allow registering the same type twice 161*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "already registered"): 162*da0073e9SAndroid Build Coastguard Worker pytree_impl.register_pytree_node( 163*da0073e9SAndroid Build Coastguard Worker MyDict, 164*da0073e9SAndroid Build Coastguard Worker lambda d: (list(d.values()), list(d.keys())), 165*da0073e9SAndroid Build Coastguard Worker lambda values, keys: MyDict(zip(keys, values)), 166*da0073e9SAndroid Build Coastguard Worker ) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker @parametrize( 169*da0073e9SAndroid Build Coastguard Worker "pytree_impl", 170*da0073e9SAndroid Build Coastguard Worker [ 171*da0073e9SAndroid Build Coastguard Worker subtest(py_pytree, name="py"), 172*da0073e9SAndroid Build Coastguard Worker subtest(cxx_pytree, name="cxx"), 173*da0073e9SAndroid Build Coastguard Worker ], 174*da0073e9SAndroid Build Coastguard Worker ) 175*da0073e9SAndroid Build Coastguard Worker def test_flatten_unflatten_leaf(self, pytree_impl): 176*da0073e9SAndroid Build Coastguard Worker def run_test_with_leaf(leaf): 177*da0073e9SAndroid Build Coastguard Worker values, treespec = pytree_impl.tree_flatten(leaf) 178*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, [leaf]) 179*da0073e9SAndroid Build Coastguard Worker self.assertEqual(treespec, pytree_impl.LeafSpec()) 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker unflattened = pytree_impl.tree_unflatten(values, treespec) 182*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unflattened, leaf) 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker run_test_with_leaf(1) 185*da0073e9SAndroid Build Coastguard Worker run_test_with_leaf(1.0) 186*da0073e9SAndroid Build Coastguard Worker run_test_with_leaf(None) 187*da0073e9SAndroid Build Coastguard Worker run_test_with_leaf(bool) 188*da0073e9SAndroid Build Coastguard Worker run_test_with_leaf(torch.randn(3, 3)) 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker @parametrize( 191*da0073e9SAndroid Build Coastguard Worker "pytree_impl,gen_expected_fn", 192*da0073e9SAndroid Build Coastguard Worker [ 193*da0073e9SAndroid Build Coastguard Worker subtest( 194*da0073e9SAndroid Build Coastguard Worker ( 195*da0073e9SAndroid Build Coastguard Worker py_pytree, 196*da0073e9SAndroid Build Coastguard Worker lambda tup: py_pytree.TreeSpec( 197*da0073e9SAndroid Build Coastguard Worker tuple, None, [py_pytree.LeafSpec() for _ in tup] 198*da0073e9SAndroid Build Coastguard Worker ), 199*da0073e9SAndroid Build Coastguard Worker ), 200*da0073e9SAndroid Build Coastguard Worker name="py", 201*da0073e9SAndroid Build Coastguard Worker ), 202*da0073e9SAndroid Build Coastguard Worker subtest( 203*da0073e9SAndroid Build Coastguard Worker (cxx_pytree, lambda tup: cxx_pytree.tree_structure((0,) * len(tup))), 204*da0073e9SAndroid Build Coastguard Worker name="cxx", 205*da0073e9SAndroid Build Coastguard Worker ), 206*da0073e9SAndroid Build Coastguard Worker ], 207*da0073e9SAndroid Build Coastguard Worker ) 208*da0073e9SAndroid Build Coastguard Worker def test_flatten_unflatten_tuple(self, pytree_impl, gen_expected_fn): 209*da0073e9SAndroid Build Coastguard Worker def run_test(tup): 210*da0073e9SAndroid Build Coastguard Worker expected_spec = gen_expected_fn(tup) 211*da0073e9SAndroid Build Coastguard Worker values, treespec = pytree_impl.tree_flatten(tup) 212*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(values, list) 213*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, list(tup)) 214*da0073e9SAndroid Build Coastguard Worker self.assertEqual(treespec, expected_spec) 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker unflattened = pytree_impl.tree_unflatten(values, treespec) 217*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unflattened, tup) 218*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(unflattened, tuple) 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker run_test(()) 221*da0073e9SAndroid Build Coastguard Worker run_test((1.0,)) 222*da0073e9SAndroid Build Coastguard Worker run_test((1.0, 2)) 223*da0073e9SAndroid Build Coastguard Worker run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11)) 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker @parametrize( 226*da0073e9SAndroid Build Coastguard Worker "pytree_impl,gen_expected_fn", 227*da0073e9SAndroid Build Coastguard Worker [ 228*da0073e9SAndroid Build Coastguard Worker subtest( 229*da0073e9SAndroid Build Coastguard Worker ( 230*da0073e9SAndroid Build Coastguard Worker py_pytree, 231*da0073e9SAndroid Build Coastguard Worker lambda lst: py_pytree.TreeSpec( 232*da0073e9SAndroid Build Coastguard Worker list, None, [py_pytree.LeafSpec() for _ in lst] 233*da0073e9SAndroid Build Coastguard Worker ), 234*da0073e9SAndroid Build Coastguard Worker ), 235*da0073e9SAndroid Build Coastguard Worker name="py", 236*da0073e9SAndroid Build Coastguard Worker ), 237*da0073e9SAndroid Build Coastguard Worker subtest( 238*da0073e9SAndroid Build Coastguard Worker (cxx_pytree, lambda lst: cxx_pytree.tree_structure([0] * len(lst))), 239*da0073e9SAndroid Build Coastguard Worker name="cxx", 240*da0073e9SAndroid Build Coastguard Worker ), 241*da0073e9SAndroid Build Coastguard Worker ], 242*da0073e9SAndroid Build Coastguard Worker ) 243*da0073e9SAndroid Build Coastguard Worker def test_flatten_unflatten_list(self, pytree_impl, gen_expected_fn): 244*da0073e9SAndroid Build Coastguard Worker def run_test(lst): 245*da0073e9SAndroid Build Coastguard Worker expected_spec = gen_expected_fn(lst) 246*da0073e9SAndroid Build Coastguard Worker values, treespec = pytree_impl.tree_flatten(lst) 247*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(values, list) 248*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, lst) 249*da0073e9SAndroid Build Coastguard Worker self.assertEqual(treespec, expected_spec) 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker unflattened = pytree_impl.tree_unflatten(values, treespec) 252*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unflattened, lst) 253*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(unflattened, list) 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker run_test([]) 256*da0073e9SAndroid Build Coastguard Worker run_test([1.0, 2]) 257*da0073e9SAndroid Build Coastguard Worker run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11]) 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker @parametrize( 260*da0073e9SAndroid Build Coastguard Worker "pytree_impl,gen_expected_fn", 261*da0073e9SAndroid Build Coastguard Worker [ 262*da0073e9SAndroid Build Coastguard Worker subtest( 263*da0073e9SAndroid Build Coastguard Worker ( 264*da0073e9SAndroid Build Coastguard Worker py_pytree, 265*da0073e9SAndroid Build Coastguard Worker lambda dct: py_pytree.TreeSpec( 266*da0073e9SAndroid Build Coastguard Worker dict, 267*da0073e9SAndroid Build Coastguard Worker list(dct.keys()), 268*da0073e9SAndroid Build Coastguard Worker [py_pytree.LeafSpec() for _ in dct.values()], 269*da0073e9SAndroid Build Coastguard Worker ), 270*da0073e9SAndroid Build Coastguard Worker ), 271*da0073e9SAndroid Build Coastguard Worker name="py", 272*da0073e9SAndroid Build Coastguard Worker ), 273*da0073e9SAndroid Build Coastguard Worker subtest( 274*da0073e9SAndroid Build Coastguard Worker ( 275*da0073e9SAndroid Build Coastguard Worker cxx_pytree, 276*da0073e9SAndroid Build Coastguard Worker lambda dct: cxx_pytree.tree_structure(dict.fromkeys(dct, 0)), 277*da0073e9SAndroid Build Coastguard Worker ), 278*da0073e9SAndroid Build Coastguard Worker name="cxx", 279*da0073e9SAndroid Build Coastguard Worker ), 280*da0073e9SAndroid Build Coastguard Worker ], 281*da0073e9SAndroid Build Coastguard Worker ) 282*da0073e9SAndroid Build Coastguard Worker def test_flatten_unflatten_dict(self, pytree_impl, gen_expected_fn): 283*da0073e9SAndroid Build Coastguard Worker def run_test(dct): 284*da0073e9SAndroid Build Coastguard Worker expected_spec = gen_expected_fn(dct) 285*da0073e9SAndroid Build Coastguard Worker values, treespec = pytree_impl.tree_flatten(dct) 286*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(values, list) 287*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, list(dct.values())) 288*da0073e9SAndroid Build Coastguard Worker self.assertEqual(treespec, expected_spec) 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker unflattened = pytree_impl.tree_unflatten(values, treespec) 291*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unflattened, dct) 292*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(unflattened, dict) 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker run_test({}) 295*da0073e9SAndroid Build Coastguard Worker run_test({"a": 1}) 296*da0073e9SAndroid Build Coastguard Worker run_test({"abcdefg": torch.randn(2, 3)}) 297*da0073e9SAndroid Build Coastguard Worker run_test({1: torch.randn(2, 3)}) 298*da0073e9SAndroid Build Coastguard Worker run_test({"a": 1, "b": 2, "c": torch.randn(2, 3)}) 299*da0073e9SAndroid Build Coastguard Worker 300*da0073e9SAndroid Build Coastguard Worker @parametrize( 301*da0073e9SAndroid Build Coastguard Worker "pytree_impl,gen_expected_fn", 302*da0073e9SAndroid Build Coastguard Worker [ 303*da0073e9SAndroid Build Coastguard Worker subtest( 304*da0073e9SAndroid Build Coastguard Worker ( 305*da0073e9SAndroid Build Coastguard Worker py_pytree, 306*da0073e9SAndroid Build Coastguard Worker lambda odict: py_pytree.TreeSpec( 307*da0073e9SAndroid Build Coastguard Worker OrderedDict, 308*da0073e9SAndroid Build Coastguard Worker list(odict.keys()), 309*da0073e9SAndroid Build Coastguard Worker [py_pytree.LeafSpec() for _ in odict.values()], 310*da0073e9SAndroid Build Coastguard Worker ), 311*da0073e9SAndroid Build Coastguard Worker ), 312*da0073e9SAndroid Build Coastguard Worker name="py", 313*da0073e9SAndroid Build Coastguard Worker ), 314*da0073e9SAndroid Build Coastguard Worker subtest( 315*da0073e9SAndroid Build Coastguard Worker ( 316*da0073e9SAndroid Build Coastguard Worker cxx_pytree, 317*da0073e9SAndroid Build Coastguard Worker lambda odict: cxx_pytree.tree_structure( 318*da0073e9SAndroid Build Coastguard Worker OrderedDict.fromkeys(odict, 0) 319*da0073e9SAndroid Build Coastguard Worker ), 320*da0073e9SAndroid Build Coastguard Worker ), 321*da0073e9SAndroid Build Coastguard Worker name="cxx", 322*da0073e9SAndroid Build Coastguard Worker ), 323*da0073e9SAndroid Build Coastguard Worker ], 324*da0073e9SAndroid Build Coastguard Worker ) 325*da0073e9SAndroid Build Coastguard Worker def test_flatten_unflatten_ordereddict(self, pytree_impl, gen_expected_fn): 326*da0073e9SAndroid Build Coastguard Worker def run_test(odict): 327*da0073e9SAndroid Build Coastguard Worker expected_spec = gen_expected_fn(odict) 328*da0073e9SAndroid Build Coastguard Worker values, treespec = pytree_impl.tree_flatten(odict) 329*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(values, list) 330*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, list(odict.values())) 331*da0073e9SAndroid Build Coastguard Worker self.assertEqual(treespec, expected_spec) 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker unflattened = pytree_impl.tree_unflatten(values, treespec) 334*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unflattened, odict) 335*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(unflattened, OrderedDict) 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker od = OrderedDict() 338*da0073e9SAndroid Build Coastguard Worker run_test(od) 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker od["b"] = 1 341*da0073e9SAndroid Build Coastguard Worker od["a"] = torch.tensor(3.14) 342*da0073e9SAndroid Build Coastguard Worker run_test(od) 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker @parametrize( 345*da0073e9SAndroid Build Coastguard Worker "pytree_impl,gen_expected_fn", 346*da0073e9SAndroid Build Coastguard Worker [ 347*da0073e9SAndroid Build Coastguard Worker subtest( 348*da0073e9SAndroid Build Coastguard Worker ( 349*da0073e9SAndroid Build Coastguard Worker py_pytree, 350*da0073e9SAndroid Build Coastguard Worker lambda ddct: py_pytree.TreeSpec( 351*da0073e9SAndroid Build Coastguard Worker defaultdict, 352*da0073e9SAndroid Build Coastguard Worker [ddct.default_factory, list(ddct.keys())], 353*da0073e9SAndroid Build Coastguard Worker [py_pytree.LeafSpec() for _ in ddct.values()], 354*da0073e9SAndroid Build Coastguard Worker ), 355*da0073e9SAndroid Build Coastguard Worker ), 356*da0073e9SAndroid Build Coastguard Worker name="py", 357*da0073e9SAndroid Build Coastguard Worker ), 358*da0073e9SAndroid Build Coastguard Worker subtest( 359*da0073e9SAndroid Build Coastguard Worker ( 360*da0073e9SAndroid Build Coastguard Worker cxx_pytree, 361*da0073e9SAndroid Build Coastguard Worker lambda ddct: cxx_pytree.tree_structure( 362*da0073e9SAndroid Build Coastguard Worker defaultdict(ddct.default_factory, dict.fromkeys(ddct, 0)) 363*da0073e9SAndroid Build Coastguard Worker ), 364*da0073e9SAndroid Build Coastguard Worker ), 365*da0073e9SAndroid Build Coastguard Worker name="cxx", 366*da0073e9SAndroid Build Coastguard Worker ), 367*da0073e9SAndroid Build Coastguard Worker ], 368*da0073e9SAndroid Build Coastguard Worker ) 369*da0073e9SAndroid Build Coastguard Worker def test_flatten_unflatten_defaultdict(self, pytree_impl, gen_expected_fn): 370*da0073e9SAndroid Build Coastguard Worker def run_test(ddct): 371*da0073e9SAndroid Build Coastguard Worker expected_spec = gen_expected_fn(ddct) 372*da0073e9SAndroid Build Coastguard Worker values, treespec = pytree_impl.tree_flatten(ddct) 373*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(values, list) 374*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, list(ddct.values())) 375*da0073e9SAndroid Build Coastguard Worker self.assertEqual(treespec, expected_spec) 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Worker unflattened = pytree_impl.tree_unflatten(values, treespec) 378*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unflattened, ddct) 379*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unflattened.default_factory, ddct.default_factory) 380*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(unflattened, defaultdict) 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker run_test(defaultdict(list, {})) 383*da0073e9SAndroid Build Coastguard Worker run_test(defaultdict(int, {"a": 1})) 384*da0073e9SAndroid Build Coastguard Worker run_test(defaultdict(int, {"abcdefg": torch.randn(2, 3)})) 385*da0073e9SAndroid Build Coastguard Worker run_test(defaultdict(int, {1: torch.randn(2, 3)})) 386*da0073e9SAndroid Build Coastguard Worker run_test(defaultdict(int, {"a": 1, "b": 2, "c": torch.randn(2, 3)})) 387*da0073e9SAndroid Build Coastguard Worker 388*da0073e9SAndroid Build Coastguard Worker @parametrize( 389*da0073e9SAndroid Build Coastguard Worker "pytree_impl,gen_expected_fn", 390*da0073e9SAndroid Build Coastguard Worker [ 391*da0073e9SAndroid Build Coastguard Worker subtest( 392*da0073e9SAndroid Build Coastguard Worker ( 393*da0073e9SAndroid Build Coastguard Worker py_pytree, 394*da0073e9SAndroid Build Coastguard Worker lambda deq: py_pytree.TreeSpec( 395*da0073e9SAndroid Build Coastguard Worker deque, deq.maxlen, [py_pytree.LeafSpec() for _ in deq] 396*da0073e9SAndroid Build Coastguard Worker ), 397*da0073e9SAndroid Build Coastguard Worker ), 398*da0073e9SAndroid Build Coastguard Worker name="py", 399*da0073e9SAndroid Build Coastguard Worker ), 400*da0073e9SAndroid Build Coastguard Worker subtest( 401*da0073e9SAndroid Build Coastguard Worker ( 402*da0073e9SAndroid Build Coastguard Worker cxx_pytree, 403*da0073e9SAndroid Build Coastguard Worker lambda deq: cxx_pytree.tree_structure( 404*da0073e9SAndroid Build Coastguard Worker deque(deq, maxlen=deq.maxlen) 405*da0073e9SAndroid Build Coastguard Worker ), 406*da0073e9SAndroid Build Coastguard Worker ), 407*da0073e9SAndroid Build Coastguard Worker name="cxx", 408*da0073e9SAndroid Build Coastguard Worker ), 409*da0073e9SAndroid Build Coastguard Worker ], 410*da0073e9SAndroid Build Coastguard Worker ) 411*da0073e9SAndroid Build Coastguard Worker def test_flatten_unflatten_deque(self, pytree_impl, gen_expected_fn): 412*da0073e9SAndroid Build Coastguard Worker def run_test(deq): 413*da0073e9SAndroid Build Coastguard Worker expected_spec = gen_expected_fn(deq) 414*da0073e9SAndroid Build Coastguard Worker values, treespec = pytree_impl.tree_flatten(deq) 415*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(values, list) 416*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, list(deq)) 417*da0073e9SAndroid Build Coastguard Worker self.assertEqual(treespec, expected_spec) 418*da0073e9SAndroid Build Coastguard Worker 419*da0073e9SAndroid Build Coastguard Worker unflattened = pytree_impl.tree_unflatten(values, treespec) 420*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unflattened, deq) 421*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unflattened.maxlen, deq.maxlen) 422*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(unflattened, deque) 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker run_test(deque([])) 425*da0073e9SAndroid Build Coastguard Worker run_test(deque([1.0, 2])) 426*da0073e9SAndroid Build Coastguard Worker run_test(deque([torch.tensor([1.0, 2]), 2, 10, 9, 11], maxlen=8)) 427*da0073e9SAndroid Build Coastguard Worker 428*da0073e9SAndroid Build Coastguard Worker @parametrize( 429*da0073e9SAndroid Build Coastguard Worker "pytree_impl", 430*da0073e9SAndroid Build Coastguard Worker [ 431*da0073e9SAndroid Build Coastguard Worker subtest(py_pytree, name="py"), 432*da0073e9SAndroid Build Coastguard Worker subtest(cxx_pytree, name="cxx"), 433*da0073e9SAndroid Build Coastguard Worker ], 434*da0073e9SAndroid Build Coastguard Worker ) 435*da0073e9SAndroid Build Coastguard Worker def test_flatten_unflatten_namedtuple(self, pytree_impl): 436*da0073e9SAndroid Build Coastguard Worker Point = namedtuple("Point", ["x", "y"]) 437*da0073e9SAndroid Build Coastguard Worker 438*da0073e9SAndroid Build Coastguard Worker def run_test(tup): 439*da0073e9SAndroid Build Coastguard Worker if pytree_impl is py_pytree: 440*da0073e9SAndroid Build Coastguard Worker expected_spec = py_pytree.TreeSpec( 441*da0073e9SAndroid Build Coastguard Worker namedtuple, Point, [py_pytree.LeafSpec() for _ in tup] 442*da0073e9SAndroid Build Coastguard Worker ) 443*da0073e9SAndroid Build Coastguard Worker else: 444*da0073e9SAndroid Build Coastguard Worker expected_spec = cxx_pytree.tree_structure(Point(0, 1)) 445*da0073e9SAndroid Build Coastguard Worker values, treespec = pytree_impl.tree_flatten(tup) 446*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(values, list) 447*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, list(tup)) 448*da0073e9SAndroid Build Coastguard Worker self.assertEqual(treespec, expected_spec) 449*da0073e9SAndroid Build Coastguard Worker 450*da0073e9SAndroid Build Coastguard Worker unflattened = pytree_impl.tree_unflatten(values, treespec) 451*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unflattened, tup) 452*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(unflattened, Point) 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker run_test(Point(1.0, 2)) 455*da0073e9SAndroid Build Coastguard Worker run_test(Point(torch.tensor(1.0), 2)) 456*da0073e9SAndroid Build Coastguard Worker 457*da0073e9SAndroid Build Coastguard Worker @parametrize( 458*da0073e9SAndroid Build Coastguard Worker "op", 459*da0073e9SAndroid Build Coastguard Worker [ 460*da0073e9SAndroid Build Coastguard Worker subtest(torch.max, name="max"), 461*da0073e9SAndroid Build Coastguard Worker subtest(torch.min, name="min"), 462*da0073e9SAndroid Build Coastguard Worker ], 463*da0073e9SAndroid Build Coastguard Worker ) 464*da0073e9SAndroid Build Coastguard Worker @parametrize( 465*da0073e9SAndroid Build Coastguard Worker "pytree_impl", 466*da0073e9SAndroid Build Coastguard Worker [ 467*da0073e9SAndroid Build Coastguard Worker subtest(py_pytree, name="py"), 468*da0073e9SAndroid Build Coastguard Worker subtest(cxx_pytree, name="cxx"), 469*da0073e9SAndroid Build Coastguard Worker ], 470*da0073e9SAndroid Build Coastguard Worker ) 471*da0073e9SAndroid Build Coastguard Worker def test_flatten_unflatten_return_types(self, pytree_impl, op): 472*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3) 473*da0073e9SAndroid Build Coastguard Worker expected = op(x, dim=0) 474*da0073e9SAndroid Build Coastguard Worker 475*da0073e9SAndroid Build Coastguard Worker values, spec = pytree_impl.tree_flatten(expected) 476*da0073e9SAndroid Build Coastguard Worker # Check that values is actually List[Tensor] and not (ReturnType(...),) 477*da0073e9SAndroid Build Coastguard Worker for value in values: 478*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(value, torch.Tensor) 479*da0073e9SAndroid Build Coastguard Worker result = pytree_impl.tree_unflatten(values, spec) 480*da0073e9SAndroid Build Coastguard Worker 481*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(result), type(expected)) 482*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 483*da0073e9SAndroid Build Coastguard Worker 484*da0073e9SAndroid Build Coastguard Worker @parametrize( 485*da0073e9SAndroid Build Coastguard Worker "pytree_impl", 486*da0073e9SAndroid Build Coastguard Worker [ 487*da0073e9SAndroid Build Coastguard Worker subtest(py_pytree, name="py"), 488*da0073e9SAndroid Build Coastguard Worker subtest(cxx_pytree, name="cxx"), 489*da0073e9SAndroid Build Coastguard Worker ], 490*da0073e9SAndroid Build Coastguard Worker ) 491*da0073e9SAndroid Build Coastguard Worker def test_flatten_unflatten_nested(self, pytree_impl): 492*da0073e9SAndroid Build Coastguard Worker def run_test(pytree): 493*da0073e9SAndroid Build Coastguard Worker values, treespec = pytree_impl.tree_flatten(pytree) 494*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(values, list) 495*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(values), treespec.num_leaves) 496*da0073e9SAndroid Build Coastguard Worker 497*da0073e9SAndroid Build Coastguard Worker # NB: python basic data structures (dict list tuple) all have 498*da0073e9SAndroid Build Coastguard Worker # contents equality defined on them, so the following works for them. 499*da0073e9SAndroid Build Coastguard Worker unflattened = pytree_impl.tree_unflatten(values, treespec) 500*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unflattened, pytree) 501*da0073e9SAndroid Build Coastguard Worker 502*da0073e9SAndroid Build Coastguard Worker cases = [ 503*da0073e9SAndroid Build Coastguard Worker [()], 504*da0073e9SAndroid Build Coastguard Worker ([],), 505*da0073e9SAndroid Build Coastguard Worker {"a": ()}, 506*da0073e9SAndroid Build Coastguard Worker {"a": 0, "b": [{"c": 1}]}, 507*da0073e9SAndroid Build Coastguard Worker {"a": 0, "b": [1, {"c": 2}, torch.randn(3)], "c": (torch.randn(2, 3), 1)}, 508*da0073e9SAndroid Build Coastguard Worker ] 509*da0073e9SAndroid Build Coastguard Worker for case in cases: 510*da0073e9SAndroid Build Coastguard Worker run_test(case) 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker @parametrize( 513*da0073e9SAndroid Build Coastguard Worker "pytree_impl", 514*da0073e9SAndroid Build Coastguard Worker [ 515*da0073e9SAndroid Build Coastguard Worker subtest(py_pytree, name="py"), 516*da0073e9SAndroid Build Coastguard Worker subtest(cxx_pytree, name="cxx"), 517*da0073e9SAndroid Build Coastguard Worker ], 518*da0073e9SAndroid Build Coastguard Worker ) 519*da0073e9SAndroid Build Coastguard Worker def test_flatten_with_is_leaf(self, pytree_impl): 520*da0073e9SAndroid Build Coastguard Worker def run_test(pytree, one_level_leaves): 521*da0073e9SAndroid Build Coastguard Worker values, treespec = pytree_impl.tree_flatten( 522*da0073e9SAndroid Build Coastguard Worker pytree, is_leaf=lambda x: x is not pytree 523*da0073e9SAndroid Build Coastguard Worker ) 524*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(values, list) 525*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(values), treespec.num_nodes - 1) 526*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(values), treespec.num_leaves) 527*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(values), treespec.num_children) 528*da0073e9SAndroid Build Coastguard Worker self.assertEqual(values, one_level_leaves) 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 531*da0073e9SAndroid Build Coastguard Worker treespec, 532*da0073e9SAndroid Build Coastguard Worker pytree_impl.tree_structure( 533*da0073e9SAndroid Build Coastguard Worker pytree_impl.tree_unflatten([0] * treespec.num_leaves, treespec) 534*da0073e9SAndroid Build Coastguard Worker ), 535*da0073e9SAndroid Build Coastguard Worker ) 536*da0073e9SAndroid Build Coastguard Worker 537*da0073e9SAndroid Build Coastguard Worker unflattened = pytree_impl.tree_unflatten(values, treespec) 538*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unflattened, pytree) 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Worker cases = [ 541*da0073e9SAndroid Build Coastguard Worker ([()], [()]), 542*da0073e9SAndroid Build Coastguard Worker (([],), [[]]), 543*da0073e9SAndroid Build Coastguard Worker ({"a": ()}, [()]), 544*da0073e9SAndroid Build Coastguard Worker ({"a": 0, "b": [{"c": 1}]}, [0, [{"c": 1}]]), 545*da0073e9SAndroid Build Coastguard Worker ( 546*da0073e9SAndroid Build Coastguard Worker { 547*da0073e9SAndroid Build Coastguard Worker "a": 0, 548*da0073e9SAndroid Build Coastguard Worker "b": [1, {"c": 2}, torch.ones(3)], 549*da0073e9SAndroid Build Coastguard Worker "c": (torch.zeros(2, 3), 1), 550*da0073e9SAndroid Build Coastguard Worker }, 551*da0073e9SAndroid Build Coastguard Worker [0, [1, {"c": 2}, torch.ones(3)], (torch.zeros(2, 3), 1)], 552*da0073e9SAndroid Build Coastguard Worker ), 553*da0073e9SAndroid Build Coastguard Worker ] 554*da0073e9SAndroid Build Coastguard Worker for case in cases: 555*da0073e9SAndroid Build Coastguard Worker run_test(*case) 556*da0073e9SAndroid Build Coastguard Worker 557*da0073e9SAndroid Build Coastguard Worker @parametrize( 558*da0073e9SAndroid Build Coastguard Worker "pytree_impl", 559*da0073e9SAndroid Build Coastguard Worker [ 560*da0073e9SAndroid Build Coastguard Worker subtest(py_pytree, name="py"), 561*da0073e9SAndroid Build Coastguard Worker subtest(cxx_pytree, name="cxx"), 562*da0073e9SAndroid Build Coastguard Worker ], 563*da0073e9SAndroid Build Coastguard Worker ) 564*da0073e9SAndroid Build Coastguard Worker def test_tree_map(self, pytree_impl): 565*da0073e9SAndroid Build Coastguard Worker def run_test(pytree): 566*da0073e9SAndroid Build Coastguard Worker def f(x): 567*da0073e9SAndroid Build Coastguard Worker return x * 3 568*da0073e9SAndroid Build Coastguard Worker 569*da0073e9SAndroid Build Coastguard Worker sm1 = sum(map(f, pytree_impl.tree_leaves(pytree))) 570*da0073e9SAndroid Build Coastguard Worker sm2 = sum(pytree_impl.tree_leaves(pytree_impl.tree_map(f, pytree))) 571*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sm1, sm2) 572*da0073e9SAndroid Build Coastguard Worker 573*da0073e9SAndroid Build Coastguard Worker def invf(x): 574*da0073e9SAndroid Build Coastguard Worker return x // 3 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 577*da0073e9SAndroid Build Coastguard Worker pytree_impl.tree_map(invf, pytree_impl.tree_map(f, pytree)), 578*da0073e9SAndroid Build Coastguard Worker pytree, 579*da0073e9SAndroid Build Coastguard Worker ) 580*da0073e9SAndroid Build Coastguard Worker 581*da0073e9SAndroid Build Coastguard Worker cases = [ 582*da0073e9SAndroid Build Coastguard Worker [()], 583*da0073e9SAndroid Build Coastguard Worker ([],), 584*da0073e9SAndroid Build Coastguard Worker {"a": ()}, 585*da0073e9SAndroid Build Coastguard Worker {"a": 1, "b": [{"c": 2}]}, 586*da0073e9SAndroid Build Coastguard Worker {"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)}, 587*da0073e9SAndroid Build Coastguard Worker ] 588*da0073e9SAndroid Build Coastguard Worker for case in cases: 589*da0073e9SAndroid Build Coastguard Worker run_test(case) 590*da0073e9SAndroid Build Coastguard Worker 591*da0073e9SAndroid Build Coastguard Worker @parametrize( 592*da0073e9SAndroid Build Coastguard Worker "pytree_impl", 593*da0073e9SAndroid Build Coastguard Worker [ 594*da0073e9SAndroid Build Coastguard Worker subtest(py_pytree, name="py"), 595*da0073e9SAndroid Build Coastguard Worker subtest(cxx_pytree, name="cxx"), 596*da0073e9SAndroid Build Coastguard Worker ], 597*da0073e9SAndroid Build Coastguard Worker ) 598*da0073e9SAndroid Build Coastguard Worker def test_tree_map_multi_inputs(self, pytree_impl): 599*da0073e9SAndroid Build Coastguard Worker def run_test(pytree): 600*da0073e9SAndroid Build Coastguard Worker def f(x, y, z): 601*da0073e9SAndroid Build Coastguard Worker return x, [y, (z, 0)] 602*da0073e9SAndroid Build Coastguard Worker 603*da0073e9SAndroid Build Coastguard Worker pytree_x = pytree 604*da0073e9SAndroid Build Coastguard Worker pytree_y = pytree_impl.tree_map(lambda x: (x + 1,), pytree) 605*da0073e9SAndroid Build Coastguard Worker pytree_z = pytree_impl.tree_map(lambda x: {"a": x * 2, "b": 2}, pytree) 606*da0073e9SAndroid Build Coastguard Worker 607*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 608*da0073e9SAndroid Build Coastguard Worker pytree_impl.tree_map(f, pytree_x, pytree_y, pytree_z), 609*da0073e9SAndroid Build Coastguard Worker pytree_impl.tree_map( 610*da0073e9SAndroid Build Coastguard Worker lambda x: f(x, (x + 1,), {"a": x * 2, "b": 2}), pytree 611*da0073e9SAndroid Build Coastguard Worker ), 612*da0073e9SAndroid Build Coastguard Worker ) 613*da0073e9SAndroid Build Coastguard Worker 614*da0073e9SAndroid Build Coastguard Worker cases = [ 615*da0073e9SAndroid Build Coastguard Worker [()], 616*da0073e9SAndroid Build Coastguard Worker ([],), 617*da0073e9SAndroid Build Coastguard Worker {"a": ()}, 618*da0073e9SAndroid Build Coastguard Worker {"a": 1, "b": [{"c": 2}]}, 619*da0073e9SAndroid Build Coastguard Worker {"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)}, 620*da0073e9SAndroid Build Coastguard Worker ] 621*da0073e9SAndroid Build Coastguard Worker for case in cases: 622*da0073e9SAndroid Build Coastguard Worker run_test(case) 623*da0073e9SAndroid Build Coastguard Worker 624*da0073e9SAndroid Build Coastguard Worker @parametrize( 625*da0073e9SAndroid Build Coastguard Worker "pytree_impl", 626*da0073e9SAndroid Build Coastguard Worker [ 627*da0073e9SAndroid Build Coastguard Worker subtest(py_pytree, name="py"), 628*da0073e9SAndroid Build Coastguard Worker subtest(cxx_pytree, name="cxx"), 629*da0073e9SAndroid Build Coastguard Worker ], 630*da0073e9SAndroid Build Coastguard Worker ) 631*da0073e9SAndroid Build Coastguard Worker def test_tree_map_only(self, pytree_impl): 632*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 633*da0073e9SAndroid Build Coastguard Worker pytree_impl.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"] 634*da0073e9SAndroid Build Coastguard Worker ) 635*da0073e9SAndroid Build Coastguard Worker 636*da0073e9SAndroid Build Coastguard Worker @parametrize( 637*da0073e9SAndroid Build Coastguard Worker "pytree_impl", 638*da0073e9SAndroid Build Coastguard Worker [ 639*da0073e9SAndroid Build Coastguard Worker subtest(py_pytree, name="py"), 640*da0073e9SAndroid Build Coastguard Worker subtest(cxx_pytree, name="cxx"), 641*da0073e9SAndroid Build Coastguard Worker ], 642*da0073e9SAndroid Build Coastguard Worker ) 643*da0073e9SAndroid Build Coastguard Worker def test_tree_map_only_predicate_fn(self, pytree_impl): 644*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 645*da0073e9SAndroid Build Coastguard Worker pytree_impl.tree_map_only(lambda x: x == 0, lambda x: x + 2, [0, 1]), [2, 1] 646*da0073e9SAndroid Build Coastguard Worker ) 647*da0073e9SAndroid Build Coastguard Worker 648*da0073e9SAndroid Build Coastguard Worker @parametrize( 649*da0073e9SAndroid Build Coastguard Worker "pytree_impl", 650*da0073e9SAndroid Build Coastguard Worker [ 651*da0073e9SAndroid Build Coastguard Worker subtest(py_pytree, name="py"), 652*da0073e9SAndroid Build Coastguard Worker subtest(cxx_pytree, name="cxx"), 653*da0073e9SAndroid Build Coastguard Worker ], 654*da0073e9SAndroid Build Coastguard Worker ) 655*da0073e9SAndroid Build Coastguard Worker def test_tree_all_any(self, pytree_impl): 656*da0073e9SAndroid Build Coastguard Worker self.assertTrue(pytree_impl.tree_all(lambda x: x % 2, [1, 3])) 657*da0073e9SAndroid Build Coastguard Worker self.assertFalse(pytree_impl.tree_all(lambda x: x % 2, [0, 1])) 658*da0073e9SAndroid Build Coastguard Worker self.assertTrue(pytree_impl.tree_any(lambda x: x % 2, [0, 1])) 659*da0073e9SAndroid Build Coastguard Worker self.assertFalse(pytree_impl.tree_any(lambda x: x % 2, [0, 2])) 660*da0073e9SAndroid Build Coastguard Worker self.assertTrue(pytree_impl.tree_all_only(int, lambda x: x % 2, [1, 3, "a"])) 661*da0073e9SAndroid Build Coastguard Worker self.assertFalse(pytree_impl.tree_all_only(int, lambda x: x % 2, [0, 1, "a"])) 662*da0073e9SAndroid Build Coastguard Worker self.assertTrue(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 1, "a"])) 663*da0073e9SAndroid Build Coastguard Worker self.assertFalse(pytree_impl.tree_any_only(int, lambda x: x % 2, [0, 2, "a"])) 664*da0073e9SAndroid Build Coastguard Worker 665*da0073e9SAndroid Build Coastguard Worker @parametrize( 666*da0073e9SAndroid Build Coastguard Worker "pytree_impl", 667*da0073e9SAndroid Build Coastguard Worker [ 668*da0073e9SAndroid Build Coastguard Worker subtest(py_pytree, name="py"), 669*da0073e9SAndroid Build Coastguard Worker subtest(cxx_pytree, name="cxx"), 670*da0073e9SAndroid Build Coastguard Worker ], 671*da0073e9SAndroid Build Coastguard Worker ) 672*da0073e9SAndroid Build Coastguard Worker def test_broadcast_to_and_flatten(self, pytree_impl): 673*da0073e9SAndroid Build Coastguard Worker cases = [ 674*da0073e9SAndroid Build Coastguard Worker (1, (), []), 675*da0073e9SAndroid Build Coastguard Worker # Same (flat) structures 676*da0073e9SAndroid Build Coastguard Worker ((1,), (0,), [1]), 677*da0073e9SAndroid Build Coastguard Worker ([1], [0], [1]), 678*da0073e9SAndroid Build Coastguard Worker ((1, 2, 3), (0, 0, 0), [1, 2, 3]), 679*da0073e9SAndroid Build Coastguard Worker ({"a": 1, "b": 2}, {"a": 0, "b": 0}, [1, 2]), 680*da0073e9SAndroid Build Coastguard Worker # Mismatched (flat) structures 681*da0073e9SAndroid Build Coastguard Worker ([1], (0,), None), 682*da0073e9SAndroid Build Coastguard Worker ([1], (0,), None), 683*da0073e9SAndroid Build Coastguard Worker ((1,), [0], None), 684*da0073e9SAndroid Build Coastguard Worker ((1, 2, 3), (0, 0), None), 685*da0073e9SAndroid Build Coastguard Worker ({"a": 1, "b": 2}, {"a": 0}, None), 686*da0073e9SAndroid Build Coastguard Worker ({"a": 1, "b": 2}, {"a": 0, "c": 0}, None), 687*da0073e9SAndroid Build Coastguard Worker ({"a": 1, "b": 2}, {"a": 0, "b": 0, "c": 0}, None), 688*da0073e9SAndroid Build Coastguard Worker # Same (nested) structures 689*da0073e9SAndroid Build Coastguard Worker ((1, [2, 3]), (0, [0, 0]), [1, 2, 3]), 690*da0073e9SAndroid Build Coastguard Worker ((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]), 691*da0073e9SAndroid Build Coastguard Worker # Mismatched (nested) structures 692*da0073e9SAndroid Build Coastguard Worker ((1, [2, 3]), (0, (0, 0)), None), 693*da0073e9SAndroid Build Coastguard Worker ((1, [2, 3]), (0, [0, 0, 0]), None), 694*da0073e9SAndroid Build Coastguard Worker # Broadcasting single value 695*da0073e9SAndroid Build Coastguard Worker (1, (0, 0, 0), [1, 1, 1]), 696*da0073e9SAndroid Build Coastguard Worker (1, [0, 0, 0], [1, 1, 1]), 697*da0073e9SAndroid Build Coastguard Worker (1, {"a": 0, "b": 0}, [1, 1]), 698*da0073e9SAndroid Build Coastguard Worker (1, (0, [0, [0]], 0), [1, 1, 1, 1]), 699*da0073e9SAndroid Build Coastguard Worker (1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]), 700*da0073e9SAndroid Build Coastguard Worker # Broadcast multiple things 701*da0073e9SAndroid Build Coastguard Worker ((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]), 702*da0073e9SAndroid Build Coastguard Worker ((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]), 703*da0073e9SAndroid Build Coastguard Worker (([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]), 704*da0073e9SAndroid Build Coastguard Worker ] 705*da0073e9SAndroid Build Coastguard Worker for pytree, to_pytree, expected in cases: 706*da0073e9SAndroid Build Coastguard Worker _, to_spec = pytree_impl.tree_flatten(to_pytree) 707*da0073e9SAndroid Build Coastguard Worker result = pytree_impl._broadcast_to_and_flatten(pytree, to_spec) 708*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected, msg=str([pytree, to_spec, expected])) 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker @parametrize( 711*da0073e9SAndroid Build Coastguard Worker "pytree_impl", 712*da0073e9SAndroid Build Coastguard Worker [ 713*da0073e9SAndroid Build Coastguard Worker subtest(py_pytree, name="py"), 714*da0073e9SAndroid Build Coastguard Worker subtest(cxx_pytree, name="cxx"), 715*da0073e9SAndroid Build Coastguard Worker ], 716*da0073e9SAndroid Build Coastguard Worker ) 717*da0073e9SAndroid Build Coastguard Worker def test_pytree_serialize_bad_input(self, pytree_impl): 718*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 719*da0073e9SAndroid Build Coastguard Worker pytree_impl.treespec_dumps("random_blurb") 720*da0073e9SAndroid Build Coastguard Worker 721*da0073e9SAndroid Build Coastguard Worker 722*da0073e9SAndroid Build Coastguard Workerclass TestPythonPytree(TestCase): 723*da0073e9SAndroid Build Coastguard Worker def test_deprecated_register_pytree_node(self): 724*da0073e9SAndroid Build Coastguard Worker class DummyType: 725*da0073e9SAndroid Build Coastguard Worker def __init__(self, x, y): 726*da0073e9SAndroid Build Coastguard Worker self.x = x 727*da0073e9SAndroid Build Coastguard Worker self.y = y 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 730*da0073e9SAndroid Build Coastguard Worker FutureWarning, "torch.utils._pytree._register_pytree_node" 731*da0073e9SAndroid Build Coastguard Worker ): 732*da0073e9SAndroid Build Coastguard Worker py_pytree._register_pytree_node( 733*da0073e9SAndroid Build Coastguard Worker DummyType, 734*da0073e9SAndroid Build Coastguard Worker lambda dummy: ([dummy.x, dummy.y], None), 735*da0073e9SAndroid Build Coastguard Worker lambda xs, _: DummyType(*xs), 736*da0073e9SAndroid Build Coastguard Worker ) 737*da0073e9SAndroid Build Coastguard Worker 738*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "already registered"): 739*da0073e9SAndroid Build Coastguard Worker py_pytree._register_pytree_node( 740*da0073e9SAndroid Build Coastguard Worker DummyType, 741*da0073e9SAndroid Build Coastguard Worker lambda dummy: ([dummy.x, dummy.y], None), 742*da0073e9SAndroid Build Coastguard Worker lambda xs, _: DummyType(*xs), 743*da0073e9SAndroid Build Coastguard Worker ) 744*da0073e9SAndroid Build Coastguard Worker 745*da0073e9SAndroid Build Coastguard Worker def test_import_pytree_doesnt_import_optree(self): 746*da0073e9SAndroid Build Coastguard Worker # importing torch.utils._pytree shouldn't import optree. 747*da0073e9SAndroid Build Coastguard Worker # only importing torch.utils._cxx_pytree should. 748*da0073e9SAndroid Build Coastguard Worker script = """ 749*da0073e9SAndroid Build Coastguard Workerimport sys 750*da0073e9SAndroid Build Coastguard Workerimport torch 751*da0073e9SAndroid Build Coastguard Workerimport torch.utils._pytree 752*da0073e9SAndroid Build Coastguard Workerassert "torch.utils._pytree" in sys.modules 753*da0073e9SAndroid Build Coastguard Workerif "torch.utils._cxx_pytree" in sys.modules: 754*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("importing torch.utils._pytree should not import torch.utils._cxx_pytree") 755*da0073e9SAndroid Build Coastguard Workerif "optree" in sys.modules: 756*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("importing torch.utils._pytree should not import optree") 757*da0073e9SAndroid Build Coastguard Worker""" 758*da0073e9SAndroid Build Coastguard Worker try: 759*da0073e9SAndroid Build Coastguard Worker subprocess.check_output( 760*da0073e9SAndroid Build Coastguard Worker [sys.executable, "-c", script], 761*da0073e9SAndroid Build Coastguard Worker stderr=subprocess.STDOUT, 762*da0073e9SAndroid Build Coastguard Worker # On Windows, opening the subprocess with the default CWD makes `import torch` 763*da0073e9SAndroid Build Coastguard Worker # fail, so just set CWD to this script's directory 764*da0073e9SAndroid Build Coastguard Worker cwd=os.path.dirname(os.path.realpath(__file__)), 765*da0073e9SAndroid Build Coastguard Worker ) 766*da0073e9SAndroid Build Coastguard Worker except subprocess.CalledProcessError as e: 767*da0073e9SAndroid Build Coastguard Worker self.fail( 768*da0073e9SAndroid Build Coastguard Worker msg=( 769*da0073e9SAndroid Build Coastguard Worker "Subprocess exception while attempting to run test: " 770*da0073e9SAndroid Build Coastguard Worker + e.output.decode("utf-8") 771*da0073e9SAndroid Build Coastguard Worker ) 772*da0073e9SAndroid Build Coastguard Worker ) 773*da0073e9SAndroid Build Coastguard Worker 774*da0073e9SAndroid Build Coastguard Worker def test_treespec_equality(self): 775*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 776*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 777*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 778*da0073e9SAndroid Build Coastguard Worker ) 779*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 780*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec(list, None, []), 781*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec(list, None, []), 782*da0073e9SAndroid Build Coastguard Worker ) 783*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 784*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]), 785*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]), 786*da0073e9SAndroid Build Coastguard Worker ) 787*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 788*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec(tuple, None, []) == py_pytree.TreeSpec(list, None, []), 789*da0073e9SAndroid Build Coastguard Worker ) 790*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 791*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec(tuple, None, []) != py_pytree.TreeSpec(list, None, []), 792*da0073e9SAndroid Build Coastguard Worker ) 793*da0073e9SAndroid Build Coastguard Worker 794*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.") 795*da0073e9SAndroid Build Coastguard Worker def test_treespec_repr(self): 796*da0073e9SAndroid Build Coastguard Worker # Check that it looks sane 797*da0073e9SAndroid Build Coastguard Worker pytree = (0, [0, 0, [0]]) 798*da0073e9SAndroid Build Coastguard Worker _, spec = py_pytree.tree_flatten(pytree) 799*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 800*da0073e9SAndroid Build Coastguard Worker repr(spec), 801*da0073e9SAndroid Build Coastguard Worker ( 802*da0073e9SAndroid Build Coastguard Worker "TreeSpec(tuple, None, [*,\n" 803*da0073e9SAndroid Build Coastguard Worker " TreeSpec(list, None, [*,\n" 804*da0073e9SAndroid Build Coastguard Worker " *,\n" 805*da0073e9SAndroid Build Coastguard Worker " TreeSpec(list, None, [*])])])" 806*da0073e9SAndroid Build Coastguard Worker ), 807*da0073e9SAndroid Build Coastguard Worker ) 808*da0073e9SAndroid Build Coastguard Worker 809*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.") 810*da0073e9SAndroid Build Coastguard Worker def test_treespec_repr_dynamo(self): 811*da0073e9SAndroid Build Coastguard Worker # Check that it looks sane 812*da0073e9SAndroid Build Coastguard Worker pytree = (0, [0, 0, [0]]) 813*da0073e9SAndroid Build Coastguard Worker _, spec = py_pytree.tree_flatten(pytree) 814*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 815*da0073e9SAndroid Build Coastguard Worker repr(spec), 816*da0073e9SAndroid Build Coastguard Worker """\ 817*da0073e9SAndroid Build Coastguard WorkerTreeSpec(tuple, None, [*, 818*da0073e9SAndroid Build Coastguard Worker TreeSpec(list, None, [*, 819*da0073e9SAndroid Build Coastguard Worker *, 820*da0073e9SAndroid Build Coastguard Worker TreeSpec(list, None, [*])])])""", 821*da0073e9SAndroid Build Coastguard Worker ) 822*da0073e9SAndroid Build Coastguard Worker 823*da0073e9SAndroid Build Coastguard Worker @parametrize( 824*da0073e9SAndroid Build Coastguard Worker "spec", 825*da0073e9SAndroid Build Coastguard Worker [ 826*da0073e9SAndroid Build Coastguard Worker # py_pytree.tree_structure([]) 827*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec(list, None, []), 828*da0073e9SAndroid Build Coastguard Worker # py_pytree.tree_structure(()) 829*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec(tuple, None, []), 830*da0073e9SAndroid Build Coastguard Worker # py_pytree.tree_structure({}) 831*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec(dict, [], []), 832*da0073e9SAndroid Build Coastguard Worker # py_pytree.tree_structure([0]) 833*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec(list, None, [py_pytree.LeafSpec()]), 834*da0073e9SAndroid Build Coastguard Worker # py_pytree.tree_structure([0, 1]) 835*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec( 836*da0073e9SAndroid Build Coastguard Worker list, 837*da0073e9SAndroid Build Coastguard Worker None, 838*da0073e9SAndroid Build Coastguard Worker [ 839*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 840*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 841*da0073e9SAndroid Build Coastguard Worker ], 842*da0073e9SAndroid Build Coastguard Worker ), 843*da0073e9SAndroid Build Coastguard Worker # py_pytree.tree_structure((0, 1, 2)) 844*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec( 845*da0073e9SAndroid Build Coastguard Worker tuple, 846*da0073e9SAndroid Build Coastguard Worker None, 847*da0073e9SAndroid Build Coastguard Worker [ 848*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 849*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 850*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 851*da0073e9SAndroid Build Coastguard Worker ], 852*da0073e9SAndroid Build Coastguard Worker ), 853*da0073e9SAndroid Build Coastguard Worker # py_pytree.tree_structure({"a": 0, "b": 1, "c": 2}) 854*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec( 855*da0073e9SAndroid Build Coastguard Worker dict, 856*da0073e9SAndroid Build Coastguard Worker ["a", "b", "c"], 857*da0073e9SAndroid Build Coastguard Worker [ 858*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 859*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 860*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 861*da0073e9SAndroid Build Coastguard Worker ], 862*da0073e9SAndroid Build Coastguard Worker ), 863*da0073e9SAndroid Build Coastguard Worker # py_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})]) 864*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec( 865*da0073e9SAndroid Build Coastguard Worker OrderedDict, 866*da0073e9SAndroid Build Coastguard Worker ["a", "b", "c"], 867*da0073e9SAndroid Build Coastguard Worker [ 868*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec( 869*da0073e9SAndroid Build Coastguard Worker tuple, 870*da0073e9SAndroid Build Coastguard Worker None, 871*da0073e9SAndroid Build Coastguard Worker [ 872*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 873*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 874*da0073e9SAndroid Build Coastguard Worker ], 875*da0073e9SAndroid Build Coastguard Worker ), 876*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 877*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec( 878*da0073e9SAndroid Build Coastguard Worker dict, 879*da0073e9SAndroid Build Coastguard Worker ["a", "b", "c"], 880*da0073e9SAndroid Build Coastguard Worker [ 881*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 882*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 883*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 884*da0073e9SAndroid Build Coastguard Worker ], 885*da0073e9SAndroid Build Coastguard Worker ), 886*da0073e9SAndroid Build Coastguard Worker ], 887*da0073e9SAndroid Build Coastguard Worker ), 888*da0073e9SAndroid Build Coastguard Worker # py_pytree.tree_structure([(0, 1, [2, 3])]) 889*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec( 890*da0073e9SAndroid Build Coastguard Worker list, 891*da0073e9SAndroid Build Coastguard Worker None, 892*da0073e9SAndroid Build Coastguard Worker [ 893*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec( 894*da0073e9SAndroid Build Coastguard Worker tuple, 895*da0073e9SAndroid Build Coastguard Worker None, 896*da0073e9SAndroid Build Coastguard Worker [ 897*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 898*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 899*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec( 900*da0073e9SAndroid Build Coastguard Worker list, 901*da0073e9SAndroid Build Coastguard Worker None, 902*da0073e9SAndroid Build Coastguard Worker [ 903*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 904*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 905*da0073e9SAndroid Build Coastguard Worker ], 906*da0073e9SAndroid Build Coastguard Worker ), 907*da0073e9SAndroid Build Coastguard Worker ], 908*da0073e9SAndroid Build Coastguard Worker ), 909*da0073e9SAndroid Build Coastguard Worker ], 910*da0073e9SAndroid Build Coastguard Worker ), 911*da0073e9SAndroid Build Coastguard Worker # py_pytree.tree_structure(defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}})) 912*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec( 913*da0073e9SAndroid Build Coastguard Worker defaultdict, 914*da0073e9SAndroid Build Coastguard Worker [list, ["a", "b", "c"]], 915*da0073e9SAndroid Build Coastguard Worker [ 916*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec( 917*da0073e9SAndroid Build Coastguard Worker list, 918*da0073e9SAndroid Build Coastguard Worker None, 919*da0073e9SAndroid Build Coastguard Worker [ 920*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 921*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 922*da0073e9SAndroid Build Coastguard Worker ], 923*da0073e9SAndroid Build Coastguard Worker ), 924*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec( 925*da0073e9SAndroid Build Coastguard Worker list, 926*da0073e9SAndroid Build Coastguard Worker None, 927*da0073e9SAndroid Build Coastguard Worker [ 928*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 929*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 930*da0073e9SAndroid Build Coastguard Worker ], 931*da0073e9SAndroid Build Coastguard Worker ), 932*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec(dict, [], []), 933*da0073e9SAndroid Build Coastguard Worker ], 934*da0073e9SAndroid Build Coastguard Worker ), 935*da0073e9SAndroid Build Coastguard Worker ], 936*da0073e9SAndroid Build Coastguard Worker ) 937*da0073e9SAndroid Build Coastguard Worker def test_pytree_serialize(self, spec): 938*da0073e9SAndroid Build Coastguard Worker # Ensure that the spec is valid 939*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 940*da0073e9SAndroid Build Coastguard Worker spec, 941*da0073e9SAndroid Build Coastguard Worker py_pytree.tree_structure( 942*da0073e9SAndroid Build Coastguard Worker py_pytree.tree_unflatten([0] * spec.num_leaves, spec) 943*da0073e9SAndroid Build Coastguard Worker ), 944*da0073e9SAndroid Build Coastguard Worker ) 945*da0073e9SAndroid Build Coastguard Worker 946*da0073e9SAndroid Build Coastguard Worker serialized_spec = py_pytree.treespec_dumps(spec) 947*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(serialized_spec, str) 948*da0073e9SAndroid Build Coastguard Worker self.assertEqual(spec, py_pytree.treespec_loads(serialized_spec)) 949*da0073e9SAndroid Build Coastguard Worker 950*da0073e9SAndroid Build Coastguard Worker def test_pytree_serialize_namedtuple(self): 951*da0073e9SAndroid Build Coastguard Worker Point1 = namedtuple("Point1", ["x", "y"]) 952*da0073e9SAndroid Build Coastguard Worker py_pytree._register_namedtuple( 953*da0073e9SAndroid Build Coastguard Worker Point1, 954*da0073e9SAndroid Build Coastguard Worker serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point1", 955*da0073e9SAndroid Build Coastguard Worker ) 956*da0073e9SAndroid Build Coastguard Worker 957*da0073e9SAndroid Build Coastguard Worker spec = py_pytree.TreeSpec( 958*da0073e9SAndroid Build Coastguard Worker namedtuple, Point1, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] 959*da0073e9SAndroid Build Coastguard Worker ) 960*da0073e9SAndroid Build Coastguard Worker roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec)) 961*da0073e9SAndroid Build Coastguard Worker self.assertEqual(spec, roundtrip_spec) 962*da0073e9SAndroid Build Coastguard Worker 963*da0073e9SAndroid Build Coastguard Worker class Point2(NamedTuple): 964*da0073e9SAndroid Build Coastguard Worker x: int 965*da0073e9SAndroid Build Coastguard Worker y: int 966*da0073e9SAndroid Build Coastguard Worker 967*da0073e9SAndroid Build Coastguard Worker py_pytree._register_namedtuple( 968*da0073e9SAndroid Build Coastguard Worker Point2, 969*da0073e9SAndroid Build Coastguard Worker serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.Point2", 970*da0073e9SAndroid Build Coastguard Worker ) 971*da0073e9SAndroid Build Coastguard Worker 972*da0073e9SAndroid Build Coastguard Worker spec = py_pytree.TreeSpec( 973*da0073e9SAndroid Build Coastguard Worker namedtuple, Point2, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] 974*da0073e9SAndroid Build Coastguard Worker ) 975*da0073e9SAndroid Build Coastguard Worker roundtrip_spec = py_pytree.treespec_loads(py_pytree.treespec_dumps(spec)) 976*da0073e9SAndroid Build Coastguard Worker self.assertEqual(spec, roundtrip_spec) 977*da0073e9SAndroid Build Coastguard Worker 978*da0073e9SAndroid Build Coastguard Worker def test_pytree_serialize_namedtuple_bad(self): 979*da0073e9SAndroid Build Coastguard Worker DummyType = namedtuple("DummyType", ["x", "y"]) 980*da0073e9SAndroid Build Coastguard Worker 981*da0073e9SAndroid Build Coastguard Worker spec = py_pytree.TreeSpec( 982*da0073e9SAndroid Build Coastguard Worker namedtuple, DummyType, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] 983*da0073e9SAndroid Build Coastguard Worker ) 984*da0073e9SAndroid Build Coastguard Worker 985*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 986*da0073e9SAndroid Build Coastguard Worker NotImplementedError, "Please register using `_register_namedtuple`" 987*da0073e9SAndroid Build Coastguard Worker ): 988*da0073e9SAndroid Build Coastguard Worker py_pytree.treespec_dumps(spec) 989*da0073e9SAndroid Build Coastguard Worker 990*da0073e9SAndroid Build Coastguard Worker def test_pytree_custom_type_serialize_bad(self): 991*da0073e9SAndroid Build Coastguard Worker class DummyType: 992*da0073e9SAndroid Build Coastguard Worker def __init__(self, x, y): 993*da0073e9SAndroid Build Coastguard Worker self.x = x 994*da0073e9SAndroid Build Coastguard Worker self.y = y 995*da0073e9SAndroid Build Coastguard Worker 996*da0073e9SAndroid Build Coastguard Worker py_pytree.register_pytree_node( 997*da0073e9SAndroid Build Coastguard Worker DummyType, 998*da0073e9SAndroid Build Coastguard Worker lambda dummy: ([dummy.x, dummy.y], None), 999*da0073e9SAndroid Build Coastguard Worker lambda xs, _: DummyType(*xs), 1000*da0073e9SAndroid Build Coastguard Worker ) 1001*da0073e9SAndroid Build Coastguard Worker 1002*da0073e9SAndroid Build Coastguard Worker spec = py_pytree.TreeSpec( 1003*da0073e9SAndroid Build Coastguard Worker DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] 1004*da0073e9SAndroid Build Coastguard Worker ) 1005*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1006*da0073e9SAndroid Build Coastguard Worker NotImplementedError, "No registered serialization name" 1007*da0073e9SAndroid Build Coastguard Worker ): 1008*da0073e9SAndroid Build Coastguard Worker roundtrip_spec = py_pytree.treespec_dumps(spec) 1009*da0073e9SAndroid Build Coastguard Worker 1010*da0073e9SAndroid Build Coastguard Worker def test_pytree_custom_type_serialize(self): 1011*da0073e9SAndroid Build Coastguard Worker class DummyType: 1012*da0073e9SAndroid Build Coastguard Worker def __init__(self, x, y): 1013*da0073e9SAndroid Build Coastguard Worker self.x = x 1014*da0073e9SAndroid Build Coastguard Worker self.y = y 1015*da0073e9SAndroid Build Coastguard Worker 1016*da0073e9SAndroid Build Coastguard Worker py_pytree.register_pytree_node( 1017*da0073e9SAndroid Build Coastguard Worker DummyType, 1018*da0073e9SAndroid Build Coastguard Worker lambda dummy: ([dummy.x, dummy.y], None), 1019*da0073e9SAndroid Build Coastguard Worker lambda xs, _: DummyType(*xs), 1020*da0073e9SAndroid Build Coastguard Worker serialized_type_name="test_pytree_custom_type_serialize.DummyType", 1021*da0073e9SAndroid Build Coastguard Worker to_dumpable_context=lambda context: "moo", 1022*da0073e9SAndroid Build Coastguard Worker from_dumpable_context=lambda dumpable_context: None, 1023*da0073e9SAndroid Build Coastguard Worker ) 1024*da0073e9SAndroid Build Coastguard Worker spec = py_pytree.TreeSpec( 1025*da0073e9SAndroid Build Coastguard Worker DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] 1026*da0073e9SAndroid Build Coastguard Worker ) 1027*da0073e9SAndroid Build Coastguard Worker serialized_spec = py_pytree.treespec_dumps(spec, 1) 1028*da0073e9SAndroid Build Coastguard Worker self.assertIn("moo", serialized_spec) 1029*da0073e9SAndroid Build Coastguard Worker roundtrip_spec = py_pytree.treespec_loads(serialized_spec) 1030*da0073e9SAndroid Build Coastguard Worker self.assertEqual(roundtrip_spec, spec) 1031*da0073e9SAndroid Build Coastguard Worker 1032*da0073e9SAndroid Build Coastguard Worker def test_pytree_serialize_register_bad(self): 1033*da0073e9SAndroid Build Coastguard Worker class DummyType: 1034*da0073e9SAndroid Build Coastguard Worker def __init__(self, x, y): 1035*da0073e9SAndroid Build Coastguard Worker self.x = x 1036*da0073e9SAndroid Build Coastguard Worker self.y = y 1037*da0073e9SAndroid Build Coastguard Worker 1038*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1039*da0073e9SAndroid Build Coastguard Worker ValueError, "Both to_dumpable_context and from_dumpable_context" 1040*da0073e9SAndroid Build Coastguard Worker ): 1041*da0073e9SAndroid Build Coastguard Worker py_pytree.register_pytree_node( 1042*da0073e9SAndroid Build Coastguard Worker DummyType, 1043*da0073e9SAndroid Build Coastguard Worker lambda dummy: ([dummy.x, dummy.y], None), 1044*da0073e9SAndroid Build Coastguard Worker lambda xs, _: DummyType(*xs), 1045*da0073e9SAndroid Build Coastguard Worker serialized_type_name="test_pytree_serialize_register_bad.DummyType", 1046*da0073e9SAndroid Build Coastguard Worker to_dumpable_context=lambda context: "moo", 1047*da0073e9SAndroid Build Coastguard Worker ) 1048*da0073e9SAndroid Build Coastguard Worker 1049*da0073e9SAndroid Build Coastguard Worker def test_pytree_context_serialize_bad(self): 1050*da0073e9SAndroid Build Coastguard Worker class DummyType: 1051*da0073e9SAndroid Build Coastguard Worker def __init__(self, x, y): 1052*da0073e9SAndroid Build Coastguard Worker self.x = x 1053*da0073e9SAndroid Build Coastguard Worker self.y = y 1054*da0073e9SAndroid Build Coastguard Worker 1055*da0073e9SAndroid Build Coastguard Worker py_pytree.register_pytree_node( 1056*da0073e9SAndroid Build Coastguard Worker DummyType, 1057*da0073e9SAndroid Build Coastguard Worker lambda dummy: ([dummy.x, dummy.y], None), 1058*da0073e9SAndroid Build Coastguard Worker lambda xs, _: DummyType(*xs), 1059*da0073e9SAndroid Build Coastguard Worker serialized_type_name="test_pytree_serialize_serialize_bad.DummyType", 1060*da0073e9SAndroid Build Coastguard Worker to_dumpable_context=lambda context: DummyType, 1061*da0073e9SAndroid Build Coastguard Worker from_dumpable_context=lambda dumpable_context: None, 1062*da0073e9SAndroid Build Coastguard Worker ) 1063*da0073e9SAndroid Build Coastguard Worker 1064*da0073e9SAndroid Build Coastguard Worker spec = py_pytree.TreeSpec( 1065*da0073e9SAndroid Build Coastguard Worker DummyType, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] 1066*da0073e9SAndroid Build Coastguard Worker ) 1067*da0073e9SAndroid Build Coastguard Worker 1068*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1069*da0073e9SAndroid Build Coastguard Worker TypeError, "Object of type type is not JSON serializable" 1070*da0073e9SAndroid Build Coastguard Worker ): 1071*da0073e9SAndroid Build Coastguard Worker py_pytree.treespec_dumps(spec) 1072*da0073e9SAndroid Build Coastguard Worker 1073*da0073e9SAndroid Build Coastguard Worker def test_pytree_serialize_bad_protocol(self): 1074*da0073e9SAndroid Build Coastguard Worker import json 1075*da0073e9SAndroid Build Coastguard Worker 1076*da0073e9SAndroid Build Coastguard Worker Point = namedtuple("Point", ["x", "y"]) 1077*da0073e9SAndroid Build Coastguard Worker spec = py_pytree.TreeSpec( 1078*da0073e9SAndroid Build Coastguard Worker namedtuple, Point, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] 1079*da0073e9SAndroid Build Coastguard Worker ) 1080*da0073e9SAndroid Build Coastguard Worker py_pytree._register_namedtuple( 1081*da0073e9SAndroid Build Coastguard Worker Point, 1082*da0073e9SAndroid Build Coastguard Worker serialized_type_name="test_pytree.test_pytree_serialize_bad_protocol.Point", 1083*da0073e9SAndroid Build Coastguard Worker ) 1084*da0073e9SAndroid Build Coastguard Worker 1085*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Unknown protocol"): 1086*da0073e9SAndroid Build Coastguard Worker py_pytree.treespec_dumps(spec, -1) 1087*da0073e9SAndroid Build Coastguard Worker 1088*da0073e9SAndroid Build Coastguard Worker serialized_spec = py_pytree.treespec_dumps(spec) 1089*da0073e9SAndroid Build Coastguard Worker protocol, data = json.loads(serialized_spec) 1090*da0073e9SAndroid Build Coastguard Worker bad_protocol_serialized_spec = json.dumps((-1, data)) 1091*da0073e9SAndroid Build Coastguard Worker 1092*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Unknown protocol"): 1093*da0073e9SAndroid Build Coastguard Worker py_pytree.treespec_loads(bad_protocol_serialized_spec) 1094*da0073e9SAndroid Build Coastguard Worker 1095*da0073e9SAndroid Build Coastguard Worker def test_saved_serialized(self): 1096*da0073e9SAndroid Build Coastguard Worker # py_pytree.tree_structure(OrderedDict([(1, (0, 1)), (2, 2), (3, {4: 3, 5: 4, 6: 5})])) 1097*da0073e9SAndroid Build Coastguard Worker complicated_spec = py_pytree.TreeSpec( 1098*da0073e9SAndroid Build Coastguard Worker OrderedDict, 1099*da0073e9SAndroid Build Coastguard Worker [1, 2, 3], 1100*da0073e9SAndroid Build Coastguard Worker [ 1101*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec( 1102*da0073e9SAndroid Build Coastguard Worker tuple, None, [py_pytree.LeafSpec(), py_pytree.LeafSpec()] 1103*da0073e9SAndroid Build Coastguard Worker ), 1104*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 1105*da0073e9SAndroid Build Coastguard Worker py_pytree.TreeSpec( 1106*da0073e9SAndroid Build Coastguard Worker dict, 1107*da0073e9SAndroid Build Coastguard Worker [4, 5, 6], 1108*da0073e9SAndroid Build Coastguard Worker [ 1109*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 1110*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 1111*da0073e9SAndroid Build Coastguard Worker py_pytree.LeafSpec(), 1112*da0073e9SAndroid Build Coastguard Worker ], 1113*da0073e9SAndroid Build Coastguard Worker ), 1114*da0073e9SAndroid Build Coastguard Worker ], 1115*da0073e9SAndroid Build Coastguard Worker ) 1116*da0073e9SAndroid Build Coastguard Worker # Ensure that the spec is valid 1117*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1118*da0073e9SAndroid Build Coastguard Worker complicated_spec, 1119*da0073e9SAndroid Build Coastguard Worker py_pytree.tree_structure( 1120*da0073e9SAndroid Build Coastguard Worker py_pytree.tree_unflatten( 1121*da0073e9SAndroid Build Coastguard Worker [0] * complicated_spec.num_leaves, complicated_spec 1122*da0073e9SAndroid Build Coastguard Worker ) 1123*da0073e9SAndroid Build Coastguard Worker ), 1124*da0073e9SAndroid Build Coastguard Worker ) 1125*da0073e9SAndroid Build Coastguard Worker 1126*da0073e9SAndroid Build Coastguard Worker serialized_spec = py_pytree.treespec_dumps(complicated_spec) 1127*da0073e9SAndroid Build Coastguard Worker saved_spec = ( 1128*da0073e9SAndroid Build Coastguard Worker '[1, {"type": "collections.OrderedDict", "context": "[1, 2, 3]", ' 1129*da0073e9SAndroid Build Coastguard Worker '"children_spec": [{"type": "builtins.tuple", "context": "null", ' 1130*da0073e9SAndroid Build Coastguard Worker '"children_spec": [{"type": null, "context": null, ' 1131*da0073e9SAndroid Build Coastguard Worker '"children_spec": []}, {"type": null, "context": null, ' 1132*da0073e9SAndroid Build Coastguard Worker '"children_spec": []}]}, {"type": null, "context": null, ' 1133*da0073e9SAndroid Build Coastguard Worker '"children_spec": []}, {"type": "builtins.dict", "context": ' 1134*da0073e9SAndroid Build Coastguard Worker '"[4, 5, 6]", "children_spec": [{"type": null, "context": null, ' 1135*da0073e9SAndroid Build Coastguard Worker '"children_spec": []}, {"type": null, "context": null, "children_spec": ' 1136*da0073e9SAndroid Build Coastguard Worker '[]}, {"type": null, "context": null, "children_spec": []}]}]}]' 1137*da0073e9SAndroid Build Coastguard Worker ) 1138*da0073e9SAndroid Build Coastguard Worker self.assertEqual(serialized_spec, saved_spec) 1139*da0073e9SAndroid Build Coastguard Worker self.assertEqual(complicated_spec, py_pytree.treespec_loads(saved_spec)) 1140*da0073e9SAndroid Build Coastguard Worker 1141*da0073e9SAndroid Build Coastguard Worker def test_tree_map_with_path(self): 1142*da0073e9SAndroid Build Coastguard Worker tree = [{i: i for i in range(10)}] 1143*da0073e9SAndroid Build Coastguard Worker all_zeros = py_pytree.tree_map_with_path( 1144*da0073e9SAndroid Build Coastguard Worker lambda kp, val: val - kp[1].key + kp[0].idx, tree 1145*da0073e9SAndroid Build Coastguard Worker ) 1146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(all_zeros, [dict.fromkeys(range(10), 0)]) 1147*da0073e9SAndroid Build Coastguard Worker 1148*da0073e9SAndroid Build Coastguard Worker def test_tree_map_with_path_multiple_trees(self): 1149*da0073e9SAndroid Build Coastguard Worker @dataclass 1150*da0073e9SAndroid Build Coastguard Worker class ACustomPytree: 1151*da0073e9SAndroid Build Coastguard Worker x: Any 1152*da0073e9SAndroid Build Coastguard Worker y: Any 1153*da0073e9SAndroid Build Coastguard Worker z: Any 1154*da0073e9SAndroid Build Coastguard Worker 1155*da0073e9SAndroid Build Coastguard Worker tree1 = [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5] 1156*da0073e9SAndroid Build Coastguard Worker tree2 = [ACustomPytree(x=2, y={"cin": [2, 2, 2], "bar": 2}, z="leaf"), 2] 1157*da0073e9SAndroid Build Coastguard Worker 1158*da0073e9SAndroid Build Coastguard Worker py_pytree.register_pytree_node( 1159*da0073e9SAndroid Build Coastguard Worker ACustomPytree, 1160*da0073e9SAndroid Build Coastguard Worker flatten_fn=lambda f: ([f.x, f.y], f.z), 1161*da0073e9SAndroid Build Coastguard Worker unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z), 1162*da0073e9SAndroid Build Coastguard Worker flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z), 1163*da0073e9SAndroid Build Coastguard Worker ) 1164*da0073e9SAndroid Build Coastguard Worker from_two_trees = py_pytree.tree_map_with_path( 1165*da0073e9SAndroid Build Coastguard Worker lambda kp, a, b: a + b, tree1, tree2 1166*da0073e9SAndroid Build Coastguard Worker ) 1167*da0073e9SAndroid Build Coastguard Worker from_one_tree = py_pytree.tree_map(lambda a: a + 2, tree1) 1168*da0073e9SAndroid Build Coastguard Worker self.assertEqual(from_two_trees, from_one_tree) 1169*da0073e9SAndroid Build Coastguard Worker 1170*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("dynamo pytree tracing doesn't work here") 1171*da0073e9SAndroid Build Coastguard Worker def test_tree_flatten_with_path_is_leaf(self): 1172*da0073e9SAndroid Build Coastguard Worker leaf_dict = {"foo": [(3)]} 1173*da0073e9SAndroid Build Coastguard Worker pytree = (["hello", [1, 2], leaf_dict],) 1174*da0073e9SAndroid Build Coastguard Worker key_leaves, spec = py_pytree.tree_flatten_with_path( 1175*da0073e9SAndroid Build Coastguard Worker pytree, is_leaf=lambda x: isinstance(x, dict) 1176*da0073e9SAndroid Build Coastguard Worker ) 1177*da0073e9SAndroid Build Coastguard Worker self.assertTrue(key_leaves[-1][1] is leaf_dict) 1178*da0073e9SAndroid Build Coastguard Worker 1179*da0073e9SAndroid Build Coastguard Worker def test_tree_flatten_with_path_roundtrip(self): 1180*da0073e9SAndroid Build Coastguard Worker class ANamedTuple(NamedTuple): 1181*da0073e9SAndroid Build Coastguard Worker x: torch.Tensor 1182*da0073e9SAndroid Build Coastguard Worker y: int 1183*da0073e9SAndroid Build Coastguard Worker z: str 1184*da0073e9SAndroid Build Coastguard Worker 1185*da0073e9SAndroid Build Coastguard Worker @dataclass 1186*da0073e9SAndroid Build Coastguard Worker class ACustomPytree: 1187*da0073e9SAndroid Build Coastguard Worker x: Any 1188*da0073e9SAndroid Build Coastguard Worker y: Any 1189*da0073e9SAndroid Build Coastguard Worker z: Any 1190*da0073e9SAndroid Build Coastguard Worker 1191*da0073e9SAndroid Build Coastguard Worker py_pytree.register_pytree_node( 1192*da0073e9SAndroid Build Coastguard Worker ACustomPytree, 1193*da0073e9SAndroid Build Coastguard Worker flatten_fn=lambda f: ([f.x, f.y], f.z), 1194*da0073e9SAndroid Build Coastguard Worker unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z), 1195*da0073e9SAndroid Build Coastguard Worker flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z), 1196*da0073e9SAndroid Build Coastguard Worker ) 1197*da0073e9SAndroid Build Coastguard Worker 1198*da0073e9SAndroid Build Coastguard Worker SOME_PYTREES = [ 1199*da0073e9SAndroid Build Coastguard Worker (None,), 1200*da0073e9SAndroid Build Coastguard Worker ["hello", [1, 2], {"foo": [(3)]}], 1201*da0073e9SAndroid Build Coastguard Worker [ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")], 1202*da0073e9SAndroid Build Coastguard Worker [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5], 1203*da0073e9SAndroid Build Coastguard Worker ] 1204*da0073e9SAndroid Build Coastguard Worker for pytree in SOME_PYTREES: 1205*da0073e9SAndroid Build Coastguard Worker key_leaves, spec = py_pytree.tree_flatten_with_path(pytree) 1206*da0073e9SAndroid Build Coastguard Worker actual = py_pytree.tree_unflatten([leaf for _, leaf in key_leaves], spec) 1207*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, pytree) 1208*da0073e9SAndroid Build Coastguard Worker 1209*da0073e9SAndroid Build Coastguard Worker def test_tree_leaves_with_path(self): 1210*da0073e9SAndroid Build Coastguard Worker class ANamedTuple(NamedTuple): 1211*da0073e9SAndroid Build Coastguard Worker x: torch.Tensor 1212*da0073e9SAndroid Build Coastguard Worker y: int 1213*da0073e9SAndroid Build Coastguard Worker z: str 1214*da0073e9SAndroid Build Coastguard Worker 1215*da0073e9SAndroid Build Coastguard Worker @dataclass 1216*da0073e9SAndroid Build Coastguard Worker class ACustomPytree: 1217*da0073e9SAndroid Build Coastguard Worker x: Any 1218*da0073e9SAndroid Build Coastguard Worker y: Any 1219*da0073e9SAndroid Build Coastguard Worker z: Any 1220*da0073e9SAndroid Build Coastguard Worker 1221*da0073e9SAndroid Build Coastguard Worker py_pytree.register_pytree_node( 1222*da0073e9SAndroid Build Coastguard Worker ACustomPytree, 1223*da0073e9SAndroid Build Coastguard Worker flatten_fn=lambda f: ([f.x, f.y], f.z), 1224*da0073e9SAndroid Build Coastguard Worker unflatten_fn=lambda xy, z: ACustomPytree(xy[0], xy[1], z), 1225*da0073e9SAndroid Build Coastguard Worker flatten_with_keys_fn=lambda f: ((("x", f.x), ("y", f.y)), f.z), 1226*da0073e9SAndroid Build Coastguard Worker ) 1227*da0073e9SAndroid Build Coastguard Worker 1228*da0073e9SAndroid Build Coastguard Worker SOME_PYTREES = [ 1229*da0073e9SAndroid Build Coastguard Worker (None,), 1230*da0073e9SAndroid Build Coastguard Worker ["hello", [1, 2], {"foo": [(3)]}], 1231*da0073e9SAndroid Build Coastguard Worker [ANamedTuple(x=torch.rand(2, 3), y=1, z="foo")], 1232*da0073e9SAndroid Build Coastguard Worker [ACustomPytree(x=12, y={"cin": [1, 4, 10], "bar": 18}, z="leaf"), 5], 1233*da0073e9SAndroid Build Coastguard Worker ] 1234*da0073e9SAndroid Build Coastguard Worker for pytree in SOME_PYTREES: 1235*da0073e9SAndroid Build Coastguard Worker flat_out, _ = py_pytree.tree_flatten_with_path(pytree) 1236*da0073e9SAndroid Build Coastguard Worker leaves_out = py_pytree.tree_leaves_with_path(pytree) 1237*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat_out, leaves_out) 1238*da0073e9SAndroid Build Coastguard Worker 1239*da0073e9SAndroid Build Coastguard Worker def test_key_str(self): 1240*da0073e9SAndroid Build Coastguard Worker class ANamedTuple(NamedTuple): 1241*da0073e9SAndroid Build Coastguard Worker x: str 1242*da0073e9SAndroid Build Coastguard Worker y: int 1243*da0073e9SAndroid Build Coastguard Worker 1244*da0073e9SAndroid Build Coastguard Worker tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],) 1245*da0073e9SAndroid Build Coastguard Worker flat, _ = py_pytree.tree_flatten_with_path(tree) 1246*da0073e9SAndroid Build Coastguard Worker paths = [f"{py_pytree.keystr(kp)}: {val}" for kp, val in flat] 1247*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1248*da0073e9SAndroid Build Coastguard Worker paths, 1249*da0073e9SAndroid Build Coastguard Worker [ 1250*da0073e9SAndroid Build Coastguard Worker "[0][0]: hello", 1251*da0073e9SAndroid Build Coastguard Worker "[0][1][0]: 1", 1252*da0073e9SAndroid Build Coastguard Worker "[0][1][1]: 2", 1253*da0073e9SAndroid Build Coastguard Worker "[0][2]['foo'][0]: 3", 1254*da0073e9SAndroid Build Coastguard Worker "[0][2]['bar'][0].x: baz", 1255*da0073e9SAndroid Build Coastguard Worker "[0][2]['bar'][0].y: 10", 1256*da0073e9SAndroid Build Coastguard Worker ], 1257*da0073e9SAndroid Build Coastguard Worker ) 1258*da0073e9SAndroid Build Coastguard Worker 1259*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("AssertionError in dynamo") 1260*da0073e9SAndroid Build Coastguard Worker def test_flatten_flatten_with_key_consistency(self): 1261*da0073e9SAndroid Build Coastguard Worker """Check that flatten and flatten_with_key produces consistent leaves/context.""" 1262*da0073e9SAndroid Build Coastguard Worker reg = py_pytree.SUPPORTED_NODES 1263*da0073e9SAndroid Build Coastguard Worker 1264*da0073e9SAndroid Build Coastguard Worker EXAMPLE_TREE = { 1265*da0073e9SAndroid Build Coastguard Worker list: [1, 2, 3], 1266*da0073e9SAndroid Build Coastguard Worker tuple: (1, 2, 3), 1267*da0073e9SAndroid Build Coastguard Worker dict: {"foo": 1, "bar": 2}, 1268*da0073e9SAndroid Build Coastguard Worker namedtuple: collections.namedtuple("ANamedTuple", ["x", "y"])(1, 2), 1269*da0073e9SAndroid Build Coastguard Worker OrderedDict: OrderedDict([("foo", 1), ("bar", 2)]), 1270*da0073e9SAndroid Build Coastguard Worker defaultdict: defaultdict(int, {"foo": 1, "bar": 2}), 1271*da0073e9SAndroid Build Coastguard Worker deque: deque([1, 2, 3]), 1272*da0073e9SAndroid Build Coastguard Worker torch.Size: torch.Size([1, 2, 3]), 1273*da0073e9SAndroid Build Coastguard Worker immutable_dict: immutable_dict({"foo": 1, "bar": 2}), 1274*da0073e9SAndroid Build Coastguard Worker immutable_list: immutable_list([1, 2, 3]), 1275*da0073e9SAndroid Build Coastguard Worker } 1276*da0073e9SAndroid Build Coastguard Worker 1277*da0073e9SAndroid Build Coastguard Worker for typ in reg: 1278*da0073e9SAndroid Build Coastguard Worker example = EXAMPLE_TREE.get(typ) 1279*da0073e9SAndroid Build Coastguard Worker if example is None: 1280*da0073e9SAndroid Build Coastguard Worker continue 1281*da0073e9SAndroid Build Coastguard Worker flat_with_path, spec1 = py_pytree.tree_flatten_with_path(example) 1282*da0073e9SAndroid Build Coastguard Worker flat, spec2 = py_pytree.tree_flatten(example) 1283*da0073e9SAndroid Build Coastguard Worker 1284*da0073e9SAndroid Build Coastguard Worker self.assertEqual(flat, [x[1] for x in flat_with_path]) 1285*da0073e9SAndroid Build Coastguard Worker self.assertEqual(spec1, spec2) 1286*da0073e9SAndroid Build Coastguard Worker 1287*da0073e9SAndroid Build Coastguard Worker def test_key_access(self): 1288*da0073e9SAndroid Build Coastguard Worker class ANamedTuple(NamedTuple): 1289*da0073e9SAndroid Build Coastguard Worker x: str 1290*da0073e9SAndroid Build Coastguard Worker y: int 1291*da0073e9SAndroid Build Coastguard Worker 1292*da0073e9SAndroid Build Coastguard Worker tree = (["hello", [1, 2], {"foo": [(3)], "bar": [ANamedTuple(x="baz", y=10)]}],) 1293*da0073e9SAndroid Build Coastguard Worker flat, _ = py_pytree.tree_flatten_with_path(tree) 1294*da0073e9SAndroid Build Coastguard Worker for kp, val in flat: 1295*da0073e9SAndroid Build Coastguard Worker self.assertEqual(py_pytree.key_get(tree, kp), val) 1296*da0073e9SAndroid Build Coastguard Worker 1297*da0073e9SAndroid Build Coastguard Worker 1298*da0073e9SAndroid Build Coastguard Workerclass TestCxxPytree(TestCase): 1299*da0073e9SAndroid Build Coastguard Worker def setUp(self): 1300*da0073e9SAndroid Build Coastguard Worker if IS_FBCODE: 1301*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("C++ pytree tests are not supported in fbcode") 1302*da0073e9SAndroid Build Coastguard Worker 1303*da0073e9SAndroid Build Coastguard Worker def test_treespec_equality(self): 1304*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec()) 1305*da0073e9SAndroid Build Coastguard Worker 1306*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(TEST_WITH_TORCHDYNAMO, "Dynamo test in test_treespec_repr_dynamo.") 1307*da0073e9SAndroid Build Coastguard Worker def test_treespec_repr(self): 1308*da0073e9SAndroid Build Coastguard Worker # Check that it looks sane 1309*da0073e9SAndroid Build Coastguard Worker pytree = (0, [0, 0, [0]]) 1310*da0073e9SAndroid Build Coastguard Worker _, spec = cxx_pytree.tree_flatten(pytree) 1311*da0073e9SAndroid Build Coastguard Worker self.assertEqual(repr(spec), "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)") 1312*da0073e9SAndroid Build Coastguard Worker 1313*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.") 1314*da0073e9SAndroid Build Coastguard Worker def test_treespec_repr_dynamo(self): 1315*da0073e9SAndroid Build Coastguard Worker # Check that it looks sane 1316*da0073e9SAndroid Build Coastguard Worker pytree = (0, [0, 0, [0]]) 1317*da0073e9SAndroid Build Coastguard Worker _, spec = cxx_pytree.tree_flatten(pytree) 1318*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1319*da0073e9SAndroid Build Coastguard Worker repr(spec), 1320*da0073e9SAndroid Build Coastguard Worker "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)", 1321*da0073e9SAndroid Build Coastguard Worker ) 1322*da0073e9SAndroid Build Coastguard Worker 1323*da0073e9SAndroid Build Coastguard Worker @parametrize( 1324*da0073e9SAndroid Build Coastguard Worker "spec", 1325*da0073e9SAndroid Build Coastguard Worker [ 1326*da0073e9SAndroid Build Coastguard Worker cxx_pytree.tree_structure([]), 1327*da0073e9SAndroid Build Coastguard Worker cxx_pytree.tree_structure(()), 1328*da0073e9SAndroid Build Coastguard Worker cxx_pytree.tree_structure({}), 1329*da0073e9SAndroid Build Coastguard Worker cxx_pytree.tree_structure([0]), 1330*da0073e9SAndroid Build Coastguard Worker cxx_pytree.tree_structure([0, 1]), 1331*da0073e9SAndroid Build Coastguard Worker cxx_pytree.tree_structure((0, 1, 2)), 1332*da0073e9SAndroid Build Coastguard Worker cxx_pytree.tree_structure({"a": 0, "b": 1, "c": 2}), 1333*da0073e9SAndroid Build Coastguard Worker cxx_pytree.tree_structure( 1334*da0073e9SAndroid Build Coastguard Worker OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})]) 1335*da0073e9SAndroid Build Coastguard Worker ), 1336*da0073e9SAndroid Build Coastguard Worker cxx_pytree.tree_structure([(0, 1, [2, 3])]), 1337*da0073e9SAndroid Build Coastguard Worker cxx_pytree.tree_structure( 1338*da0073e9SAndroid Build Coastguard Worker defaultdict(list, {"a": [0, 1], "b": [1, 2], "c": {}}) 1339*da0073e9SAndroid Build Coastguard Worker ), 1340*da0073e9SAndroid Build Coastguard Worker ], 1341*da0073e9SAndroid Build Coastguard Worker ) 1342*da0073e9SAndroid Build Coastguard Worker def test_pytree_serialize(self, spec): 1343*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1344*da0073e9SAndroid Build Coastguard Worker spec, 1345*da0073e9SAndroid Build Coastguard Worker cxx_pytree.tree_structure( 1346*da0073e9SAndroid Build Coastguard Worker cxx_pytree.tree_unflatten([0] * spec.num_leaves, spec) 1347*da0073e9SAndroid Build Coastguard Worker ), 1348*da0073e9SAndroid Build Coastguard Worker ) 1349*da0073e9SAndroid Build Coastguard Worker 1350*da0073e9SAndroid Build Coastguard Worker serialized_spec = cxx_pytree.treespec_dumps(spec) 1351*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(serialized_spec, str) 1352*da0073e9SAndroid Build Coastguard Worker self.assertEqual(spec, cxx_pytree.treespec_loads(serialized_spec)) 1353*da0073e9SAndroid Build Coastguard Worker 1354*da0073e9SAndroid Build Coastguard Worker def test_pytree_serialize_namedtuple(self): 1355*da0073e9SAndroid Build Coastguard Worker py_pytree._register_namedtuple( 1356*da0073e9SAndroid Build Coastguard Worker GlobalPoint, 1357*da0073e9SAndroid Build Coastguard Worker serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.GlobalPoint", 1358*da0073e9SAndroid Build Coastguard Worker ) 1359*da0073e9SAndroid Build Coastguard Worker spec = cxx_pytree.tree_structure(GlobalPoint(0, 1)) 1360*da0073e9SAndroid Build Coastguard Worker 1361*da0073e9SAndroid Build Coastguard Worker roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec)) 1362*da0073e9SAndroid Build Coastguard Worker self.assertEqual(roundtrip_spec.type._fields, spec.type._fields) 1363*da0073e9SAndroid Build Coastguard Worker 1364*da0073e9SAndroid Build Coastguard Worker LocalPoint = namedtuple("LocalPoint", ["x", "y"]) 1365*da0073e9SAndroid Build Coastguard Worker py_pytree._register_namedtuple( 1366*da0073e9SAndroid Build Coastguard Worker LocalPoint, 1367*da0073e9SAndroid Build Coastguard Worker serialized_type_name="test_pytree.test_pytree_serialize_namedtuple.LocalPoint", 1368*da0073e9SAndroid Build Coastguard Worker ) 1369*da0073e9SAndroid Build Coastguard Worker spec = cxx_pytree.tree_structure(LocalPoint(0, 1)) 1370*da0073e9SAndroid Build Coastguard Worker 1371*da0073e9SAndroid Build Coastguard Worker roundtrip_spec = cxx_pytree.treespec_loads(cxx_pytree.treespec_dumps(spec)) 1372*da0073e9SAndroid Build Coastguard Worker self.assertEqual(roundtrip_spec.type._fields, spec.type._fields) 1373*da0073e9SAndroid Build Coastguard Worker 1374*da0073e9SAndroid Build Coastguard Worker def test_pytree_custom_type_serialize(self): 1375*da0073e9SAndroid Build Coastguard Worker cxx_pytree.register_pytree_node( 1376*da0073e9SAndroid Build Coastguard Worker GlobalDummyType, 1377*da0073e9SAndroid Build Coastguard Worker lambda dummy: ([dummy.x, dummy.y], None), 1378*da0073e9SAndroid Build Coastguard Worker lambda xs, _: GlobalDummyType(*xs), 1379*da0073e9SAndroid Build Coastguard Worker serialized_type_name="GlobalDummyType", 1380*da0073e9SAndroid Build Coastguard Worker ) 1381*da0073e9SAndroid Build Coastguard Worker spec = cxx_pytree.tree_structure(GlobalDummyType(0, 1)) 1382*da0073e9SAndroid Build Coastguard Worker serialized_spec = cxx_pytree.treespec_dumps(spec) 1383*da0073e9SAndroid Build Coastguard Worker roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec) 1384*da0073e9SAndroid Build Coastguard Worker self.assertEqual(roundtrip_spec, spec) 1385*da0073e9SAndroid Build Coastguard Worker 1386*da0073e9SAndroid Build Coastguard Worker class LocalDummyType: 1387*da0073e9SAndroid Build Coastguard Worker def __init__(self, x, y): 1388*da0073e9SAndroid Build Coastguard Worker self.x = x 1389*da0073e9SAndroid Build Coastguard Worker self.y = y 1390*da0073e9SAndroid Build Coastguard Worker 1391*da0073e9SAndroid Build Coastguard Worker cxx_pytree.register_pytree_node( 1392*da0073e9SAndroid Build Coastguard Worker LocalDummyType, 1393*da0073e9SAndroid Build Coastguard Worker lambda dummy: ([dummy.x, dummy.y], None), 1394*da0073e9SAndroid Build Coastguard Worker lambda xs, _: LocalDummyType(*xs), 1395*da0073e9SAndroid Build Coastguard Worker serialized_type_name="LocalDummyType", 1396*da0073e9SAndroid Build Coastguard Worker ) 1397*da0073e9SAndroid Build Coastguard Worker spec = cxx_pytree.tree_structure(LocalDummyType(0, 1)) 1398*da0073e9SAndroid Build Coastguard Worker serialized_spec = cxx_pytree.treespec_dumps(spec) 1399*da0073e9SAndroid Build Coastguard Worker roundtrip_spec = cxx_pytree.treespec_loads(serialized_spec) 1400*da0073e9SAndroid Build Coastguard Worker self.assertEqual(roundtrip_spec, spec) 1401*da0073e9SAndroid Build Coastguard Worker 1402*da0073e9SAndroid Build Coastguard Worker 1403*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestGenericPytree) 1404*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestPythonPytree) 1405*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestCxxPytree) 1406*da0073e9SAndroid Build Coastguard Worker 1407*da0073e9SAndroid Build Coastguard Worker 1408*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 1409*da0073e9SAndroid Build Coastguard Worker run_tests() 1410