1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport unittest 8*523fa7a6SAndroid Build Coastguard Workerfrom collections import namedtuple 9*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Dict 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Workerimport torch 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Worker# @manual=//executorch/extension/pytree:pybindings 14*523fa7a6SAndroid Build Coastguard Workerfrom executorch.extension.pytree import ( 15*523fa7a6SAndroid Build Coastguard Worker broadcast_to_and_flatten, 16*523fa7a6SAndroid Build Coastguard Worker register_custom, 17*523fa7a6SAndroid Build Coastguard Worker tree_flatten, 18*523fa7a6SAndroid Build Coastguard Worker tree_map, 19*523fa7a6SAndroid Build Coastguard Worker tree_unflatten, 20*523fa7a6SAndroid Build Coastguard Worker TreeSpec, 21*523fa7a6SAndroid Build Coastguard Worker) 22*523fa7a6SAndroid Build Coastguard Worker 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Worker# pyre-fixme[11]: Annotation `TreeSpec` is not defined as a type. 25*523fa7a6SAndroid Build Coastguard Workerdef _spec(o: Any) -> TreeSpec: 26*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. 27*523fa7a6SAndroid Build Coastguard Worker _, spec = tree_flatten(o) 28*523fa7a6SAndroid Build Coastguard Worker return spec 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Worker 31*523fa7a6SAndroid Build Coastguard Worker# Constructs string representation of pytree spec of type specified by type_char (can be 'T' for tuple, 'L' for List) argument, that contains n children, each with single leaf. 32*523fa7a6SAndroid Build Coastguard Worker# e.g. ('T', 3) -> 'T3#1#1#1($,$,$)' 33*523fa7a6SAndroid Build Coastguard Workerdef _spec_str(type_char, n: int) -> str: 34*523fa7a6SAndroid Build Coastguard Worker spec = type_char + str(n) 35*523fa7a6SAndroid Build Coastguard Worker for _ in range(n): 36*523fa7a6SAndroid Build Coastguard Worker spec += "#1" 37*523fa7a6SAndroid Build Coastguard Worker spec += "(" 38*523fa7a6SAndroid Build Coastguard Worker for i in range(n): 39*523fa7a6SAndroid Build Coastguard Worker if i > 0: 40*523fa7a6SAndroid Build Coastguard Worker spec += "," 41*523fa7a6SAndroid Build Coastguard Worker spec += "$" 42*523fa7a6SAndroid Build Coastguard Worker spec += ")" 43*523fa7a6SAndroid Build Coastguard Worker return spec 44*523fa7a6SAndroid Build Coastguard Worker 45*523fa7a6SAndroid Build Coastguard Worker 46*523fa7a6SAndroid Build Coastguard Worker# Constructs string representation of pytree spec of Dict, keys can be str or int, every value is leaf. 47*523fa7a6SAndroid Build Coastguard Worker# e.g.: {'a': 1, 2: 2} -> D2#1#1('a':$,2:$) 48*523fa7a6SAndroid Build Coastguard Workerdef _spec_str_dict(d: Dict[Any, Any]) -> str: 49*523fa7a6SAndroid Build Coastguard Worker n = len(d) 50*523fa7a6SAndroid Build Coastguard Worker spec = "D" + str(n) 51*523fa7a6SAndroid Build Coastguard Worker for _ in range(n): 52*523fa7a6SAndroid Build Coastguard Worker spec += "#1" 53*523fa7a6SAndroid Build Coastguard Worker spec += "(" 54*523fa7a6SAndroid Build Coastguard Worker i = 0 55*523fa7a6SAndroid Build Coastguard Worker for key in d.keys(): 56*523fa7a6SAndroid Build Coastguard Worker if i > 0: 57*523fa7a6SAndroid Build Coastguard Worker spec += "," 58*523fa7a6SAndroid Build Coastguard Worker if isinstance(key, str): 59*523fa7a6SAndroid Build Coastguard Worker spec += "'" + key + "'" 60*523fa7a6SAndroid Build Coastguard Worker else: 61*523fa7a6SAndroid Build Coastguard Worker spec += str(key) 62*523fa7a6SAndroid Build Coastguard Worker spec += ":$" 63*523fa7a6SAndroid Build Coastguard Worker i += 1 64*523fa7a6SAndroid Build Coastguard Worker spec += ")" 65*523fa7a6SAndroid Build Coastguard Worker return spec 66*523fa7a6SAndroid Build Coastguard Worker 67*523fa7a6SAndroid Build Coastguard Worker 68*523fa7a6SAndroid Build Coastguard Workerclass TestPytree(unittest.TestCase): 69*523fa7a6SAndroid Build Coastguard Worker def test(self): 70*523fa7a6SAndroid Build Coastguard Worker SPEC = "D4#2#1#2#2('a':L2#1#1($,$),1:$,2:T2#1#1($,$),'str':D2#1#1('str':$,'str2':$))" 71*523fa7a6SAndroid Build Coastguard Worker d = {} 72*523fa7a6SAndroid Build Coastguard Worker d["a"] = [777, 1] 73*523fa7a6SAndroid Build Coastguard Worker d[1] = 4 74*523fa7a6SAndroid Build Coastguard Worker d[2] = ("ta", 2) 75*523fa7a6SAndroid Build Coastguard Worker d["str"] = {"str": 23, "str2": "47str"} 76*523fa7a6SAndroid Build Coastguard Worker (leaves, pytree) = tree_flatten(d) 77*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(leaves, [777, 1, 4, "ta", 2, 23, "47str"]) 78*523fa7a6SAndroid Build Coastguard Worker pytree_str = pytree.to_str() 79*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(pytree_str, SPEC) 80*523fa7a6SAndroid Build Coastguard Worker 81*523fa7a6SAndroid Build Coastguard Worker leaves_test = [] 82*523fa7a6SAndroid Build Coastguard Worker for i in range(len(leaves)): 83*523fa7a6SAndroid Build Coastguard Worker if i % 2 == 0: 84*523fa7a6SAndroid Build Coastguard Worker leaves_test.append(i + 13) 85*523fa7a6SAndroid Build Coastguard Worker else: 86*523fa7a6SAndroid Build Coastguard Worker leaves_test.append(str(i + 13)) 87*523fa7a6SAndroid Build Coastguard Worker 88*523fa7a6SAndroid Build Coastguard Worker tree_test = pytree.tree_unflatten(leaves_test) 89*523fa7a6SAndroid Build Coastguard Worker self.assertEqual( 90*523fa7a6SAndroid Build Coastguard Worker tree_test, 91*523fa7a6SAndroid Build Coastguard Worker {"a": [13, "14"], 1: 15, 2: ("16", 17), "str": {"str": "18", "str2": 19}}, 92*523fa7a6SAndroid Build Coastguard Worker ) 93*523fa7a6SAndroid Build Coastguard Worker 94*523fa7a6SAndroid Build Coastguard Worker pytree_from = TreeSpec.from_str(SPEC) 95*523fa7a6SAndroid Build Coastguard Worker spec_str_to = pytree_from.to_str() 96*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(SPEC, spec_str_to) 97*523fa7a6SAndroid Build Coastguard Worker 98*523fa7a6SAndroid Build Coastguard Worker def test_extract_nested_list(self): 99*523fa7a6SAndroid Build Coastguard Worker nested_struct = (1, 2, [3, 4]) 100*523fa7a6SAndroid Build Coastguard Worker (_, pytree) = tree_flatten(nested_struct) 101*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(pytree.to_str(), "T3#1#1#2($,$,L2#1#1($,$))") 102*523fa7a6SAndroid Build Coastguard Worker 103*523fa7a6SAndroid Build Coastguard Worker def test_extract_nested_dict(self): 104*523fa7a6SAndroid Build Coastguard Worker nested_struct = (1, 2, {3: 4, "str": 6}) 105*523fa7a6SAndroid Build Coastguard Worker (_, pytree) = tree_flatten(nested_struct) 106*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(pytree.to_str(), "T3#1#1#2($,$,D2#1#1(3:$,'str':$))") 107*523fa7a6SAndroid Build Coastguard Worker 108*523fa7a6SAndroid Build Coastguard Worker def test_extracted_scalar(self): 109*523fa7a6SAndroid Build Coastguard Worker struct = 4 110*523fa7a6SAndroid Build Coastguard Worker (_, pytree) = tree_flatten(struct) 111*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(pytree.to_str(), "$") 112*523fa7a6SAndroid Build Coastguard Worker 113*523fa7a6SAndroid Build Coastguard Worker def test_map(self): 114*523fa7a6SAndroid Build Coastguard Worker struct = (1, 2, [3, 4]) 115*523fa7a6SAndroid Build Coastguard Worker struct_map = tree_map(lambda x: 2 * x, struct) 116*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(struct_map, (2, 4, [6, 8])) 117*523fa7a6SAndroid Build Coastguard Worker 118*523fa7a6SAndroid Build Coastguard Worker def test_treespec_equality(self): 119*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(TreeSpec.from_str("$") == TreeSpec.from_str("$")) 120*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(_spec([1]) == TreeSpec.from_str("L1#1($)")) 121*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(_spec((1)) != _spec([1])) 122*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(_spec((1)) == _spec((2))) 123*523fa7a6SAndroid Build Coastguard Worker 124*523fa7a6SAndroid Build Coastguard Worker def test_flatten_unflatten_leaf(self): 125*523fa7a6SAndroid Build Coastguard Worker def run_test_with_leaf(leaf): 126*523fa7a6SAndroid Build Coastguard Worker values, treespec = tree_flatten(leaf) 127*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(values, [leaf]) 128*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(treespec, TreeSpec.from_str("$")) 129*523fa7a6SAndroid Build Coastguard Worker 130*523fa7a6SAndroid Build Coastguard Worker unflattened = tree_unflatten(values, treespec) 131*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(unflattened, leaf) 132*523fa7a6SAndroid Build Coastguard Worker 133*523fa7a6SAndroid Build Coastguard Worker run_test_with_leaf(1) 134*523fa7a6SAndroid Build Coastguard Worker run_test_with_leaf(1.0) 135*523fa7a6SAndroid Build Coastguard Worker run_test_with_leaf(None) 136*523fa7a6SAndroid Build Coastguard Worker run_test_with_leaf(bool) 137*523fa7a6SAndroid Build Coastguard Worker 138*523fa7a6SAndroid Build Coastguard Worker def test_flatten_unflatten_list(self): 139*523fa7a6SAndroid Build Coastguard Worker def run_test(lst): 140*523fa7a6SAndroid Build Coastguard Worker spec = _spec_str("L", len(lst)) 141*523fa7a6SAndroid Build Coastguard Worker 142*523fa7a6SAndroid Build Coastguard Worker expected_spec = TreeSpec.from_str(spec) 143*523fa7a6SAndroid Build Coastguard Worker values, treespec = tree_flatten(lst) 144*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(isinstance(values, list)) 145*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(values, lst) 146*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(treespec, expected_spec) 147*523fa7a6SAndroid Build Coastguard Worker 148*523fa7a6SAndroid Build Coastguard Worker unflattened = tree_unflatten(values, treespec) 149*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(unflattened, lst) 150*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(isinstance(unflattened, list)) 151*523fa7a6SAndroid Build Coastguard Worker 152*523fa7a6SAndroid Build Coastguard Worker run_test([]) 153*523fa7a6SAndroid Build Coastguard Worker run_test([1.0, 2]) 154*523fa7a6SAndroid Build Coastguard Worker run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11]) 155*523fa7a6SAndroid Build Coastguard Worker 156*523fa7a6SAndroid Build Coastguard Worker def test_flatten_unflatten_tuple(self): 157*523fa7a6SAndroid Build Coastguard Worker def run_test(tup): 158*523fa7a6SAndroid Build Coastguard Worker spec = _spec_str("T", len(tup)) 159*523fa7a6SAndroid Build Coastguard Worker 160*523fa7a6SAndroid Build Coastguard Worker expected_spec = TreeSpec.from_str(spec) 161*523fa7a6SAndroid Build Coastguard Worker values, treespec = tree_flatten(tup) 162*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(isinstance(values, list)) 163*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(values, list(tup)) 164*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(treespec, expected_spec) 165*523fa7a6SAndroid Build Coastguard Worker 166*523fa7a6SAndroid Build Coastguard Worker unflattened = tree_unflatten(values, treespec) 167*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(unflattened, tup) 168*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(isinstance(unflattened, tuple)) 169*523fa7a6SAndroid Build Coastguard Worker 170*523fa7a6SAndroid Build Coastguard Worker run_test(()) 171*523fa7a6SAndroid Build Coastguard Worker run_test((1.0,)) 172*523fa7a6SAndroid Build Coastguard Worker run_test((1.0, 2)) 173*523fa7a6SAndroid Build Coastguard Worker run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11)) 174*523fa7a6SAndroid Build Coastguard Worker 175*523fa7a6SAndroid Build Coastguard Worker def test_flatten_unflatten_namedtuple(self): 176*523fa7a6SAndroid Build Coastguard Worker Point = namedtuple("Point", ["x", "y"]) 177*523fa7a6SAndroid Build Coastguard Worker 178*523fa7a6SAndroid Build Coastguard Worker def run_test(tup): 179*523fa7a6SAndroid Build Coastguard Worker spec = _spec_str("N", len(tup)) 180*523fa7a6SAndroid Build Coastguard Worker expected_spec = TreeSpec.from_str(spec) 181*523fa7a6SAndroid Build Coastguard Worker 182*523fa7a6SAndroid Build Coastguard Worker values, treespec = tree_flatten(tup) 183*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(isinstance(values, list)) 184*523fa7a6SAndroid Build Coastguard Worker 185*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(values, list(tup)) 186*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(treespec, expected_spec) 187*523fa7a6SAndroid Build Coastguard Worker 188*523fa7a6SAndroid Build Coastguard Worker unflattened = tree_unflatten(values, treespec) 189*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(unflattened, tup) 190*523fa7a6SAndroid Build Coastguard Worker 191*523fa7a6SAndroid Build Coastguard Worker run_test(Point(1.0, 2)) 192*523fa7a6SAndroid Build Coastguard Worker run_test(Point(torch.tensor(1.0), 2)) 193*523fa7a6SAndroid Build Coastguard Worker 194*523fa7a6SAndroid Build Coastguard Worker def test_flatten_unflatten_torch_namedtuple_return_type(self): 195*523fa7a6SAndroid Build Coastguard Worker x = torch.randn(3, 3) 196*523fa7a6SAndroid Build Coastguard Worker expected = torch.max(x, dim=0) 197*523fa7a6SAndroid Build Coastguard Worker 198*523fa7a6SAndroid Build Coastguard Worker values, spec = tree_flatten(expected) 199*523fa7a6SAndroid Build Coastguard Worker result = tree_unflatten(values, spec) 200*523fa7a6SAndroid Build Coastguard Worker 201*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(type(result), type(expected)) 202*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(result, expected) 203*523fa7a6SAndroid Build Coastguard Worker 204*523fa7a6SAndroid Build Coastguard Worker def test_flatten_unflatten_dict(self): 205*523fa7a6SAndroid Build Coastguard Worker def run_test(d): 206*523fa7a6SAndroid Build Coastguard Worker spec = _spec_str_dict(d) 207*523fa7a6SAndroid Build Coastguard Worker 208*523fa7a6SAndroid Build Coastguard Worker values, treespec = tree_flatten(d) 209*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(isinstance(values, list)) 210*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(values, list(d.values())) 211*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(treespec, TreeSpec.from_str(spec)) 212*523fa7a6SAndroid Build Coastguard Worker 213*523fa7a6SAndroid Build Coastguard Worker unflattened = tree_unflatten(values, treespec) 214*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(unflattened, d) 215*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(isinstance(unflattened, dict)) 216*523fa7a6SAndroid Build Coastguard Worker 217*523fa7a6SAndroid Build Coastguard Worker run_test({}) 218*523fa7a6SAndroid Build Coastguard Worker run_test({"a": 1}) 219*523fa7a6SAndroid Build Coastguard Worker run_test({"abcdefg": torch.randn(2, 3)}) 220*523fa7a6SAndroid Build Coastguard Worker run_test({1: torch.randn(2, 3)}) 221*523fa7a6SAndroid Build Coastguard Worker run_test({"a": 1, "b": 2, "c": torch.randn(2, 3)}) 222*523fa7a6SAndroid Build Coastguard Worker 223*523fa7a6SAndroid Build Coastguard Worker def test_flatten_unflatten_nested(self): 224*523fa7a6SAndroid Build Coastguard Worker def run_test(pytree): 225*523fa7a6SAndroid Build Coastguard Worker values, treespec = tree_flatten(pytree) 226*523fa7a6SAndroid Build Coastguard Worker self.assertTrue(isinstance(values, list)) 227*523fa7a6SAndroid Build Coastguard Worker 228*523fa7a6SAndroid Build Coastguard Worker unflattened = tree_unflatten(values, treespec) 229*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(unflattened, pytree) 230*523fa7a6SAndroid Build Coastguard Worker 231*523fa7a6SAndroid Build Coastguard Worker cases = [ 232*523fa7a6SAndroid Build Coastguard Worker [()], 233*523fa7a6SAndroid Build Coastguard Worker ([],), 234*523fa7a6SAndroid Build Coastguard Worker {"a": ()}, 235*523fa7a6SAndroid Build Coastguard Worker {"a": 0, "b": [{"c": 1}]}, 236*523fa7a6SAndroid Build Coastguard Worker {"a": 0, "b": [1, {"c": 2}, torch.randn(3)], "c": (torch.randn(2, 3), 1)}, 237*523fa7a6SAndroid Build Coastguard Worker ] 238*523fa7a6SAndroid Build Coastguard Worker for case in cases: 239*523fa7a6SAndroid Build Coastguard Worker run_test(case) 240*523fa7a6SAndroid Build Coastguard Worker 241*523fa7a6SAndroid Build Coastguard Worker def test_treemap(self): 242*523fa7a6SAndroid Build Coastguard Worker def run_test(pytree): 243*523fa7a6SAndroid Build Coastguard Worker def f(x): 244*523fa7a6SAndroid Build Coastguard Worker return x * 3 245*523fa7a6SAndroid Build Coastguard Worker 246*523fa7a6SAndroid Build Coastguard Worker sm1 = sum(map(tree_flatten(pytree)[0], f)) 247*523fa7a6SAndroid Build Coastguard Worker sm2 = tree_flatten(tree_map(f, pytree))[0] 248*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(sm1, sm2) 249*523fa7a6SAndroid Build Coastguard Worker 250*523fa7a6SAndroid Build Coastguard Worker def invf(x): 251*523fa7a6SAndroid Build Coastguard Worker return x // 3 252*523fa7a6SAndroid Build Coastguard Worker 253*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(tree_flatten(tree_flatten(pytree, f), invf), pytree) 254*523fa7a6SAndroid Build Coastguard Worker 255*523fa7a6SAndroid Build Coastguard Worker cases = [ 256*523fa7a6SAndroid Build Coastguard Worker [()], 257*523fa7a6SAndroid Build Coastguard Worker ([],), 258*523fa7a6SAndroid Build Coastguard Worker {"a": ()}, 259*523fa7a6SAndroid Build Coastguard Worker {"a": 1, "b": [{"c": 2}]}, 260*523fa7a6SAndroid Build Coastguard Worker {"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)}, 261*523fa7a6SAndroid Build Coastguard Worker ] 262*523fa7a6SAndroid Build Coastguard Worker for case in cases: 263*523fa7a6SAndroid Build Coastguard Worker run_test(case) 264*523fa7a6SAndroid Build Coastguard Worker 265*523fa7a6SAndroid Build Coastguard Worker def test_treespec_repr(self): 266*523fa7a6SAndroid Build Coastguard Worker pytree = (0, [0, 0, 0]) 267*523fa7a6SAndroid Build Coastguard Worker _, spec = tree_flatten(pytree) 268*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(repr(spec), "T2#1#3($,L3#1#1#1($,$,$))") 269*523fa7a6SAndroid Build Coastguard Worker 270*523fa7a6SAndroid Build Coastguard Worker def test_custom_tree_node(self): 271*523fa7a6SAndroid Build Coastguard Worker class Point(object): 272*523fa7a6SAndroid Build Coastguard Worker def __init__(self, x, y, name): 273*523fa7a6SAndroid Build Coastguard Worker self.x = x 274*523fa7a6SAndroid Build Coastguard Worker self.y = y 275*523fa7a6SAndroid Build Coastguard Worker self.name = name 276*523fa7a6SAndroid Build Coastguard Worker 277*523fa7a6SAndroid Build Coastguard Worker def __repr__(self): 278*523fa7a6SAndroid Build Coastguard Worker return "Point(x:{}, y:{}, name: {})".format(self.x, self.y, self.name) 279*523fa7a6SAndroid Build Coastguard Worker 280*523fa7a6SAndroid Build Coastguard Worker def custom_flatten(p): 281*523fa7a6SAndroid Build Coastguard Worker children = [p.x, p.y] 282*523fa7a6SAndroid Build Coastguard Worker extra_data = p.name 283*523fa7a6SAndroid Build Coastguard Worker return (children, extra_data) 284*523fa7a6SAndroid Build Coastguard Worker 285*523fa7a6SAndroid Build Coastguard Worker def custom_unflatten(children, extra_data): 286*523fa7a6SAndroid Build Coastguard Worker return Point(*children, extra_data) 287*523fa7a6SAndroid Build Coastguard Worker 288*523fa7a6SAndroid Build Coastguard Worker register_custom(Point, custom_flatten, custom_unflatten) 289*523fa7a6SAndroid Build Coastguard Worker 290*523fa7a6SAndroid Build Coastguard Worker point = Point((1.0, 1.0, 1), 2.0, "point_name") 291*523fa7a6SAndroid Build Coastguard Worker children, spec = tree_flatten(point) 292*523fa7a6SAndroid Build Coastguard Worker point2 = tree_unflatten(children, spec) 293*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(str(point), str(point2)) 294*523fa7a6SAndroid Build Coastguard Worker 295*523fa7a6SAndroid Build Coastguard Worker def test_broadcast_to_and_flatten(self): 296*523fa7a6SAndroid Build Coastguard Worker cases = [ 297*523fa7a6SAndroid Build Coastguard Worker (1, (), []), 298*523fa7a6SAndroid Build Coastguard Worker # Same (flat) structures 299*523fa7a6SAndroid Build Coastguard Worker ((1,), (0,), [1]), 300*523fa7a6SAndroid Build Coastguard Worker ([1], [0], [1]), 301*523fa7a6SAndroid Build Coastguard Worker ((1, 2, 3), (0, 0, 0), [1, 2, 3]), 302*523fa7a6SAndroid Build Coastguard Worker ({"a": 1, "b": 2}, {"a": 0, "b": 0}, [1, 2]), 303*523fa7a6SAndroid Build Coastguard Worker # Mismatched (flat) structures 304*523fa7a6SAndroid Build Coastguard Worker ([1], (0,), None), 305*523fa7a6SAndroid Build Coastguard Worker ([1], (0,), None), 306*523fa7a6SAndroid Build Coastguard Worker ((1,), [0], None), 307*523fa7a6SAndroid Build Coastguard Worker ((1, 2, 3), (0, 0), None), 308*523fa7a6SAndroid Build Coastguard Worker ({"a": 1, "b": 2}, {"a": 0}, None), 309*523fa7a6SAndroid Build Coastguard Worker ({"a": 1, "b": 2}, {"a": 0, "c": 0}, None), 310*523fa7a6SAndroid Build Coastguard Worker ({"a": 1, "b": 2}, {"a": 0, "b": 0, "c": 0}, None), 311*523fa7a6SAndroid Build Coastguard Worker # Same (nested) structures 312*523fa7a6SAndroid Build Coastguard Worker ((1, [2, 3]), (0, [0, 0]), [1, 2, 3]), 313*523fa7a6SAndroid Build Coastguard Worker ((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]), 314*523fa7a6SAndroid Build Coastguard Worker # Mismatched (nested) structures 315*523fa7a6SAndroid Build Coastguard Worker ((1, [2, 3]), (0, (0, 0)), None), 316*523fa7a6SAndroid Build Coastguard Worker ((1, [2, 3]), (0, [0, 0, 0]), None), 317*523fa7a6SAndroid Build Coastguard Worker # Broadcasting single value 318*523fa7a6SAndroid Build Coastguard Worker (1, (0, 0, 0), [1, 1, 1]), 319*523fa7a6SAndroid Build Coastguard Worker (1, [0, 0, 0], [1, 1, 1]), 320*523fa7a6SAndroid Build Coastguard Worker (1, {"a": 0, "b": 0}, [1, 1]), 321*523fa7a6SAndroid Build Coastguard Worker (1, (0, [0, [0]], 0), [1, 1, 1, 1]), 322*523fa7a6SAndroid Build Coastguard Worker (1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]), 323*523fa7a6SAndroid Build Coastguard Worker # Broadcast multiple things 324*523fa7a6SAndroid Build Coastguard Worker ((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]), 325*523fa7a6SAndroid Build Coastguard Worker ((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]), 326*523fa7a6SAndroid Build Coastguard Worker (([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]), 327*523fa7a6SAndroid Build Coastguard Worker ] 328*523fa7a6SAndroid Build Coastguard Worker for pytree, to_pytree, expected in cases: 329*523fa7a6SAndroid Build Coastguard Worker _, to_spec = tree_flatten(to_pytree) 330*523fa7a6SAndroid Build Coastguard Worker result = broadcast_to_and_flatten(pytree, to_spec) 331*523fa7a6SAndroid Build Coastguard Worker self.assertEqual(result, expected, msg=str([pytree, to_spec, expected])) 332