xref: /aosp_15_r20/external/pytorch/torch/utils/_pytree.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""
2Contains utility functions for working with nested python data structures.
3
4A *pytree* is Python nested data structure. It is a tree in the sense that
5nodes are Python collections (e.g., list, tuple, dict) and the leaves are
6Python values. Furthermore, a pytree should not contain reference cycles.
7
8pytrees are useful for working with nested collections of Tensors. For example,
9one can use `tree_map` to map a function over all Tensors inside some nested
10collection of Tensors and `tree_leaves` to get a flat list of all Tensors
11inside some nested collection. pytrees are helpful for implementing nested
12collection support for PyTorch APIs.
13
14This pytree implementation is not very performant due to Python overhead
15To improve the performance we can move parts of the implementation to C++.
16"""
17
18import dataclasses
19import functools
20import importlib
21import json
22import sys
23import threading
24import types
25import warnings
26from collections import defaultdict, deque, namedtuple, OrderedDict
27from typing import (
28    Any,
29    Callable,
30    cast,
31    DefaultDict,
32    Deque,
33    Dict,
34    FrozenSet,
35    Generic,
36    Hashable,
37    Iterable,
38    List,
39    Mapping,
40    NamedTuple,
41    Optional,
42    OrderedDict as GenericOrderedDict,
43    overload,
44    Protocol,
45    Sequence,
46    Tuple,
47    Type,
48    TypeVar,
49    Union,
50)
51from typing_extensions import deprecated
52
53
54__all__ = [
55    "PyTree",
56    "Context",
57    "FlattenFunc",
58    "UnflattenFunc",
59    "DumpableContext",
60    "ToDumpableContextFn",
61    "FromDumpableContextFn",
62    "TreeSpec",
63    "LeafSpec",
64    "keystr",
65    "key_get",
66    "register_pytree_node",
67    "tree_flatten",
68    "tree_flatten_with_path",
69    "tree_unflatten",
70    "tree_iter",
71    "tree_leaves",
72    "tree_leaves_with_path",
73    "tree_structure",
74    "tree_map",
75    "tree_map_with_path",
76    "tree_map_",
77    "tree_map_only",
78    "tree_map_only_",
79    "tree_all",
80    "tree_any",
81    "tree_all_only",
82    "tree_any_only",
83    "treespec_dumps",
84    "treespec_loads",
85    "treespec_pprint",
86]
87
88
89T = TypeVar("T")
90S = TypeVar("S")
91U = TypeVar("U")
92R = TypeVar("R")
93
94
95DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL = 1
96NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND"
97
98
99class KeyEntry(Protocol):
100    def __hash__(self) -> int:
101        ...
102
103    def __eq__(self, other: object) -> bool:
104        ...
105
106    def __str__(self) -> str:
107        ...
108
109    def get(self, parent: Any) -> Any:
110        ...
111
112
113Context = Any
114PyTree = Any
115FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
116UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
117DumpableContext = Any  # Any json dumpable text
118ToDumpableContextFn = Callable[[Context], DumpableContext]
119FromDumpableContextFn = Callable[[DumpableContext], Context]
120ToStrFunc = Callable[["TreeSpec", List[str]], str]
121MaybeFromStrFunc = Callable[[str], Optional[Tuple[Any, Context, str]]]
122KeyPath = Tuple[KeyEntry, ...]
123FlattenWithKeysFunc = Callable[[PyTree], Tuple[List[Tuple[KeyEntry, Any]], Any]]
124
125
126# A NodeDef holds two callables:
127# - flatten_fn should take the collection and return a flat list of values.
128#   It can also return some context that is used in reconstructing the
129#   collection.
130# - unflatten_fn should take a flat list of values and some context
131#   (returned by flatten_fn). It returns the collection by reconstructing
132#   it from the list and the context.
133# - flatten_with_keys_fn, which is a callable that takes a
134#   pytree and returns a list of (keypath, value) pairs and a context.
135class NodeDef(NamedTuple):
136    type: Type[Any]
137    flatten_fn: FlattenFunc
138    unflatten_fn: UnflattenFunc
139    flatten_with_keys_fn: Optional[FlattenWithKeysFunc]
140
141
142_NODE_REGISTRY_LOCK = threading.Lock()
143SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {}
144
145
146# _SerializeNodeDef holds the following:
147# - typ: the type of the node (e.g., "Dict", "List", etc)
148# - serialized_type_name: the fully qualified name of the type, e.g. "collections.OrderedDict"
149# - to_dumpable_context takes a TreeSpec, and returns a serialized string format of the
150#   context, and the version number
151# - from_dumpable_context takes in a string representation of the context, and the
152#   version, and returns the deserialized context
153class _SerializeNodeDef(NamedTuple):
154    typ: Type[Any]
155    serialized_type_name: str
156    to_dumpable_context: Optional[ToDumpableContextFn]
157    from_dumpable_context: Optional[FromDumpableContextFn]
158
159
160SUPPORTED_SERIALIZED_TYPES: Dict[Type[Any], _SerializeNodeDef] = {}
161SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {}
162
163# NB: we try really hard to not import _cxx_pytree (which depends on optree)
164# as much as possible. This is for isolation: a user who is not using C++ pytree
165# shouldn't pay for it, and it helps makes things like cpython upgrades easier.
166_cxx_pytree_exists = importlib.util.find_spec("optree")  # type: ignore[attr-defined]
167_cxx_pytree_imported = False
168_cxx_pytree_pending_imports: List[Any] = []
169
170
171def register_pytree_node(
172    cls: Type[Any],
173    flatten_fn: FlattenFunc,
174    unflatten_fn: UnflattenFunc,
175    *,
176    serialized_type_name: Optional[str] = None,
177    to_dumpable_context: Optional[ToDumpableContextFn] = None,
178    from_dumpable_context: Optional[FromDumpableContextFn] = None,
179    flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
180) -> None:
181    """Register a container-like type as pytree node.
182
183    Args:
184        cls: the type to register
185        flatten_fn: A callable that takes a pytree and returns a flattened
186            representation of the pytree and additional context to represent the
187            flattened pytree.
188        unflatten_fn: A callable that takes a flattened version of the pytree,
189            additional context, and returns an unflattened pytree.
190        serialized_type_name: A keyword argument used to specify the fully qualified
191            name used when serializing the tree spec.
192        to_dumpable_context: An optional keyword argument to custom specify how
193            to convert the context of the pytree to a custom json dumpable
194            representation. This is used for json serialization, which is being
195            used in torch.export right now.
196        from_dumpable_context: An optional keyword argument to custom specify how
197            to convert the custom json dumpable representation of the context
198            back to the original context. This is used for json deserialization,
199            which is being used in torch.export right now.
200        flatten_with_keys_fn: An optional keyword argument to specify how to
201            access each pytree leaf's keypath when flattening and tree-mapping.
202            Like ``flatten_fn``, but in place of a List[leaf], it should return
203            a List[(keypath, leaf)].
204    """
205    with _NODE_REGISTRY_LOCK:
206        if cls in SUPPORTED_NODES:
207            raise ValueError(f"{cls} is already registered as pytree node.")
208
209    _private_register_pytree_node(
210        cls,
211        flatten_fn,
212        unflatten_fn,
213        serialized_type_name=serialized_type_name,
214        to_dumpable_context=to_dumpable_context,
215        from_dumpable_context=from_dumpable_context,
216        flatten_with_keys_fn=flatten_with_keys_fn,
217    )
218
219    if not _cxx_pytree_exists:
220        return
221
222    if _cxx_pytree_imported:
223        from . import _cxx_pytree as cxx
224
225        cxx._private_register_pytree_node(
226            cls,
227            flatten_fn,
228            unflatten_fn,
229            serialized_type_name=serialized_type_name,
230            to_dumpable_context=to_dumpable_context,
231            from_dumpable_context=from_dumpable_context,
232        )
233    else:
234        args = (cls, flatten_fn, unflatten_fn)
235        kwargs = {
236            "serialized_type_name": serialized_type_name,
237            "to_dumpable_context": to_dumpable_context,
238            "from_dumpable_context": from_dumpable_context,
239        }
240        _cxx_pytree_pending_imports.append((args, kwargs))
241
242
243def _register_namedtuple(
244    cls: Type[Any],
245    *,
246    serialized_type_name: str,
247) -> None:
248    """
249    Registers a namedtuple as a valid pytree node. By default namedtuples are
250    valid pytree nodes, but they are not serializable. This API provides the
251    argument `serialized_type_name` which allows these namedtuples to be
252    serialized.
253
254    Args:
255        cls: the dataclass type to register
256        serialized_type_name: The serialized name for the dataclass. This is
257        required if you want to serialize the pytree TreeSpec containing this
258        namedtuple.
259    """
260    _private_register_pytree_node(
261        cls,
262        _namedtuple_flatten,
263        _namedtuple_unflatten,
264        serialized_type_name=serialized_type_name,
265        to_dumpable_context=_namedtuple_serialize,
266        from_dumpable_context=_namedtuple_deserialize,
267        flatten_with_keys_fn=_namedtuple_flatten_with_keys,
268    )
269
270
271@deprecated(
272    "`torch.utils._pytree._register_pytree_node` is deprecated. "
273    "Please use `torch.utils._pytree.register_pytree_node` instead.",
274    category=FutureWarning,
275)
276def _register_pytree_node(
277    cls: Type[Any],
278    flatten_fn: FlattenFunc,
279    unflatten_fn: UnflattenFunc,
280    to_str_fn: Optional[ToStrFunc] = None,  # deprecated
281    maybe_from_str_fn: Optional[MaybeFromStrFunc] = None,  # deprecated
282    *,
283    serialized_type_name: Optional[str] = None,
284    to_dumpable_context: Optional[ToDumpableContextFn] = None,
285    from_dumpable_context: Optional[FromDumpableContextFn] = None,
286    flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
287) -> None:
288    """Register a container-like type as pytree node for the Python pytree only.
289
290    Args:
291        cls: the type to register
292        flatten_fn: A callable that takes a pytree and returns a flattened
293            representation of the pytree and additional context to represent the
294            flattened pytree.
295        unflatten_fn: A callable that takes a flattened version of the pytree,
296            additional context, and returns an unflattened pytree.
297        serialized_type_name: A keyword argument used to specify the fully qualified
298            name used when serializing the tree spec.
299        to_dumpable_context: An optional keyword argument to custom specify how
300            to convert the context of the pytree to a custom json dumpable
301            representation. This is used for json serialization, which is being
302            used in torch.export right now.
303        from_dumpable_context: An optional keyword argument to custom specify how
304            to convert the custom json dumpable representation of the context
305            back to the original context. This is used for json deserialization,
306            which is being used in torch.export right now.
307        flatten_with_keys_fn: An optional keyword argument to specify how to
308            access each pytree leaf's keypath when flattening and tree-mapping.
309            Like ``flatten_fn``, but in place of a List[leaf], it should return
310            a List[(keypath, leaf)].
311    """
312    if to_str_fn is not None or maybe_from_str_fn is not None:
313        warnings.warn(
314            "`to_str_fn` and `maybe_from_str_fn` is deprecated. "
315            "Please use `to_dumpable_context` and `from_dumpable_context` instead.",
316            FutureWarning,
317            stacklevel=2,
318        )
319
320    _private_register_pytree_node(
321        cls,
322        flatten_fn,
323        unflatten_fn,
324        serialized_type_name=serialized_type_name,
325        to_dumpable_context=to_dumpable_context,
326        from_dumpable_context=from_dumpable_context,
327        flatten_with_keys_fn=flatten_with_keys_fn,
328    )
329
330
331def _private_register_pytree_node(
332    cls: Type[Any],
333    flatten_fn: FlattenFunc,
334    unflatten_fn: UnflattenFunc,
335    *,
336    serialized_type_name: Optional[str] = None,
337    to_dumpable_context: Optional[ToDumpableContextFn] = None,
338    from_dumpable_context: Optional[FromDumpableContextFn] = None,
339    flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None,
340) -> None:
341    """This is an internal function that is used to register a pytree node type
342    for the Python pytree only. End-users should use :func:`register_pytree_node`
343    instead.
344    """
345    with _NODE_REGISTRY_LOCK:
346        if cls in SUPPORTED_NODES:
347            # TODO: change this warning to an error after OSS/internal stabilize
348            warnings.warn(
349                f"{cls} is already registered as pytree node. "
350                "Overwriting the previous registration.",
351            )
352
353        node_def = NodeDef(cls, flatten_fn, unflatten_fn, flatten_with_keys_fn)
354        SUPPORTED_NODES[cls] = node_def
355
356        if (to_dumpable_context is None) ^ (from_dumpable_context is None):
357            raise ValueError(
358                f"Both to_dumpable_context and from_dumpable_context for {cls} must "
359                "be None or registered."
360            )
361
362        if serialized_type_name is None:
363            serialized_type_name = NO_SERIALIZED_TYPE_NAME_FOUND
364
365        serialize_node_def = _SerializeNodeDef(
366            cls,
367            serialized_type_name,
368            to_dumpable_context,
369            from_dumpable_context,
370        )
371        SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def
372        SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls
373
374
375@dataclasses.dataclass(frozen=True)
376class SequenceKey(Generic[T]):
377    idx: int
378
379    def __str__(self) -> str:
380        return f"[{self.idx!r}]"
381
382    def get(self, sequence: Sequence[T]) -> T:
383        return sequence[self.idx]
384
385
386K = TypeVar("K", bound=Hashable)
387
388
389@dataclasses.dataclass(frozen=True)
390class MappingKey(Generic[K, T]):
391    key: K
392
393    def __str__(self) -> str:
394        return f"[{self.key!r}]"
395
396    def get(self, mapping: Mapping[K, T]) -> T:
397        return mapping[self.key]
398
399
400@dataclasses.dataclass(frozen=True)
401class GetAttrKey:
402    name: str
403
404    def __str__(self) -> str:
405        return f".{self.name}"
406
407    def get(self, obj: Any) -> Any:
408        return getattr(obj, self.name)
409
410
411def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]:
412    return list(d), None
413
414
415def _tuple_flatten_with_keys(
416    d: Tuple[Any, ...]
417) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
418    values, context = _tuple_flatten(d)
419    return [(SequenceKey(i), v) for i, v in enumerate(values)], context
420
421
422def _tuple_unflatten(values: Iterable[Any], context: Context) -> Tuple[Any, ...]:
423    return tuple(values)
424
425
426def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]:
427    return d, None
428
429
430def _list_flatten_with_keys(d: List[Any]) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
431    values, context = _list_flatten(d)
432    return [(SequenceKey(i), v) for i, v in enumerate(values)], context
433
434
435def _list_unflatten(values: Iterable[Any], context: Context) -> List[Any]:
436    return list(values)
437
438
439def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]:
440    return list(d.values()), list(d.keys())
441
442
443def _dict_flatten_with_keys(
444    d: Dict[Any, Any]
445) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
446    values, context = _dict_flatten(d)
447    return [(MappingKey(k), v) for k, v in zip(context, values)], context
448
449
450def _dict_unflatten(values: Iterable[Any], context: Context) -> Dict[Any, Any]:
451    return dict(zip(context, values))
452
453
454def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]:
455    return list(d), type(d)
456
457
458def _namedtuple_flatten_with_keys(
459    d: NamedTuple,
460) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
461    values, context = _namedtuple_flatten(d)
462    return (
463        [(GetAttrKey(field), v) for field, v in zip(context._fields, values)],
464        context,
465    )
466
467
468def _namedtuple_unflatten(values: Iterable[Any], context: Context) -> NamedTuple:
469    return cast(NamedTuple, context(*values))
470
471
472def _namedtuple_serialize(context: Context) -> DumpableContext:
473    if context not in SUPPORTED_SERIALIZED_TYPES:
474        raise NotImplementedError(
475            f"Can't serialize TreeSpec of namedtuple class {context} because we "
476            "didn't register a serializated_type_name. Please register using "
477            "`_register_namedtuple`."
478        )
479
480    serialize_node_def = SUPPORTED_SERIALIZED_TYPES[context]
481    serialized_type_name = serialize_node_def.serialized_type_name
482
483    if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND:
484        raise NotImplementedError(
485            f"Can't serialize TreeSpec of namedtuple class {context} because we "
486            "couldn't find a serializated_type_name. Please register using "
487            "`_register_namedtuple`."
488        )
489    return serialized_type_name
490
491
492def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context:
493    if dumpable_context not in SERIALIZED_TYPE_TO_PYTHON_TYPE:
494        raise NotImplementedError(
495            f"Can't deserialize TreeSpec of namedtuple class {dumpable_context} "
496            "because we couldn't find a serializated name."
497        )
498
499    typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[dumpable_context]
500    return typ
501
502
503def _ordereddict_flatten(d: GenericOrderedDict[Any, Any]) -> Tuple[List[Any], Context]:
504    return list(d.values()), list(d.keys())
505
506
507def _ordereddict_flatten_with_keys(
508    d: GenericOrderedDict[Any, Any]
509) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
510    values, context = _ordereddict_flatten(d)
511    return [(MappingKey(k), v) for k, v in zip(context, values)], context
512
513
514def _ordereddict_unflatten(
515    values: Iterable[Any],
516    context: Context,
517) -> GenericOrderedDict[Any, Any]:
518    return OrderedDict((key, value) for key, value in zip(context, values))
519
520
521_odict_flatten = _ordereddict_flatten
522_odict_unflatten = _ordereddict_unflatten
523
524
525def _defaultdict_flatten(d: DefaultDict[Any, Any]) -> Tuple[List[Any], Context]:
526    values, dict_context = _dict_flatten(d)
527    return values, [d.default_factory, dict_context]
528
529
530def _defaultdict_flatten_with_keys(
531    d: DefaultDict[Any, Any]
532) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
533    values, context = _defaultdict_flatten(d)
534    _, dict_context = context
535    return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context
536
537
538def _defaultdict_unflatten(
539    values: Iterable[Any],
540    context: Context,
541) -> DefaultDict[Any, Any]:
542    default_factory, dict_context = context
543    return defaultdict(default_factory, _dict_unflatten(values, dict_context))
544
545
546def _defaultdict_serialize(context: Context) -> DumpableContext:
547    default_factory, dict_context = context
548    json_defaultdict = {
549        "default_factory_module": default_factory.__module__,
550        "default_factory_name": default_factory.__qualname__,
551        "dict_context": dict_context,
552    }
553    return json_defaultdict
554
555
556def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context:
557    assert isinstance(dumpable_context, dict)
558    assert set(dumpable_context) == {
559        "default_factory_module",
560        "default_factory_name",
561        "dict_context",
562    }
563
564    default_factory_module = dumpable_context["default_factory_module"]
565    default_factory_name = dumpable_context["default_factory_name"]
566    assert isinstance(default_factory_module, str)
567    assert isinstance(default_factory_name, str)
568    module = importlib.import_module(default_factory_module)
569    default_factory = getattr(module, default_factory_name)
570
571    dict_context = dumpable_context["dict_context"]
572    return [default_factory, dict_context]
573
574
575def _deque_flatten(d: Deque[Any]) -> Tuple[List[Any], Context]:
576    return list(d), d.maxlen
577
578
579def _deque_flatten_with_keys(
580    d: Deque[Any],
581) -> Tuple[List[Tuple[KeyEntry, Any]], Context]:
582    values, context = _deque_flatten(d)
583    return [(SequenceKey(i), v) for i, v in enumerate(values)], context
584
585
586def _deque_unflatten(values: Iterable[Any], context: Context) -> Deque[Any]:
587    return deque(values, maxlen=context)
588
589
590_private_register_pytree_node(
591    tuple,
592    _tuple_flatten,
593    _tuple_unflatten,
594    serialized_type_name="builtins.tuple",
595    flatten_with_keys_fn=_tuple_flatten_with_keys,
596)
597_private_register_pytree_node(
598    list,
599    _list_flatten,
600    _list_unflatten,
601    serialized_type_name="builtins.list",
602    flatten_with_keys_fn=_list_flatten_with_keys,
603)
604_private_register_pytree_node(
605    dict,
606    _dict_flatten,
607    _dict_unflatten,
608    serialized_type_name="builtins.dict",
609    flatten_with_keys_fn=_dict_flatten_with_keys,
610)
611_private_register_pytree_node(
612    namedtuple,  # type: ignore[arg-type]
613    _namedtuple_flatten,
614    _namedtuple_unflatten,
615    serialized_type_name="collections.namedtuple",
616    to_dumpable_context=_namedtuple_serialize,
617    from_dumpable_context=_namedtuple_deserialize,
618    flatten_with_keys_fn=_namedtuple_flatten_with_keys,
619)
620_private_register_pytree_node(
621    OrderedDict,
622    _ordereddict_flatten,
623    _ordereddict_unflatten,
624    serialized_type_name="collections.OrderedDict",
625    flatten_with_keys_fn=_ordereddict_flatten_with_keys,
626)
627_private_register_pytree_node(
628    defaultdict,
629    _defaultdict_flatten,
630    _defaultdict_unflatten,
631    serialized_type_name="collections.defaultdict",
632    to_dumpable_context=_defaultdict_serialize,
633    from_dumpable_context=_defaultdict_deserialize,
634    flatten_with_keys_fn=_defaultdict_flatten_with_keys,
635)
636_private_register_pytree_node(
637    deque,
638    _deque_flatten,
639    _deque_unflatten,
640    serialized_type_name="collections.deque",
641    flatten_with_keys_fn=_deque_flatten_with_keys,
642)
643
644
645STANDARD_DICT_TYPES: FrozenSet[type] = frozenset(
646    {dict, OrderedDict, defaultdict},
647)
648BUILTIN_TYPES: FrozenSet[type] = frozenset(
649    {tuple, list, dict, namedtuple, OrderedDict, defaultdict, deque},  # type: ignore[arg-type]
650)
651
652
653# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple
654def _is_namedtuple_instance(tree: Any) -> bool:
655    typ = type(tree)
656    bases = typ.__bases__
657    if len(bases) != 1 or bases[0] != tuple:
658        return False
659    fields = getattr(typ, "_fields", None)
660    if not isinstance(fields, tuple):
661        return False
662    return all(type(entry) == str for entry in fields)
663
664
665def _get_node_type(tree: Any) -> Any:
666    if _is_namedtuple_instance(tree):
667        return namedtuple
668    return type(tree)
669
670
671# A leaf is defined as anything that is not a Node.
672def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -> bool:
673    return (is_leaf is not None and is_leaf(tree)) or _get_node_type(
674        tree
675    ) not in SUPPORTED_NODES
676
677
678# A TreeSpec represents the structure of a pytree. It holds:
679# "type": the type of root Node of the pytree
680# context: some context that is useful in unflattening the pytree
681# children_specs: specs for each child of the root Node
682# num_leaves: the number of leaves
683@dataclasses.dataclass(init=True, frozen=True, eq=True, repr=False)
684class TreeSpec:
685    type: Any
686    context: Context
687    children_specs: List["TreeSpec"]
688
689    num_nodes: int = dataclasses.field(init=False)
690    num_leaves: int = dataclasses.field(init=False)
691    num_children: int = dataclasses.field(init=False)
692
693    def __post_init__(self) -> None:
694        num_nodes = sum((spec.num_nodes for spec in self.children_specs), start=1)
695        num_leaves = sum(spec.num_leaves for spec in self.children_specs)
696        num_children = len(self.children_specs)
697        object.__setattr__(self, "num_nodes", num_nodes)
698        object.__setattr__(self, "num_leaves", num_leaves)
699        object.__setattr__(self, "num_children", num_children)
700
701    def __repr__(self, indent: int = 0) -> str:
702        repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, ["
703        children_specs_str: str = ""
704        if self.num_children > 0:
705            indent += 2
706            children_specs_str += self.children_specs[0].__repr__(indent)
707            children_specs_str += "," if self.num_children > 1 else ""
708            children_specs_str += ",".join(
709                [
710                    "\n" + " " * indent + child.__repr__(indent)
711                    for child in self.children_specs[1:]
712                ]
713            )
714        repr_suffix: str = f"{children_specs_str}])"
715        return repr_prefix + repr_suffix
716
717    def is_leaf(self) -> bool:
718        return self.num_nodes == 1 and self.num_leaves == 1
719
720    def _flatten_up_to_helper(self, tree: PyTree, subtrees: List[PyTree]) -> None:
721        if self.is_leaf():
722            subtrees.append(tree)
723            return
724
725        node_type = _get_node_type(tree)
726        if self.type not in BUILTIN_TYPES:
727            # Always require custom node types to match exactly
728            if node_type != self.type:
729                raise ValueError(
730                    f"Type mismatch; "
731                    f"expected {self.type!r}, but got {node_type!r}.",
732                )
733            flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
734            child_pytrees, context = flatten_fn(tree)
735            if len(child_pytrees) != self.num_children:
736                raise ValueError(
737                    f"Node arity mismatch; "
738                    f"expected {self.num_children}, but got {len(child_pytrees)}.",
739                )
740            if context != self.context:
741                raise ValueError(
742                    f"Node context mismatch for custom node type {self.type!r}.",
743                )
744        else:
745            # For builtin dictionary types, we allow some flexibility
746            # Otherwise, we require exact matches
747            both_standard_dict = (
748                self.type in STANDARD_DICT_TYPES and node_type in STANDARD_DICT_TYPES
749            )
750            if node_type != self.type and not both_standard_dict:
751                raise ValueError(
752                    f"Node type mismatch; "
753                    f"expected {self.type!r}, but got {node_type!r}.",
754                )
755            if len(tree) != self.num_children:
756                raise ValueError(
757                    f"Node arity mismatch; "
758                    f"expected {self.num_children}, but got {len(tree)}.",
759                )
760
761            if both_standard_dict:  # dictionary types are compatible with each other
762                dict_context = (
763                    self.context
764                    if self.type is not defaultdict
765                    # ignore mismatch of `default_factory` for defaultdict
766                    else self.context[1]
767                )
768                expected_keys = dict_context
769                got_key_set = set(tree)
770                expected_key_set = set(expected_keys)
771                if got_key_set != expected_key_set:
772                    missing_keys = expected_key_set.difference(got_key_set)
773                    extra_keys = got_key_set.difference(expected_key_set)
774                    message = ""
775                    if missing_keys:
776                        message += f"; missing key(s): {missing_keys}"
777                    if extra_keys:
778                        message += f"; extra key(s): {extra_keys}"
779                    raise ValueError(f"Node keys mismatch{message}.")
780                child_pytrees = [tree[key] for key in expected_keys]
781            else:
782                flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
783                child_pytrees, context = flatten_fn(tree)
784                if (
785                    context != self.context
786                    and self.type is not deque  # ignore mismatch of `maxlen` for deque
787                ):
788                    raise ValueError(
789                        f"Node context mismatch for node type {self.type!r}; "
790                        f"expected {self.context!r}, but got {context!r}.",  # namedtuple type mismatch
791                    )
792
793        for child_pytree, child_spec in zip(child_pytrees, self.children_specs):
794            child_spec._flatten_up_to_helper(child_pytree, subtrees)
795
796    def flatten_up_to(self, tree: PyTree) -> List[PyTree]:
797        subtrees: List[PyTree] = []
798        self._flatten_up_to_helper(tree, subtrees)
799        return subtrees
800
801    def unflatten(self, leaves: Iterable[Any]) -> PyTree:
802        if not isinstance(leaves, (list, tuple)):
803            leaves = list(leaves)
804        if len(leaves) != self.num_leaves:
805            raise ValueError(
806                f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
807                f"but the spec refers to a pytree that holds {self.num_leaves} "
808                f"items ({self}).",
809            )
810        if self.is_leaf():
811            return leaves[0]
812
813        unflatten_fn = SUPPORTED_NODES[self.type].unflatten_fn
814
815        # Recursively unflatten the children
816        start = 0
817        end = 0
818        child_pytrees = []
819        for child_spec in self.children_specs:
820            end += child_spec.num_leaves
821            child_pytrees.append(child_spec.unflatten(leaves[start:end]))
822            start = end
823
824        return unflatten_fn(child_pytrees, self.context)
825
826
827class LeafSpec(TreeSpec):
828    def __init__(self) -> None:
829        super().__init__(None, None, [])
830
831    def __post_init__(self) -> None:
832        object.__setattr__(self, "num_nodes", 1)
833        object.__setattr__(self, "num_leaves", 1)
834        object.__setattr__(self, "num_children", 0)
835
836    def __repr__(self, indent: int = 0) -> str:
837        return "*"
838
839
840# All leaves are equivalent, so represent with a single object to save on
841# object construction time
842_LEAF_SPEC = LeafSpec()
843
844
845def _tree_flatten_helper(
846    tree: PyTree,
847    leaves: List[Any],
848    is_leaf: Optional[Callable[[PyTree], bool]] = None,
849) -> TreeSpec:
850    if _is_leaf(tree, is_leaf=is_leaf):
851        leaves.append(tree)
852        return _LEAF_SPEC
853
854    node_type = _get_node_type(tree)
855    flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
856    child_pytrees, context = flatten_fn(tree)
857
858    # Recursively flatten the children
859    children_specs = [
860        _tree_flatten_helper(child, leaves, is_leaf=is_leaf) for child in child_pytrees
861    ]
862
863    return TreeSpec(node_type, context, children_specs)
864
865
866def tree_flatten(
867    tree: PyTree,
868    is_leaf: Optional[Callable[[PyTree], bool]] = None,
869) -> Tuple[List[Any], TreeSpec]:
870    """Flattens a pytree into a list of values and a TreeSpec that can be used
871    to reconstruct the pytree.
872    """
873    leaves: List[Any] = []
874    spec = _tree_flatten_helper(tree, leaves, is_leaf=is_leaf)
875    return leaves, spec
876
877
878def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree:
879    """Given a list of values and a TreeSpec, builds a pytree.
880    This is the inverse operation of `tree_flatten`.
881    """
882    if not isinstance(treespec, TreeSpec):
883        raise TypeError(
884            f"tree_unflatten(leaves, treespec): Expected `treespec` to be "
885            f"instance of TreeSpec but got item of type {type(treespec)}.",
886        )
887    return treespec.unflatten(leaves)
888
889
890def tree_iter(
891    tree: PyTree,
892    is_leaf: Optional[Callable[[PyTree], bool]] = None,
893) -> Iterable[Any]:
894    """Get an iterator over the leaves of a pytree."""
895    if _is_leaf(tree, is_leaf=is_leaf):
896        yield tree
897    else:
898        node_type = _get_node_type(tree)
899        flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
900        child_pytrees, _ = flatten_fn(tree)
901
902        # Recursively flatten the children
903        for child in child_pytrees:
904            yield from tree_iter(child, is_leaf=is_leaf)
905
906
907def tree_leaves(
908    tree: PyTree,
909    is_leaf: Optional[Callable[[PyTree], bool]] = None,
910) -> List[Any]:
911    """Get a list of leaves of a pytree."""
912    return list(tree_iter(tree, is_leaf=is_leaf))
913
914
915def tree_structure(
916    tree: PyTree,
917    is_leaf: Optional[Callable[[PyTree], bool]] = None,
918) -> TreeSpec:
919    """Get the TreeSpec for a pytree."""
920    return tree_flatten(tree, is_leaf=is_leaf)[1]
921
922
923def tree_map(
924    func: Callable[..., Any],
925    tree: PyTree,
926    *rests: PyTree,
927    is_leaf: Optional[Callable[[PyTree], bool]] = None,
928) -> PyTree:
929    """Map a multi-input function over pytree args to produce a new pytree.
930
931    See also :func:`tree_map_`.
932
933    >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)})
934    {'x': 8, 'y': (43, 65)}
935    >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None})
936    {'x': False, 'y': (False, False), 'z': True}
937
938    If multiple inputs are given, the structure of the tree is taken from the first input;
939    subsequent inputs need only have ``tree`` as a prefix:
940
941    >>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]])
942    [[5, 7, 9], [6, 1, 2]]
943
944    Args:
945        func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
946            corresponding leaves of the pytrees.
947        tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
948            argument to function ``func``.
949        rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
950            ``tree`` or has ``tree`` as a prefix.
951        is_leaf (callable, optional): An extra leaf predicate function that will be called at each
952            flattening step. The function should have a single argument with signature
953            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
954            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
955            leaf or not. If the function is not specified, the default pytree registry will be used.
956
957    Returns:
958        A new pytree with the same structure as ``tree`` but with the value at each leaf given by
959        ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs``
960        is the tuple of values at corresponding nodes in ``rests``.
961    """
962    leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
963    flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
964    return treespec.unflatten(map(func, *flat_args))
965
966
967def tree_map_(
968    func: Callable[..., Any],
969    tree: PyTree,
970    *rests: PyTree,
971    is_leaf: Optional[Callable[[PyTree], bool]] = None,
972) -> PyTree:
973    """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree.
974
975    See also :func:`tree_map`.
976
977    Args:
978        func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the
979            corresponding leaves of the pytrees.
980        tree (pytree): A pytree to be mapped over, with each leaf providing the first positional
981            argument to function ``func``.
982        rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as
983            ``tree`` or has ``tree`` as a prefix.
984        is_leaf (callable, optional): An extra leaf predicate function that will be called at each
985            flattening step. The function should have a single argument with signature
986            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
987            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
988            leaf or not. If the function is not specified, the default pytree registry will be used.
989
990    Returns:
991        The original ``tree`` with the value at each leaf is given by the side-effect of function
992        ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf
993        in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``.
994    """
995    leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
996    flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
997    tuple(map(func, *flat_args))  # consume and exhaust the iterable
998    return tree
999
1000
1001Type2 = Tuple[Type[T], Type[S]]
1002Type3 = Tuple[Type[T], Type[S], Type[U]]
1003if sys.version_info >= (3, 10):
1004    TypeAny = Union[Type[Any], Tuple[Type[Any], ...], types.UnionType]
1005else:
1006    TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]
1007
1008Fn2 = Callable[[Union[T, S]], R]
1009Fn3 = Callable[[Union[T, S, U]], R]
1010Fn = Callable[[T], R]
1011FnAny = Callable[[Any], R]
1012
1013MapOnlyFn = Callable[[T], Callable[[Any], Any]]
1014
1015
1016# These specializations help with type inference on the lambda passed to this
1017# function
1018@overload
1019def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]:
1020    ...
1021
1022
1023@overload
1024def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]:
1025    ...
1026
1027
1028@overload
1029def map_only(__type_or_types_or_pred: Type[T]) -> MapOnlyFn[Fn[T, Any]]:
1030    ...
1031
1032
1033# This specialization is needed for the implementations below that call
1034@overload
1035def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]:
1036    ...
1037
1038
1039@overload
1040def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]:
1041    ...
1042
1043
1044def map_only(
1045    __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]]
1046) -> MapOnlyFn[FnAny[Any]]:
1047    """
1048    Suppose you are writing a tree_map over tensors, leaving everything
1049    else unchanged.  Ordinarily you would have to write:
1050
1051        def go(t):
1052            if isinstance(t, Tensor):
1053                return ...
1054            else:
1055                return t
1056
1057    With this function, you only need to write:
1058
1059        @map_only(Tensor)
1060        def go(t):
1061            return ...
1062
1063    You can also directly use 'tree_map_only'
1064    """
1065    if isinstance(__type_or_types_or_pred, (type, tuple)) or (
1066        sys.version_info >= (3, 10)
1067        and isinstance(__type_or_types_or_pred, types.UnionType)
1068    ):
1069
1070        def pred(x: Any) -> bool:
1071            return isinstance(x, __type_or_types_or_pred)  # type: ignore[arg-type]
1072
1073    elif callable(__type_or_types_or_pred):
1074        pred = __type_or_types_or_pred  # type: ignore[assignment]
1075    else:
1076        raise TypeError("Argument must be a type, a tuple of types, or a callable.")
1077
1078    def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]:
1079        @functools.wraps(func)
1080        def wrapped(x: T) -> Any:
1081            if pred(x):
1082                return func(x)
1083            return x
1084
1085        return wrapped
1086
1087    return wrapper
1088
1089
1090@overload
1091def tree_map_only(
1092    __type_or_types_or_pred: Type[T],
1093    func: Fn[T, Any],
1094    tree: PyTree,
1095    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1096) -> PyTree:
1097    ...
1098
1099
1100@overload
1101def tree_map_only(
1102    __type_or_types_or_pred: Type2[T, S],
1103    func: Fn2[T, S, Any],
1104    tree: PyTree,
1105    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1106) -> PyTree:
1107    ...
1108
1109
1110@overload
1111def tree_map_only(
1112    __type_or_types_or_pred: Type3[T, S, U],
1113    func: Fn3[T, S, U, Any],
1114    tree: PyTree,
1115    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1116) -> PyTree:
1117    ...
1118
1119
1120@overload
1121def tree_map_only(
1122    __type_or_types_or_pred: Callable[[Any], bool],
1123    func: FnAny[Any],
1124    tree: PyTree,
1125    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1126) -> PyTree:
1127    ...
1128
1129
1130def tree_map_only(
1131    __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
1132    func: FnAny[Any],
1133    tree: PyTree,
1134    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1135) -> PyTree:
1136    return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
1137
1138
1139@overload
1140def tree_map_only_(
1141    __type_or_types_or_pred: Type[T],
1142    func: Fn[T, Any],
1143    tree: PyTree,
1144    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1145) -> PyTree:
1146    ...
1147
1148
1149@overload
1150def tree_map_only_(
1151    __type_or_types_or_pred: Type2[T, S],
1152    func: Fn2[T, S, Any],
1153    tree: PyTree,
1154    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1155) -> PyTree:
1156    ...
1157
1158
1159@overload
1160def tree_map_only_(
1161    __type_or_types_or_pred: Type3[T, S, U],
1162    func: Fn3[T, S, U, Any],
1163    tree: PyTree,
1164    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1165) -> PyTree:
1166    ...
1167
1168
1169@overload
1170def tree_map_only_(
1171    __type_or_types_or_pred: Callable[[Any], bool],
1172    func: FnAny[Any],
1173    tree: PyTree,
1174    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1175) -> PyTree:
1176    ...
1177
1178
1179def tree_map_only_(
1180    __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]],
1181    func: FnAny[Any],
1182    tree: PyTree,
1183    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1184) -> PyTree:
1185    return tree_map_(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
1186
1187
1188def tree_all(
1189    pred: Callable[[Any], bool],
1190    tree: PyTree,
1191    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1192) -> bool:
1193    flat_args = tree_iter(tree, is_leaf=is_leaf)
1194    return all(map(pred, flat_args))
1195
1196
1197def tree_any(
1198    pred: Callable[[Any], bool],
1199    tree: PyTree,
1200    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1201) -> bool:
1202    flat_args = tree_iter(tree, is_leaf=is_leaf)
1203    return any(map(pred, flat_args))
1204
1205
1206@overload
1207def tree_all_only(
1208    __type_or_types: Type[T],
1209    pred: Fn[T, bool],
1210    tree: PyTree,
1211    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1212) -> bool:
1213    ...
1214
1215
1216@overload
1217def tree_all_only(
1218    __type_or_types: Type2[T, S],
1219    pred: Fn2[T, S, bool],
1220    tree: PyTree,
1221    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1222) -> bool:
1223    ...
1224
1225
1226@overload
1227def tree_all_only(
1228    __type_or_types: Type3[T, S, U],
1229    pred: Fn3[T, S, U, bool],
1230    tree: PyTree,
1231    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1232) -> bool:
1233    ...
1234
1235
1236def tree_all_only(
1237    __type_or_types: TypeAny,
1238    pred: FnAny[bool],
1239    tree: PyTree,
1240    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1241) -> bool:
1242    flat_args = tree_iter(tree, is_leaf=is_leaf)
1243    return all(pred(x) for x in flat_args if isinstance(x, __type_or_types))
1244
1245
1246@overload
1247def tree_any_only(
1248    __type_or_types: Type[T],
1249    pred: Fn[T, bool],
1250    tree: PyTree,
1251    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1252) -> bool:
1253    ...
1254
1255
1256@overload
1257def tree_any_only(
1258    __type_or_types: Type2[T, S],
1259    pred: Fn2[T, S, bool],
1260    tree: PyTree,
1261    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1262) -> bool:
1263    ...
1264
1265
1266@overload
1267def tree_any_only(
1268    __type_or_types: Type3[T, S, U],
1269    pred: Fn3[T, S, U, bool],
1270    tree: PyTree,
1271    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1272) -> bool:
1273    ...
1274
1275
1276def tree_any_only(
1277    __type_or_types: TypeAny,
1278    pred: FnAny[bool],
1279    tree: PyTree,
1280    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1281) -> bool:
1282    flat_args = tree_iter(tree, is_leaf=is_leaf)
1283    return any(pred(x) for x in flat_args if isinstance(x, __type_or_types))
1284
1285
1286# Broadcasts a pytree to the provided TreeSpec and returns the flattened
1287# values. If this is not possible, then this function returns None.
1288#
1289# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]),
1290# would return [0, 0]. This is useful for part of the vmap implementation:
1291# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be
1292# broadcastable to the tree structure of `inputs` and we use
1293# _broadcast_to_and_flatten to check this.
1294def _broadcast_to_and_flatten(
1295    tree: PyTree,
1296    treespec: TreeSpec,
1297    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1298) -> Optional[List[Any]]:
1299    assert isinstance(treespec, TreeSpec)
1300
1301    if _is_leaf(tree, is_leaf=is_leaf):
1302        return [tree] * treespec.num_leaves
1303    if treespec.is_leaf():
1304        return None
1305    node_type = _get_node_type(tree)
1306    if node_type != treespec.type:
1307        return None
1308
1309    flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
1310    child_pytrees, ctx = flatten_fn(tree)
1311
1312    # Check if the Node is different from the spec
1313    if len(child_pytrees) != treespec.num_children or ctx != treespec.context:
1314        return None
1315
1316    # Recursively flatten the children
1317    result: List[Any] = []
1318    for child, child_spec in zip(child_pytrees, treespec.children_specs):
1319        flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf)
1320        if flat is not None:
1321            result += flat
1322        else:
1323            return None
1324
1325    return result
1326
1327
1328@dataclasses.dataclass
1329class _TreeSpecSchema:
1330    """
1331    _TreeSpecSchema is the schema used to serialize the TreeSpec
1332    It contains the following fields:
1333    - type: A string name of the type. null for the case of a LeafSpec.
1334    - context: Any format which is json dumpable
1335    - children_spec: A list of children serialized specs.
1336    """
1337
1338    type: Optional[str]
1339    context: DumpableContext
1340    children_spec: List["_TreeSpecSchema"]
1341
1342
1343class _ProtocolFn(NamedTuple):
1344    treespec_to_json: Callable[[TreeSpec], DumpableContext]
1345    json_to_treespec: Callable[[DumpableContext], TreeSpec]
1346
1347
1348_SUPPORTED_PROTOCOLS: Dict[int, _ProtocolFn] = {}
1349
1350
1351def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema:
1352    if treespec.is_leaf():
1353        return _TreeSpecSchema(None, None, [])
1354
1355    if treespec.type not in SUPPORTED_SERIALIZED_TYPES:
1356        raise NotImplementedError(
1357            f"Serializing {treespec.type} in pytree is not registered.",
1358        )
1359
1360    serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type]
1361
1362    serialized_type_name = serialize_node_def.serialized_type_name
1363
1364    if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND:
1365        raise NotImplementedError(
1366            f"No registered serialization name for {treespec.type} found. "
1367            "Please update your _register_pytree_node call with a `serialized_type_name` kwarg."
1368        )
1369
1370    if serialize_node_def.to_dumpable_context is None:
1371        try:
1372            serialized_context = json.dumps(treespec.context)
1373        except TypeError as e:
1374            raise TypeError(
1375                "Unable to serialize context. "
1376                "Please make the context json dump-able, or register a "
1377                "custom serializer using _register_pytree_node."
1378            ) from e
1379    else:
1380        serialized_context = serialize_node_def.to_dumpable_context(treespec.context)
1381
1382    child_schemas = [_treespec_to_json(child) for child in treespec.children_specs]
1383
1384    return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas)
1385
1386
1387def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec:
1388    if (
1389        json_schema["type"] is None
1390        and json_schema["context"] is None
1391        and len(json_schema["children_spec"]) == 0
1392    ):
1393        return _LEAF_SPEC
1394
1395    if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE:
1396        raise NotImplementedError(
1397            f'Deserializing {json_schema["type"]} in pytree is not registered.',
1398        )
1399
1400    typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]]
1401    serialize_node_def = SUPPORTED_SERIALIZED_TYPES[typ]
1402
1403    if serialize_node_def.from_dumpable_context is None:
1404        try:
1405            context = json.loads(json_schema["context"])
1406        except TypeError as ex:
1407            raise TypeError(
1408                "Unable to deserialize context. "
1409                "Please make the context json load-able, or register a "
1410                "custom serializer using _register_pytree_node.",
1411            ) from ex
1412    else:
1413        context = serialize_node_def.from_dumpable_context(json_schema["context"])
1414
1415    children_specs = []
1416    for child_string in json_schema["children_spec"]:
1417        children_specs.append(_json_to_treespec(child_string))
1418
1419    return TreeSpec(typ, context, children_specs)
1420
1421
1422_SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec)
1423
1424
1425def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str:
1426    if not isinstance(treespec, TreeSpec):
1427        raise TypeError(
1428            f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of "
1429            f"TreeSpec but got item of type {type(treespec)}.",
1430        )
1431
1432    if protocol is None:
1433        protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL
1434
1435    if protocol in _SUPPORTED_PROTOCOLS:
1436        json_spec = _SUPPORTED_PROTOCOLS[protocol].treespec_to_json(treespec)
1437    else:
1438        raise ValueError(
1439            f"Unknown protocol {protocol}. "
1440            f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}",
1441        )
1442
1443    str_spec = json.dumps((protocol, dataclasses.asdict(json_spec)))
1444    return str_spec
1445
1446
1447def treespec_loads(serialized: str) -> TreeSpec:
1448    protocol, json_schema = json.loads(serialized)
1449
1450    if protocol in _SUPPORTED_PROTOCOLS:
1451        return _SUPPORTED_PROTOCOLS[protocol].json_to_treespec(json_schema)
1452    raise ValueError(
1453        f"Unknown protocol {protocol}. "
1454        f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}",
1455    )
1456
1457
1458class _DummyLeaf:
1459    def __repr__(self) -> str:
1460        return "*"
1461
1462
1463def treespec_pprint(treespec: TreeSpec) -> str:
1464    dummy_tree = tree_unflatten(
1465        [_DummyLeaf() for _ in range(treespec.num_leaves)],
1466        treespec,
1467    )
1468    return repr(dummy_tree)
1469
1470
1471# TODO(angelayi): remove this function after OSS/internal stabilize
1472@deprecated(
1473    "`pytree_to_str` is deprecated. Please use `treespec_dumps` instead.",
1474    category=FutureWarning,
1475)
1476def pytree_to_str(treespec: TreeSpec) -> str:
1477    return treespec_dumps(treespec)
1478
1479
1480# TODO(angelayi): remove this function after OSS/internal stabilize
1481@deprecated(
1482    "`str_to_pytree` is deprecated. Please use `treespec_loads` instead.",
1483    category=FutureWarning,
1484)
1485def str_to_pytree(json: str) -> TreeSpec:
1486    return treespec_loads(json)
1487
1488
1489def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> List[Any]:
1490    """Get a flat list of arguments to this function
1491
1492    A slightly faster version of tree_leaves((args, kwargs))
1493    """
1494    leaves: List[Any] = []
1495    for a in args:
1496        leaves.extend(tree_iter(a))
1497    for a in kwargs.values():
1498        leaves.extend(tree_iter(a))
1499    return leaves
1500
1501
1502def tree_flatten_with_path(
1503    tree: PyTree,
1504    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1505) -> Tuple[List[Tuple[KeyPath, Any]], TreeSpec]:
1506    """Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path.
1507
1508    Args:
1509        tree: a pytree to flatten. If it contains a custom type, that type must be
1510            registered with an appropriate `tree_flatten_with_path_fn` when registered
1511            with :func:`register_pytree_node`.
1512        is_leaf: An extra leaf predicate function that will be called at each
1513            flattening step. The function should have a single argument with signature
1514            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
1515            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
1516            leaf or not. If the function is not specified, the default pytree registry will be used.
1517    Returns:
1518        A tuple where the first element is a list of (key path, leaf) pairs, and the
1519        second element is a :class:`TreeSpec` representing the structure of the flattened
1520        tree.
1521    """
1522    _, treespec = tree_flatten(tree, is_leaf)
1523    return list(_generate_key_paths((), tree, is_leaf)), treespec
1524
1525
1526def tree_leaves_with_path(
1527    tree: PyTree,
1528    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1529) -> List[Tuple[KeyPath, Any]]:
1530    """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
1531
1532    Args:
1533        tree: a pytree. If it contains a custom type, that type must be
1534            registered with an appropriate `tree_flatten_with_path_fn` when registered
1535            with :func:`register_pytree_node`.
1536        is_leaf: An extra leaf predicate function that will be called at each
1537            flattening step. The function should have a single argument with signature
1538            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
1539            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
1540            leaf or not. If the function is not specified, the default pytree registry will be used.
1541    Returns:
1542        A list of (key path, leaf) pairs.
1543    """
1544    return list(_generate_key_paths((), tree, is_leaf))
1545
1546
1547def _generate_key_paths(
1548    key_path: KeyPath,
1549    tree: PyTree,
1550    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1551) -> Iterable[Tuple[KeyPath, Any]]:
1552    if is_leaf and is_leaf(tree):
1553        yield key_path, tree
1554        return
1555
1556    node_type = _get_node_type(tree)
1557    handler = SUPPORTED_NODES.get(node_type)
1558    if not handler:
1559        # This is a leaf
1560        yield key_path, tree
1561        return
1562
1563    flatten_with_keys = handler.flatten_with_keys_fn
1564    if flatten_with_keys:
1565        key_children, _ = flatten_with_keys(tree)
1566        for k, c in key_children:
1567            yield from _generate_key_paths((*key_path, k), c, is_leaf)
1568    else:
1569        # We registered this pytree but didn't add a flatten_with_keys_fn, complain.
1570        raise ValueError(
1571            f"Did not find a flatten_with_keys_fn for type: {node_type}. "
1572            "Please pass a flatten_with_keys_fn argument to register_pytree_node."
1573        )
1574
1575
1576def tree_map_with_path(
1577    func: Callable[..., Any],
1578    tree: PyTree,
1579    *rests: PyTree,
1580    is_leaf: Optional[Callable[[PyTree], bool]] = None,
1581) -> PyTree:
1582    """Like :func:`tree_map`, but the provided callable takes an additional key path argument.
1583
1584    Args:
1585        func: A function that takes ``2 + len(rests)`` arguments, to be applied at the
1586            corresponding leaves of the pytrees. The first positional argument
1587            to ``func`` is the key path of the leaf in question. The second
1588            positional argument is the value of the leaf.
1589        tree: A pytree to be mapped over, with each leaf providing the first positional
1590            argument to function ``func``.
1591        rests: A tuple of pytrees, each of which has the same structure as
1592            ``tree`` or has ``tree`` as a prefix.
1593        is_leaf: An extra leaf predicate function that will be called at each
1594            flattening step. The function should have a single argument with signature
1595            ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated
1596            as a leaf. Otherwise, the default pytree registry will be used to determine a node is a
1597            leaf or not. If the function is not specified, the default pytree registry will be used.
1598
1599    Returns
1600        A new pytree with the same structure as ``tree`` but with the value at each leaf given by
1601        ``func(keypath, x, *xs)`` where ``keypath`` is the key path at the
1602        corresponding leaf in ``tree``, ``x`` is the value at that leaf, and
1603        ``xs`` is the tuple of values at corresponding nodes in ``rests``.
1604    """
1605    keypath_leaves, treespec = tree_flatten_with_path(tree, is_leaf)
1606    keypath_leaves = list(zip(*keypath_leaves))
1607    all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
1608    return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
1609
1610
1611def keystr(kp: KeyPath) -> str:
1612    """Given a key path, return a pretty-printed representation."""
1613    return "".join([str(k) for k in kp])
1614
1615
1616def key_get(obj: Any, kp: KeyPath) -> Any:
1617    """Given an object and a key path, return the value at the key path."""
1618    for k in kp:
1619        obj = k.get(obj)
1620    return obj
1621