1""" 2Contains utility functions for working with nested python data structures. 3 4A *pytree* is Python nested data structure. It is a tree in the sense that 5nodes are Python collections (e.g., list, tuple, dict) and the leaves are 6Python values. Furthermore, a pytree should not contain reference cycles. 7 8pytrees are useful for working with nested collections of Tensors. For example, 9one can use `tree_map` to map a function over all Tensors inside some nested 10collection of Tensors and `tree_leaves` to get a flat list of all Tensors 11inside some nested collection. pytrees are helpful for implementing nested 12collection support for PyTorch APIs. 13 14This pytree implementation is not very performant due to Python overhead 15To improve the performance we can move parts of the implementation to C++. 16""" 17 18import dataclasses 19import functools 20import importlib 21import json 22import sys 23import threading 24import types 25import warnings 26from collections import defaultdict, deque, namedtuple, OrderedDict 27from typing import ( 28 Any, 29 Callable, 30 cast, 31 DefaultDict, 32 Deque, 33 Dict, 34 FrozenSet, 35 Generic, 36 Hashable, 37 Iterable, 38 List, 39 Mapping, 40 NamedTuple, 41 Optional, 42 OrderedDict as GenericOrderedDict, 43 overload, 44 Protocol, 45 Sequence, 46 Tuple, 47 Type, 48 TypeVar, 49 Union, 50) 51from typing_extensions import deprecated 52 53 54__all__ = [ 55 "PyTree", 56 "Context", 57 "FlattenFunc", 58 "UnflattenFunc", 59 "DumpableContext", 60 "ToDumpableContextFn", 61 "FromDumpableContextFn", 62 "TreeSpec", 63 "LeafSpec", 64 "keystr", 65 "key_get", 66 "register_pytree_node", 67 "tree_flatten", 68 "tree_flatten_with_path", 69 "tree_unflatten", 70 "tree_iter", 71 "tree_leaves", 72 "tree_leaves_with_path", 73 "tree_structure", 74 "tree_map", 75 "tree_map_with_path", 76 "tree_map_", 77 "tree_map_only", 78 "tree_map_only_", 79 "tree_all", 80 "tree_any", 81 "tree_all_only", 82 "tree_any_only", 83 "treespec_dumps", 84 "treespec_loads", 85 "treespec_pprint", 86] 87 88 89T = TypeVar("T") 90S = TypeVar("S") 91U = TypeVar("U") 92R = TypeVar("R") 93 94 95DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL = 1 96NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND" 97 98 99class KeyEntry(Protocol): 100 def __hash__(self) -> int: 101 ... 102 103 def __eq__(self, other: object) -> bool: 104 ... 105 106 def __str__(self) -> str: 107 ... 108 109 def get(self, parent: Any) -> Any: 110 ... 111 112 113Context = Any 114PyTree = Any 115FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]] 116UnflattenFunc = Callable[[Iterable[Any], Context], PyTree] 117DumpableContext = Any # Any json dumpable text 118ToDumpableContextFn = Callable[[Context], DumpableContext] 119FromDumpableContextFn = Callable[[DumpableContext], Context] 120ToStrFunc = Callable[["TreeSpec", List[str]], str] 121MaybeFromStrFunc = Callable[[str], Optional[Tuple[Any, Context, str]]] 122KeyPath = Tuple[KeyEntry, ...] 123FlattenWithKeysFunc = Callable[[PyTree], Tuple[List[Tuple[KeyEntry, Any]], Any]] 124 125 126# A NodeDef holds two callables: 127# - flatten_fn should take the collection and return a flat list of values. 128# It can also return some context that is used in reconstructing the 129# collection. 130# - unflatten_fn should take a flat list of values and some context 131# (returned by flatten_fn). It returns the collection by reconstructing 132# it from the list and the context. 133# - flatten_with_keys_fn, which is a callable that takes a 134# pytree and returns a list of (keypath, value) pairs and a context. 135class NodeDef(NamedTuple): 136 type: Type[Any] 137 flatten_fn: FlattenFunc 138 unflatten_fn: UnflattenFunc 139 flatten_with_keys_fn: Optional[FlattenWithKeysFunc] 140 141 142_NODE_REGISTRY_LOCK = threading.Lock() 143SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} 144 145 146# _SerializeNodeDef holds the following: 147# - typ: the type of the node (e.g., "Dict", "List", etc) 148# - serialized_type_name: the fully qualified name of the type, e.g. "collections.OrderedDict" 149# - to_dumpable_context takes a TreeSpec, and returns a serialized string format of the 150# context, and the version number 151# - from_dumpable_context takes in a string representation of the context, and the 152# version, and returns the deserialized context 153class _SerializeNodeDef(NamedTuple): 154 typ: Type[Any] 155 serialized_type_name: str 156 to_dumpable_context: Optional[ToDumpableContextFn] 157 from_dumpable_context: Optional[FromDumpableContextFn] 158 159 160SUPPORTED_SERIALIZED_TYPES: Dict[Type[Any], _SerializeNodeDef] = {} 161SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {} 162 163# NB: we try really hard to not import _cxx_pytree (which depends on optree) 164# as much as possible. This is for isolation: a user who is not using C++ pytree 165# shouldn't pay for it, and it helps makes things like cpython upgrades easier. 166_cxx_pytree_exists = importlib.util.find_spec("optree") # type: ignore[attr-defined] 167_cxx_pytree_imported = False 168_cxx_pytree_pending_imports: List[Any] = [] 169 170 171def register_pytree_node( 172 cls: Type[Any], 173 flatten_fn: FlattenFunc, 174 unflatten_fn: UnflattenFunc, 175 *, 176 serialized_type_name: Optional[str] = None, 177 to_dumpable_context: Optional[ToDumpableContextFn] = None, 178 from_dumpable_context: Optional[FromDumpableContextFn] = None, 179 flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, 180) -> None: 181 """Register a container-like type as pytree node. 182 183 Args: 184 cls: the type to register 185 flatten_fn: A callable that takes a pytree and returns a flattened 186 representation of the pytree and additional context to represent the 187 flattened pytree. 188 unflatten_fn: A callable that takes a flattened version of the pytree, 189 additional context, and returns an unflattened pytree. 190 serialized_type_name: A keyword argument used to specify the fully qualified 191 name used when serializing the tree spec. 192 to_dumpable_context: An optional keyword argument to custom specify how 193 to convert the context of the pytree to a custom json dumpable 194 representation. This is used for json serialization, which is being 195 used in torch.export right now. 196 from_dumpable_context: An optional keyword argument to custom specify how 197 to convert the custom json dumpable representation of the context 198 back to the original context. This is used for json deserialization, 199 which is being used in torch.export right now. 200 flatten_with_keys_fn: An optional keyword argument to specify how to 201 access each pytree leaf's keypath when flattening and tree-mapping. 202 Like ``flatten_fn``, but in place of a List[leaf], it should return 203 a List[(keypath, leaf)]. 204 """ 205 with _NODE_REGISTRY_LOCK: 206 if cls in SUPPORTED_NODES: 207 raise ValueError(f"{cls} is already registered as pytree node.") 208 209 _private_register_pytree_node( 210 cls, 211 flatten_fn, 212 unflatten_fn, 213 serialized_type_name=serialized_type_name, 214 to_dumpable_context=to_dumpable_context, 215 from_dumpable_context=from_dumpable_context, 216 flatten_with_keys_fn=flatten_with_keys_fn, 217 ) 218 219 if not _cxx_pytree_exists: 220 return 221 222 if _cxx_pytree_imported: 223 from . import _cxx_pytree as cxx 224 225 cxx._private_register_pytree_node( 226 cls, 227 flatten_fn, 228 unflatten_fn, 229 serialized_type_name=serialized_type_name, 230 to_dumpable_context=to_dumpable_context, 231 from_dumpable_context=from_dumpable_context, 232 ) 233 else: 234 args = (cls, flatten_fn, unflatten_fn) 235 kwargs = { 236 "serialized_type_name": serialized_type_name, 237 "to_dumpable_context": to_dumpable_context, 238 "from_dumpable_context": from_dumpable_context, 239 } 240 _cxx_pytree_pending_imports.append((args, kwargs)) 241 242 243def _register_namedtuple( 244 cls: Type[Any], 245 *, 246 serialized_type_name: str, 247) -> None: 248 """ 249 Registers a namedtuple as a valid pytree node. By default namedtuples are 250 valid pytree nodes, but they are not serializable. This API provides the 251 argument `serialized_type_name` which allows these namedtuples to be 252 serialized. 253 254 Args: 255 cls: the dataclass type to register 256 serialized_type_name: The serialized name for the dataclass. This is 257 required if you want to serialize the pytree TreeSpec containing this 258 namedtuple. 259 """ 260 _private_register_pytree_node( 261 cls, 262 _namedtuple_flatten, 263 _namedtuple_unflatten, 264 serialized_type_name=serialized_type_name, 265 to_dumpable_context=_namedtuple_serialize, 266 from_dumpable_context=_namedtuple_deserialize, 267 flatten_with_keys_fn=_namedtuple_flatten_with_keys, 268 ) 269 270 271@deprecated( 272 "`torch.utils._pytree._register_pytree_node` is deprecated. " 273 "Please use `torch.utils._pytree.register_pytree_node` instead.", 274 category=FutureWarning, 275) 276def _register_pytree_node( 277 cls: Type[Any], 278 flatten_fn: FlattenFunc, 279 unflatten_fn: UnflattenFunc, 280 to_str_fn: Optional[ToStrFunc] = None, # deprecated 281 maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated 282 *, 283 serialized_type_name: Optional[str] = None, 284 to_dumpable_context: Optional[ToDumpableContextFn] = None, 285 from_dumpable_context: Optional[FromDumpableContextFn] = None, 286 flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, 287) -> None: 288 """Register a container-like type as pytree node for the Python pytree only. 289 290 Args: 291 cls: the type to register 292 flatten_fn: A callable that takes a pytree and returns a flattened 293 representation of the pytree and additional context to represent the 294 flattened pytree. 295 unflatten_fn: A callable that takes a flattened version of the pytree, 296 additional context, and returns an unflattened pytree. 297 serialized_type_name: A keyword argument used to specify the fully qualified 298 name used when serializing the tree spec. 299 to_dumpable_context: An optional keyword argument to custom specify how 300 to convert the context of the pytree to a custom json dumpable 301 representation. This is used for json serialization, which is being 302 used in torch.export right now. 303 from_dumpable_context: An optional keyword argument to custom specify how 304 to convert the custom json dumpable representation of the context 305 back to the original context. This is used for json deserialization, 306 which is being used in torch.export right now. 307 flatten_with_keys_fn: An optional keyword argument to specify how to 308 access each pytree leaf's keypath when flattening and tree-mapping. 309 Like ``flatten_fn``, but in place of a List[leaf], it should return 310 a List[(keypath, leaf)]. 311 """ 312 if to_str_fn is not None or maybe_from_str_fn is not None: 313 warnings.warn( 314 "`to_str_fn` and `maybe_from_str_fn` is deprecated. " 315 "Please use `to_dumpable_context` and `from_dumpable_context` instead.", 316 FutureWarning, 317 stacklevel=2, 318 ) 319 320 _private_register_pytree_node( 321 cls, 322 flatten_fn, 323 unflatten_fn, 324 serialized_type_name=serialized_type_name, 325 to_dumpable_context=to_dumpable_context, 326 from_dumpable_context=from_dumpable_context, 327 flatten_with_keys_fn=flatten_with_keys_fn, 328 ) 329 330 331def _private_register_pytree_node( 332 cls: Type[Any], 333 flatten_fn: FlattenFunc, 334 unflatten_fn: UnflattenFunc, 335 *, 336 serialized_type_name: Optional[str] = None, 337 to_dumpable_context: Optional[ToDumpableContextFn] = None, 338 from_dumpable_context: Optional[FromDumpableContextFn] = None, 339 flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, 340) -> None: 341 """This is an internal function that is used to register a pytree node type 342 for the Python pytree only. End-users should use :func:`register_pytree_node` 343 instead. 344 """ 345 with _NODE_REGISTRY_LOCK: 346 if cls in SUPPORTED_NODES: 347 # TODO: change this warning to an error after OSS/internal stabilize 348 warnings.warn( 349 f"{cls} is already registered as pytree node. " 350 "Overwriting the previous registration.", 351 ) 352 353 node_def = NodeDef(cls, flatten_fn, unflatten_fn, flatten_with_keys_fn) 354 SUPPORTED_NODES[cls] = node_def 355 356 if (to_dumpable_context is None) ^ (from_dumpable_context is None): 357 raise ValueError( 358 f"Both to_dumpable_context and from_dumpable_context for {cls} must " 359 "be None or registered." 360 ) 361 362 if serialized_type_name is None: 363 serialized_type_name = NO_SERIALIZED_TYPE_NAME_FOUND 364 365 serialize_node_def = _SerializeNodeDef( 366 cls, 367 serialized_type_name, 368 to_dumpable_context, 369 from_dumpable_context, 370 ) 371 SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def 372 SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls 373 374 375@dataclasses.dataclass(frozen=True) 376class SequenceKey(Generic[T]): 377 idx: int 378 379 def __str__(self) -> str: 380 return f"[{self.idx!r}]" 381 382 def get(self, sequence: Sequence[T]) -> T: 383 return sequence[self.idx] 384 385 386K = TypeVar("K", bound=Hashable) 387 388 389@dataclasses.dataclass(frozen=True) 390class MappingKey(Generic[K, T]): 391 key: K 392 393 def __str__(self) -> str: 394 return f"[{self.key!r}]" 395 396 def get(self, mapping: Mapping[K, T]) -> T: 397 return mapping[self.key] 398 399 400@dataclasses.dataclass(frozen=True) 401class GetAttrKey: 402 name: str 403 404 def __str__(self) -> str: 405 return f".{self.name}" 406 407 def get(self, obj: Any) -> Any: 408 return getattr(obj, self.name) 409 410 411def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]: 412 return list(d), None 413 414 415def _tuple_flatten_with_keys( 416 d: Tuple[Any, ...] 417) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: 418 values, context = _tuple_flatten(d) 419 return [(SequenceKey(i), v) for i, v in enumerate(values)], context 420 421 422def _tuple_unflatten(values: Iterable[Any], context: Context) -> Tuple[Any, ...]: 423 return tuple(values) 424 425 426def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: 427 return d, None 428 429 430def _list_flatten_with_keys(d: List[Any]) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: 431 values, context = _list_flatten(d) 432 return [(SequenceKey(i), v) for i, v in enumerate(values)], context 433 434 435def _list_unflatten(values: Iterable[Any], context: Context) -> List[Any]: 436 return list(values) 437 438 439def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: 440 return list(d.values()), list(d.keys()) 441 442 443def _dict_flatten_with_keys( 444 d: Dict[Any, Any] 445) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: 446 values, context = _dict_flatten(d) 447 return [(MappingKey(k), v) for k, v in zip(context, values)], context 448 449 450def _dict_unflatten(values: Iterable[Any], context: Context) -> Dict[Any, Any]: 451 return dict(zip(context, values)) 452 453 454def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]: 455 return list(d), type(d) 456 457 458def _namedtuple_flatten_with_keys( 459 d: NamedTuple, 460) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: 461 values, context = _namedtuple_flatten(d) 462 return ( 463 [(GetAttrKey(field), v) for field, v in zip(context._fields, values)], 464 context, 465 ) 466 467 468def _namedtuple_unflatten(values: Iterable[Any], context: Context) -> NamedTuple: 469 return cast(NamedTuple, context(*values)) 470 471 472def _namedtuple_serialize(context: Context) -> DumpableContext: 473 if context not in SUPPORTED_SERIALIZED_TYPES: 474 raise NotImplementedError( 475 f"Can't serialize TreeSpec of namedtuple class {context} because we " 476 "didn't register a serializated_type_name. Please register using " 477 "`_register_namedtuple`." 478 ) 479 480 serialize_node_def = SUPPORTED_SERIALIZED_TYPES[context] 481 serialized_type_name = serialize_node_def.serialized_type_name 482 483 if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND: 484 raise NotImplementedError( 485 f"Can't serialize TreeSpec of namedtuple class {context} because we " 486 "couldn't find a serializated_type_name. Please register using " 487 "`_register_namedtuple`." 488 ) 489 return serialized_type_name 490 491 492def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context: 493 if dumpable_context not in SERIALIZED_TYPE_TO_PYTHON_TYPE: 494 raise NotImplementedError( 495 f"Can't deserialize TreeSpec of namedtuple class {dumpable_context} " 496 "because we couldn't find a serializated name." 497 ) 498 499 typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[dumpable_context] 500 return typ 501 502 503def _ordereddict_flatten(d: GenericOrderedDict[Any, Any]) -> Tuple[List[Any], Context]: 504 return list(d.values()), list(d.keys()) 505 506 507def _ordereddict_flatten_with_keys( 508 d: GenericOrderedDict[Any, Any] 509) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: 510 values, context = _ordereddict_flatten(d) 511 return [(MappingKey(k), v) for k, v in zip(context, values)], context 512 513 514def _ordereddict_unflatten( 515 values: Iterable[Any], 516 context: Context, 517) -> GenericOrderedDict[Any, Any]: 518 return OrderedDict((key, value) for key, value in zip(context, values)) 519 520 521_odict_flatten = _ordereddict_flatten 522_odict_unflatten = _ordereddict_unflatten 523 524 525def _defaultdict_flatten(d: DefaultDict[Any, Any]) -> Tuple[List[Any], Context]: 526 values, dict_context = _dict_flatten(d) 527 return values, [d.default_factory, dict_context] 528 529 530def _defaultdict_flatten_with_keys( 531 d: DefaultDict[Any, Any] 532) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: 533 values, context = _defaultdict_flatten(d) 534 _, dict_context = context 535 return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context 536 537 538def _defaultdict_unflatten( 539 values: Iterable[Any], 540 context: Context, 541) -> DefaultDict[Any, Any]: 542 default_factory, dict_context = context 543 return defaultdict(default_factory, _dict_unflatten(values, dict_context)) 544 545 546def _defaultdict_serialize(context: Context) -> DumpableContext: 547 default_factory, dict_context = context 548 json_defaultdict = { 549 "default_factory_module": default_factory.__module__, 550 "default_factory_name": default_factory.__qualname__, 551 "dict_context": dict_context, 552 } 553 return json_defaultdict 554 555 556def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context: 557 assert isinstance(dumpable_context, dict) 558 assert set(dumpable_context) == { 559 "default_factory_module", 560 "default_factory_name", 561 "dict_context", 562 } 563 564 default_factory_module = dumpable_context["default_factory_module"] 565 default_factory_name = dumpable_context["default_factory_name"] 566 assert isinstance(default_factory_module, str) 567 assert isinstance(default_factory_name, str) 568 module = importlib.import_module(default_factory_module) 569 default_factory = getattr(module, default_factory_name) 570 571 dict_context = dumpable_context["dict_context"] 572 return [default_factory, dict_context] 573 574 575def _deque_flatten(d: Deque[Any]) -> Tuple[List[Any], Context]: 576 return list(d), d.maxlen 577 578 579def _deque_flatten_with_keys( 580 d: Deque[Any], 581) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: 582 values, context = _deque_flatten(d) 583 return [(SequenceKey(i), v) for i, v in enumerate(values)], context 584 585 586def _deque_unflatten(values: Iterable[Any], context: Context) -> Deque[Any]: 587 return deque(values, maxlen=context) 588 589 590_private_register_pytree_node( 591 tuple, 592 _tuple_flatten, 593 _tuple_unflatten, 594 serialized_type_name="builtins.tuple", 595 flatten_with_keys_fn=_tuple_flatten_with_keys, 596) 597_private_register_pytree_node( 598 list, 599 _list_flatten, 600 _list_unflatten, 601 serialized_type_name="builtins.list", 602 flatten_with_keys_fn=_list_flatten_with_keys, 603) 604_private_register_pytree_node( 605 dict, 606 _dict_flatten, 607 _dict_unflatten, 608 serialized_type_name="builtins.dict", 609 flatten_with_keys_fn=_dict_flatten_with_keys, 610) 611_private_register_pytree_node( 612 namedtuple, # type: ignore[arg-type] 613 _namedtuple_flatten, 614 _namedtuple_unflatten, 615 serialized_type_name="collections.namedtuple", 616 to_dumpable_context=_namedtuple_serialize, 617 from_dumpable_context=_namedtuple_deserialize, 618 flatten_with_keys_fn=_namedtuple_flatten_with_keys, 619) 620_private_register_pytree_node( 621 OrderedDict, 622 _ordereddict_flatten, 623 _ordereddict_unflatten, 624 serialized_type_name="collections.OrderedDict", 625 flatten_with_keys_fn=_ordereddict_flatten_with_keys, 626) 627_private_register_pytree_node( 628 defaultdict, 629 _defaultdict_flatten, 630 _defaultdict_unflatten, 631 serialized_type_name="collections.defaultdict", 632 to_dumpable_context=_defaultdict_serialize, 633 from_dumpable_context=_defaultdict_deserialize, 634 flatten_with_keys_fn=_defaultdict_flatten_with_keys, 635) 636_private_register_pytree_node( 637 deque, 638 _deque_flatten, 639 _deque_unflatten, 640 serialized_type_name="collections.deque", 641 flatten_with_keys_fn=_deque_flatten_with_keys, 642) 643 644 645STANDARD_DICT_TYPES: FrozenSet[type] = frozenset( 646 {dict, OrderedDict, defaultdict}, 647) 648BUILTIN_TYPES: FrozenSet[type] = frozenset( 649 {tuple, list, dict, namedtuple, OrderedDict, defaultdict, deque}, # type: ignore[arg-type] 650) 651 652 653# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple 654def _is_namedtuple_instance(tree: Any) -> bool: 655 typ = type(tree) 656 bases = typ.__bases__ 657 if len(bases) != 1 or bases[0] != tuple: 658 return False 659 fields = getattr(typ, "_fields", None) 660 if not isinstance(fields, tuple): 661 return False 662 return all(type(entry) == str for entry in fields) 663 664 665def _get_node_type(tree: Any) -> Any: 666 if _is_namedtuple_instance(tree): 667 return namedtuple 668 return type(tree) 669 670 671# A leaf is defined as anything that is not a Node. 672def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -> bool: 673 return (is_leaf is not None and is_leaf(tree)) or _get_node_type( 674 tree 675 ) not in SUPPORTED_NODES 676 677 678# A TreeSpec represents the structure of a pytree. It holds: 679# "type": the type of root Node of the pytree 680# context: some context that is useful in unflattening the pytree 681# children_specs: specs for each child of the root Node 682# num_leaves: the number of leaves 683@dataclasses.dataclass(init=True, frozen=True, eq=True, repr=False) 684class TreeSpec: 685 type: Any 686 context: Context 687 children_specs: List["TreeSpec"] 688 689 num_nodes: int = dataclasses.field(init=False) 690 num_leaves: int = dataclasses.field(init=False) 691 num_children: int = dataclasses.field(init=False) 692 693 def __post_init__(self) -> None: 694 num_nodes = sum((spec.num_nodes for spec in self.children_specs), start=1) 695 num_leaves = sum(spec.num_leaves for spec in self.children_specs) 696 num_children = len(self.children_specs) 697 object.__setattr__(self, "num_nodes", num_nodes) 698 object.__setattr__(self, "num_leaves", num_leaves) 699 object.__setattr__(self, "num_children", num_children) 700 701 def __repr__(self, indent: int = 0) -> str: 702 repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" 703 children_specs_str: str = "" 704 if self.num_children > 0: 705 indent += 2 706 children_specs_str += self.children_specs[0].__repr__(indent) 707 children_specs_str += "," if self.num_children > 1 else "" 708 children_specs_str += ",".join( 709 [ 710 "\n" + " " * indent + child.__repr__(indent) 711 for child in self.children_specs[1:] 712 ] 713 ) 714 repr_suffix: str = f"{children_specs_str}])" 715 return repr_prefix + repr_suffix 716 717 def is_leaf(self) -> bool: 718 return self.num_nodes == 1 and self.num_leaves == 1 719 720 def _flatten_up_to_helper(self, tree: PyTree, subtrees: List[PyTree]) -> None: 721 if self.is_leaf(): 722 subtrees.append(tree) 723 return 724 725 node_type = _get_node_type(tree) 726 if self.type not in BUILTIN_TYPES: 727 # Always require custom node types to match exactly 728 if node_type != self.type: 729 raise ValueError( 730 f"Type mismatch; " 731 f"expected {self.type!r}, but got {node_type!r}.", 732 ) 733 flatten_fn = SUPPORTED_NODES[node_type].flatten_fn 734 child_pytrees, context = flatten_fn(tree) 735 if len(child_pytrees) != self.num_children: 736 raise ValueError( 737 f"Node arity mismatch; " 738 f"expected {self.num_children}, but got {len(child_pytrees)}.", 739 ) 740 if context != self.context: 741 raise ValueError( 742 f"Node context mismatch for custom node type {self.type!r}.", 743 ) 744 else: 745 # For builtin dictionary types, we allow some flexibility 746 # Otherwise, we require exact matches 747 both_standard_dict = ( 748 self.type in STANDARD_DICT_TYPES and node_type in STANDARD_DICT_TYPES 749 ) 750 if node_type != self.type and not both_standard_dict: 751 raise ValueError( 752 f"Node type mismatch; " 753 f"expected {self.type!r}, but got {node_type!r}.", 754 ) 755 if len(tree) != self.num_children: 756 raise ValueError( 757 f"Node arity mismatch; " 758 f"expected {self.num_children}, but got {len(tree)}.", 759 ) 760 761 if both_standard_dict: # dictionary types are compatible with each other 762 dict_context = ( 763 self.context 764 if self.type is not defaultdict 765 # ignore mismatch of `default_factory` for defaultdict 766 else self.context[1] 767 ) 768 expected_keys = dict_context 769 got_key_set = set(tree) 770 expected_key_set = set(expected_keys) 771 if got_key_set != expected_key_set: 772 missing_keys = expected_key_set.difference(got_key_set) 773 extra_keys = got_key_set.difference(expected_key_set) 774 message = "" 775 if missing_keys: 776 message += f"; missing key(s): {missing_keys}" 777 if extra_keys: 778 message += f"; extra key(s): {extra_keys}" 779 raise ValueError(f"Node keys mismatch{message}.") 780 child_pytrees = [tree[key] for key in expected_keys] 781 else: 782 flatten_fn = SUPPORTED_NODES[node_type].flatten_fn 783 child_pytrees, context = flatten_fn(tree) 784 if ( 785 context != self.context 786 and self.type is not deque # ignore mismatch of `maxlen` for deque 787 ): 788 raise ValueError( 789 f"Node context mismatch for node type {self.type!r}; " 790 f"expected {self.context!r}, but got {context!r}.", # namedtuple type mismatch 791 ) 792 793 for child_pytree, child_spec in zip(child_pytrees, self.children_specs): 794 child_spec._flatten_up_to_helper(child_pytree, subtrees) 795 796 def flatten_up_to(self, tree: PyTree) -> List[PyTree]: 797 subtrees: List[PyTree] = [] 798 self._flatten_up_to_helper(tree, subtrees) 799 return subtrees 800 801 def unflatten(self, leaves: Iterable[Any]) -> PyTree: 802 if not isinstance(leaves, (list, tuple)): 803 leaves = list(leaves) 804 if len(leaves) != self.num_leaves: 805 raise ValueError( 806 f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} " 807 f"but the spec refers to a pytree that holds {self.num_leaves} " 808 f"items ({self}).", 809 ) 810 if self.is_leaf(): 811 return leaves[0] 812 813 unflatten_fn = SUPPORTED_NODES[self.type].unflatten_fn 814 815 # Recursively unflatten the children 816 start = 0 817 end = 0 818 child_pytrees = [] 819 for child_spec in self.children_specs: 820 end += child_spec.num_leaves 821 child_pytrees.append(child_spec.unflatten(leaves[start:end])) 822 start = end 823 824 return unflatten_fn(child_pytrees, self.context) 825 826 827class LeafSpec(TreeSpec): 828 def __init__(self) -> None: 829 super().__init__(None, None, []) 830 831 def __post_init__(self) -> None: 832 object.__setattr__(self, "num_nodes", 1) 833 object.__setattr__(self, "num_leaves", 1) 834 object.__setattr__(self, "num_children", 0) 835 836 def __repr__(self, indent: int = 0) -> str: 837 return "*" 838 839 840# All leaves are equivalent, so represent with a single object to save on 841# object construction time 842_LEAF_SPEC = LeafSpec() 843 844 845def _tree_flatten_helper( 846 tree: PyTree, 847 leaves: List[Any], 848 is_leaf: Optional[Callable[[PyTree], bool]] = None, 849) -> TreeSpec: 850 if _is_leaf(tree, is_leaf=is_leaf): 851 leaves.append(tree) 852 return _LEAF_SPEC 853 854 node_type = _get_node_type(tree) 855 flatten_fn = SUPPORTED_NODES[node_type].flatten_fn 856 child_pytrees, context = flatten_fn(tree) 857 858 # Recursively flatten the children 859 children_specs = [ 860 _tree_flatten_helper(child, leaves, is_leaf=is_leaf) for child in child_pytrees 861 ] 862 863 return TreeSpec(node_type, context, children_specs) 864 865 866def tree_flatten( 867 tree: PyTree, 868 is_leaf: Optional[Callable[[PyTree], bool]] = None, 869) -> Tuple[List[Any], TreeSpec]: 870 """Flattens a pytree into a list of values and a TreeSpec that can be used 871 to reconstruct the pytree. 872 """ 873 leaves: List[Any] = [] 874 spec = _tree_flatten_helper(tree, leaves, is_leaf=is_leaf) 875 return leaves, spec 876 877 878def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: 879 """Given a list of values and a TreeSpec, builds a pytree. 880 This is the inverse operation of `tree_flatten`. 881 """ 882 if not isinstance(treespec, TreeSpec): 883 raise TypeError( 884 f"tree_unflatten(leaves, treespec): Expected `treespec` to be " 885 f"instance of TreeSpec but got item of type {type(treespec)}.", 886 ) 887 return treespec.unflatten(leaves) 888 889 890def tree_iter( 891 tree: PyTree, 892 is_leaf: Optional[Callable[[PyTree], bool]] = None, 893) -> Iterable[Any]: 894 """Get an iterator over the leaves of a pytree.""" 895 if _is_leaf(tree, is_leaf=is_leaf): 896 yield tree 897 else: 898 node_type = _get_node_type(tree) 899 flatten_fn = SUPPORTED_NODES[node_type].flatten_fn 900 child_pytrees, _ = flatten_fn(tree) 901 902 # Recursively flatten the children 903 for child in child_pytrees: 904 yield from tree_iter(child, is_leaf=is_leaf) 905 906 907def tree_leaves( 908 tree: PyTree, 909 is_leaf: Optional[Callable[[PyTree], bool]] = None, 910) -> List[Any]: 911 """Get a list of leaves of a pytree.""" 912 return list(tree_iter(tree, is_leaf=is_leaf)) 913 914 915def tree_structure( 916 tree: PyTree, 917 is_leaf: Optional[Callable[[PyTree], bool]] = None, 918) -> TreeSpec: 919 """Get the TreeSpec for a pytree.""" 920 return tree_flatten(tree, is_leaf=is_leaf)[1] 921 922 923def tree_map( 924 func: Callable[..., Any], 925 tree: PyTree, 926 *rests: PyTree, 927 is_leaf: Optional[Callable[[PyTree], bool]] = None, 928) -> PyTree: 929 """Map a multi-input function over pytree args to produce a new pytree. 930 931 See also :func:`tree_map_`. 932 933 >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)}) 934 {'x': 8, 'y': (43, 65)} 935 >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}) 936 {'x': False, 'y': (False, False), 'z': True} 937 938 If multiple inputs are given, the structure of the tree is taken from the first input; 939 subsequent inputs need only have ``tree`` as a prefix: 940 941 >>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) 942 [[5, 7, 9], [6, 1, 2]] 943 944 Args: 945 func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the 946 corresponding leaves of the pytrees. 947 tree (pytree): A pytree to be mapped over, with each leaf providing the first positional 948 argument to function ``func``. 949 rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as 950 ``tree`` or has ``tree`` as a prefix. 951 is_leaf (callable, optional): An extra leaf predicate function that will be called at each 952 flattening step. The function should have a single argument with signature 953 ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated 954 as a leaf. Otherwise, the default pytree registry will be used to determine a node is a 955 leaf or not. If the function is not specified, the default pytree registry will be used. 956 957 Returns: 958 A new pytree with the same structure as ``tree`` but with the value at each leaf given by 959 ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs`` 960 is the tuple of values at corresponding nodes in ``rests``. 961 """ 962 leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) 963 flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] 964 return treespec.unflatten(map(func, *flat_args)) 965 966 967def tree_map_( 968 func: Callable[..., Any], 969 tree: PyTree, 970 *rests: PyTree, 971 is_leaf: Optional[Callable[[PyTree], bool]] = None, 972) -> PyTree: 973 """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree. 974 975 See also :func:`tree_map`. 976 977 Args: 978 func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the 979 corresponding leaves of the pytrees. 980 tree (pytree): A pytree to be mapped over, with each leaf providing the first positional 981 argument to function ``func``. 982 rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as 983 ``tree`` or has ``tree`` as a prefix. 984 is_leaf (callable, optional): An extra leaf predicate function that will be called at each 985 flattening step. The function should have a single argument with signature 986 ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated 987 as a leaf. Otherwise, the default pytree registry will be used to determine a node is a 988 leaf or not. If the function is not specified, the default pytree registry will be used. 989 990 Returns: 991 The original ``tree`` with the value at each leaf is given by the side-effect of function 992 ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf 993 in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``. 994 """ 995 leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) 996 flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] 997 tuple(map(func, *flat_args)) # consume and exhaust the iterable 998 return tree 999 1000 1001Type2 = Tuple[Type[T], Type[S]] 1002Type3 = Tuple[Type[T], Type[S], Type[U]] 1003if sys.version_info >= (3, 10): 1004 TypeAny = Union[Type[Any], Tuple[Type[Any], ...], types.UnionType] 1005else: 1006 TypeAny = Union[Type[Any], Tuple[Type[Any], ...]] 1007 1008Fn2 = Callable[[Union[T, S]], R] 1009Fn3 = Callable[[Union[T, S, U]], R] 1010Fn = Callable[[T], R] 1011FnAny = Callable[[Any], R] 1012 1013MapOnlyFn = Callable[[T], Callable[[Any], Any]] 1014 1015 1016# These specializations help with type inference on the lambda passed to this 1017# function 1018@overload 1019def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]: 1020 ... 1021 1022 1023@overload 1024def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]: 1025 ... 1026 1027 1028@overload 1029def map_only(__type_or_types_or_pred: Type[T]) -> MapOnlyFn[Fn[T, Any]]: 1030 ... 1031 1032 1033# This specialization is needed for the implementations below that call 1034@overload 1035def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]: 1036 ... 1037 1038 1039@overload 1040def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]: 1041 ... 1042 1043 1044def map_only( 1045 __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]] 1046) -> MapOnlyFn[FnAny[Any]]: 1047 """ 1048 Suppose you are writing a tree_map over tensors, leaving everything 1049 else unchanged. Ordinarily you would have to write: 1050 1051 def go(t): 1052 if isinstance(t, Tensor): 1053 return ... 1054 else: 1055 return t 1056 1057 With this function, you only need to write: 1058 1059 @map_only(Tensor) 1060 def go(t): 1061 return ... 1062 1063 You can also directly use 'tree_map_only' 1064 """ 1065 if isinstance(__type_or_types_or_pred, (type, tuple)) or ( 1066 sys.version_info >= (3, 10) 1067 and isinstance(__type_or_types_or_pred, types.UnionType) 1068 ): 1069 1070 def pred(x: Any) -> bool: 1071 return isinstance(x, __type_or_types_or_pred) # type: ignore[arg-type] 1072 1073 elif callable(__type_or_types_or_pred): 1074 pred = __type_or_types_or_pred # type: ignore[assignment] 1075 else: 1076 raise TypeError("Argument must be a type, a tuple of types, or a callable.") 1077 1078 def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]: 1079 @functools.wraps(func) 1080 def wrapped(x: T) -> Any: 1081 if pred(x): 1082 return func(x) 1083 return x 1084 1085 return wrapped 1086 1087 return wrapper 1088 1089 1090@overload 1091def tree_map_only( 1092 __type_or_types_or_pred: Type[T], 1093 func: Fn[T, Any], 1094 tree: PyTree, 1095 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1096) -> PyTree: 1097 ... 1098 1099 1100@overload 1101def tree_map_only( 1102 __type_or_types_or_pred: Type2[T, S], 1103 func: Fn2[T, S, Any], 1104 tree: PyTree, 1105 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1106) -> PyTree: 1107 ... 1108 1109 1110@overload 1111def tree_map_only( 1112 __type_or_types_or_pred: Type3[T, S, U], 1113 func: Fn3[T, S, U, Any], 1114 tree: PyTree, 1115 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1116) -> PyTree: 1117 ... 1118 1119 1120@overload 1121def tree_map_only( 1122 __type_or_types_or_pred: Callable[[Any], bool], 1123 func: FnAny[Any], 1124 tree: PyTree, 1125 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1126) -> PyTree: 1127 ... 1128 1129 1130def tree_map_only( 1131 __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], 1132 func: FnAny[Any], 1133 tree: PyTree, 1134 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1135) -> PyTree: 1136 return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf) 1137 1138 1139@overload 1140def tree_map_only_( 1141 __type_or_types_or_pred: Type[T], 1142 func: Fn[T, Any], 1143 tree: PyTree, 1144 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1145) -> PyTree: 1146 ... 1147 1148 1149@overload 1150def tree_map_only_( 1151 __type_or_types_or_pred: Type2[T, S], 1152 func: Fn2[T, S, Any], 1153 tree: PyTree, 1154 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1155) -> PyTree: 1156 ... 1157 1158 1159@overload 1160def tree_map_only_( 1161 __type_or_types_or_pred: Type3[T, S, U], 1162 func: Fn3[T, S, U, Any], 1163 tree: PyTree, 1164 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1165) -> PyTree: 1166 ... 1167 1168 1169@overload 1170def tree_map_only_( 1171 __type_or_types_or_pred: Callable[[Any], bool], 1172 func: FnAny[Any], 1173 tree: PyTree, 1174 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1175) -> PyTree: 1176 ... 1177 1178 1179def tree_map_only_( 1180 __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], 1181 func: FnAny[Any], 1182 tree: PyTree, 1183 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1184) -> PyTree: 1185 return tree_map_(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf) 1186 1187 1188def tree_all( 1189 pred: Callable[[Any], bool], 1190 tree: PyTree, 1191 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1192) -> bool: 1193 flat_args = tree_iter(tree, is_leaf=is_leaf) 1194 return all(map(pred, flat_args)) 1195 1196 1197def tree_any( 1198 pred: Callable[[Any], bool], 1199 tree: PyTree, 1200 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1201) -> bool: 1202 flat_args = tree_iter(tree, is_leaf=is_leaf) 1203 return any(map(pred, flat_args)) 1204 1205 1206@overload 1207def tree_all_only( 1208 __type_or_types: Type[T], 1209 pred: Fn[T, bool], 1210 tree: PyTree, 1211 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1212) -> bool: 1213 ... 1214 1215 1216@overload 1217def tree_all_only( 1218 __type_or_types: Type2[T, S], 1219 pred: Fn2[T, S, bool], 1220 tree: PyTree, 1221 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1222) -> bool: 1223 ... 1224 1225 1226@overload 1227def tree_all_only( 1228 __type_or_types: Type3[T, S, U], 1229 pred: Fn3[T, S, U, bool], 1230 tree: PyTree, 1231 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1232) -> bool: 1233 ... 1234 1235 1236def tree_all_only( 1237 __type_or_types: TypeAny, 1238 pred: FnAny[bool], 1239 tree: PyTree, 1240 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1241) -> bool: 1242 flat_args = tree_iter(tree, is_leaf=is_leaf) 1243 return all(pred(x) for x in flat_args if isinstance(x, __type_or_types)) 1244 1245 1246@overload 1247def tree_any_only( 1248 __type_or_types: Type[T], 1249 pred: Fn[T, bool], 1250 tree: PyTree, 1251 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1252) -> bool: 1253 ... 1254 1255 1256@overload 1257def tree_any_only( 1258 __type_or_types: Type2[T, S], 1259 pred: Fn2[T, S, bool], 1260 tree: PyTree, 1261 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1262) -> bool: 1263 ... 1264 1265 1266@overload 1267def tree_any_only( 1268 __type_or_types: Type3[T, S, U], 1269 pred: Fn3[T, S, U, bool], 1270 tree: PyTree, 1271 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1272) -> bool: 1273 ... 1274 1275 1276def tree_any_only( 1277 __type_or_types: TypeAny, 1278 pred: FnAny[bool], 1279 tree: PyTree, 1280 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1281) -> bool: 1282 flat_args = tree_iter(tree, is_leaf=is_leaf) 1283 return any(pred(x) for x in flat_args if isinstance(x, __type_or_types)) 1284 1285 1286# Broadcasts a pytree to the provided TreeSpec and returns the flattened 1287# values. If this is not possible, then this function returns None. 1288# 1289# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]), 1290# would return [0, 0]. This is useful for part of the vmap implementation: 1291# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be 1292# broadcastable to the tree structure of `inputs` and we use 1293# _broadcast_to_and_flatten to check this. 1294def _broadcast_to_and_flatten( 1295 tree: PyTree, 1296 treespec: TreeSpec, 1297 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1298) -> Optional[List[Any]]: 1299 assert isinstance(treespec, TreeSpec) 1300 1301 if _is_leaf(tree, is_leaf=is_leaf): 1302 return [tree] * treespec.num_leaves 1303 if treespec.is_leaf(): 1304 return None 1305 node_type = _get_node_type(tree) 1306 if node_type != treespec.type: 1307 return None 1308 1309 flatten_fn = SUPPORTED_NODES[node_type].flatten_fn 1310 child_pytrees, ctx = flatten_fn(tree) 1311 1312 # Check if the Node is different from the spec 1313 if len(child_pytrees) != treespec.num_children or ctx != treespec.context: 1314 return None 1315 1316 # Recursively flatten the children 1317 result: List[Any] = [] 1318 for child, child_spec in zip(child_pytrees, treespec.children_specs): 1319 flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf) 1320 if flat is not None: 1321 result += flat 1322 else: 1323 return None 1324 1325 return result 1326 1327 1328@dataclasses.dataclass 1329class _TreeSpecSchema: 1330 """ 1331 _TreeSpecSchema is the schema used to serialize the TreeSpec 1332 It contains the following fields: 1333 - type: A string name of the type. null for the case of a LeafSpec. 1334 - context: Any format which is json dumpable 1335 - children_spec: A list of children serialized specs. 1336 """ 1337 1338 type: Optional[str] 1339 context: DumpableContext 1340 children_spec: List["_TreeSpecSchema"] 1341 1342 1343class _ProtocolFn(NamedTuple): 1344 treespec_to_json: Callable[[TreeSpec], DumpableContext] 1345 json_to_treespec: Callable[[DumpableContext], TreeSpec] 1346 1347 1348_SUPPORTED_PROTOCOLS: Dict[int, _ProtocolFn] = {} 1349 1350 1351def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema: 1352 if treespec.is_leaf(): 1353 return _TreeSpecSchema(None, None, []) 1354 1355 if treespec.type not in SUPPORTED_SERIALIZED_TYPES: 1356 raise NotImplementedError( 1357 f"Serializing {treespec.type} in pytree is not registered.", 1358 ) 1359 1360 serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type] 1361 1362 serialized_type_name = serialize_node_def.serialized_type_name 1363 1364 if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND: 1365 raise NotImplementedError( 1366 f"No registered serialization name for {treespec.type} found. " 1367 "Please update your _register_pytree_node call with a `serialized_type_name` kwarg." 1368 ) 1369 1370 if serialize_node_def.to_dumpable_context is None: 1371 try: 1372 serialized_context = json.dumps(treespec.context) 1373 except TypeError as e: 1374 raise TypeError( 1375 "Unable to serialize context. " 1376 "Please make the context json dump-able, or register a " 1377 "custom serializer using _register_pytree_node." 1378 ) from e 1379 else: 1380 serialized_context = serialize_node_def.to_dumpable_context(treespec.context) 1381 1382 child_schemas = [_treespec_to_json(child) for child in treespec.children_specs] 1383 1384 return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas) 1385 1386 1387def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec: 1388 if ( 1389 json_schema["type"] is None 1390 and json_schema["context"] is None 1391 and len(json_schema["children_spec"]) == 0 1392 ): 1393 return _LEAF_SPEC 1394 1395 if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE: 1396 raise NotImplementedError( 1397 f'Deserializing {json_schema["type"]} in pytree is not registered.', 1398 ) 1399 1400 typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]] 1401 serialize_node_def = SUPPORTED_SERIALIZED_TYPES[typ] 1402 1403 if serialize_node_def.from_dumpable_context is None: 1404 try: 1405 context = json.loads(json_schema["context"]) 1406 except TypeError as ex: 1407 raise TypeError( 1408 "Unable to deserialize context. " 1409 "Please make the context json load-able, or register a " 1410 "custom serializer using _register_pytree_node.", 1411 ) from ex 1412 else: 1413 context = serialize_node_def.from_dumpable_context(json_schema["context"]) 1414 1415 children_specs = [] 1416 for child_string in json_schema["children_spec"]: 1417 children_specs.append(_json_to_treespec(child_string)) 1418 1419 return TreeSpec(typ, context, children_specs) 1420 1421 1422_SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec) 1423 1424 1425def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str: 1426 if not isinstance(treespec, TreeSpec): 1427 raise TypeError( 1428 f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of " 1429 f"TreeSpec but got item of type {type(treespec)}.", 1430 ) 1431 1432 if protocol is None: 1433 protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL 1434 1435 if protocol in _SUPPORTED_PROTOCOLS: 1436 json_spec = _SUPPORTED_PROTOCOLS[protocol].treespec_to_json(treespec) 1437 else: 1438 raise ValueError( 1439 f"Unknown protocol {protocol}. " 1440 f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}", 1441 ) 1442 1443 str_spec = json.dumps((protocol, dataclasses.asdict(json_spec))) 1444 return str_spec 1445 1446 1447def treespec_loads(serialized: str) -> TreeSpec: 1448 protocol, json_schema = json.loads(serialized) 1449 1450 if protocol in _SUPPORTED_PROTOCOLS: 1451 return _SUPPORTED_PROTOCOLS[protocol].json_to_treespec(json_schema) 1452 raise ValueError( 1453 f"Unknown protocol {protocol}. " 1454 f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}", 1455 ) 1456 1457 1458class _DummyLeaf: 1459 def __repr__(self) -> str: 1460 return "*" 1461 1462 1463def treespec_pprint(treespec: TreeSpec) -> str: 1464 dummy_tree = tree_unflatten( 1465 [_DummyLeaf() for _ in range(treespec.num_leaves)], 1466 treespec, 1467 ) 1468 return repr(dummy_tree) 1469 1470 1471# TODO(angelayi): remove this function after OSS/internal stabilize 1472@deprecated( 1473 "`pytree_to_str` is deprecated. Please use `treespec_dumps` instead.", 1474 category=FutureWarning, 1475) 1476def pytree_to_str(treespec: TreeSpec) -> str: 1477 return treespec_dumps(treespec) 1478 1479 1480# TODO(angelayi): remove this function after OSS/internal stabilize 1481@deprecated( 1482 "`str_to_pytree` is deprecated. Please use `treespec_loads` instead.", 1483 category=FutureWarning, 1484) 1485def str_to_pytree(json: str) -> TreeSpec: 1486 return treespec_loads(json) 1487 1488 1489def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> List[Any]: 1490 """Get a flat list of arguments to this function 1491 1492 A slightly faster version of tree_leaves((args, kwargs)) 1493 """ 1494 leaves: List[Any] = [] 1495 for a in args: 1496 leaves.extend(tree_iter(a)) 1497 for a in kwargs.values(): 1498 leaves.extend(tree_iter(a)) 1499 return leaves 1500 1501 1502def tree_flatten_with_path( 1503 tree: PyTree, 1504 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1505) -> Tuple[List[Tuple[KeyPath, Any]], TreeSpec]: 1506 """Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path. 1507 1508 Args: 1509 tree: a pytree to flatten. If it contains a custom type, that type must be 1510 registered with an appropriate `tree_flatten_with_path_fn` when registered 1511 with :func:`register_pytree_node`. 1512 is_leaf: An extra leaf predicate function that will be called at each 1513 flattening step. The function should have a single argument with signature 1514 ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated 1515 as a leaf. Otherwise, the default pytree registry will be used to determine a node is a 1516 leaf or not. If the function is not specified, the default pytree registry will be used. 1517 Returns: 1518 A tuple where the first element is a list of (key path, leaf) pairs, and the 1519 second element is a :class:`TreeSpec` representing the structure of the flattened 1520 tree. 1521 """ 1522 _, treespec = tree_flatten(tree, is_leaf) 1523 return list(_generate_key_paths((), tree, is_leaf)), treespec 1524 1525 1526def tree_leaves_with_path( 1527 tree: PyTree, 1528 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1529) -> List[Tuple[KeyPath, Any]]: 1530 """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path. 1531 1532 Args: 1533 tree: a pytree. If it contains a custom type, that type must be 1534 registered with an appropriate `tree_flatten_with_path_fn` when registered 1535 with :func:`register_pytree_node`. 1536 is_leaf: An extra leaf predicate function that will be called at each 1537 flattening step. The function should have a single argument with signature 1538 ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated 1539 as a leaf. Otherwise, the default pytree registry will be used to determine a node is a 1540 leaf or not. If the function is not specified, the default pytree registry will be used. 1541 Returns: 1542 A list of (key path, leaf) pairs. 1543 """ 1544 return list(_generate_key_paths((), tree, is_leaf)) 1545 1546 1547def _generate_key_paths( 1548 key_path: KeyPath, 1549 tree: PyTree, 1550 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1551) -> Iterable[Tuple[KeyPath, Any]]: 1552 if is_leaf and is_leaf(tree): 1553 yield key_path, tree 1554 return 1555 1556 node_type = _get_node_type(tree) 1557 handler = SUPPORTED_NODES.get(node_type) 1558 if not handler: 1559 # This is a leaf 1560 yield key_path, tree 1561 return 1562 1563 flatten_with_keys = handler.flatten_with_keys_fn 1564 if flatten_with_keys: 1565 key_children, _ = flatten_with_keys(tree) 1566 for k, c in key_children: 1567 yield from _generate_key_paths((*key_path, k), c, is_leaf) 1568 else: 1569 # We registered this pytree but didn't add a flatten_with_keys_fn, complain. 1570 raise ValueError( 1571 f"Did not find a flatten_with_keys_fn for type: {node_type}. " 1572 "Please pass a flatten_with_keys_fn argument to register_pytree_node." 1573 ) 1574 1575 1576def tree_map_with_path( 1577 func: Callable[..., Any], 1578 tree: PyTree, 1579 *rests: PyTree, 1580 is_leaf: Optional[Callable[[PyTree], bool]] = None, 1581) -> PyTree: 1582 """Like :func:`tree_map`, but the provided callable takes an additional key path argument. 1583 1584 Args: 1585 func: A function that takes ``2 + len(rests)`` arguments, to be applied at the 1586 corresponding leaves of the pytrees. The first positional argument 1587 to ``func`` is the key path of the leaf in question. The second 1588 positional argument is the value of the leaf. 1589 tree: A pytree to be mapped over, with each leaf providing the first positional 1590 argument to function ``func``. 1591 rests: A tuple of pytrees, each of which has the same structure as 1592 ``tree`` or has ``tree`` as a prefix. 1593 is_leaf: An extra leaf predicate function that will be called at each 1594 flattening step. The function should have a single argument with signature 1595 ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated 1596 as a leaf. Otherwise, the default pytree registry will be used to determine a node is a 1597 leaf or not. If the function is not specified, the default pytree registry will be used. 1598 1599 Returns 1600 A new pytree with the same structure as ``tree`` but with the value at each leaf given by 1601 ``func(keypath, x, *xs)`` where ``keypath`` is the key path at the 1602 corresponding leaf in ``tree``, ``x`` is the value at that leaf, and 1603 ``xs`` is the tuple of values at corresponding nodes in ``rests``. 1604 """ 1605 keypath_leaves, treespec = tree_flatten_with_path(tree, is_leaf) 1606 keypath_leaves = list(zip(*keypath_leaves)) 1607 all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests] 1608 return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves)) 1609 1610 1611def keystr(kp: KeyPath) -> str: 1612 """Given a key path, return a pretty-printed representation.""" 1613 return "".join([str(k) for k in kp]) 1614 1615 1616def key_get(obj: Any, kp: KeyPath) -> Any: 1617 """Given an object and a key path, return the value at the key path.""" 1618 for k in kp: 1619 obj = k.get(obj) 1620 return obj 1621