xref: /aosp_15_r20/external/executorch/exir/serde/union.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-ignore-all-errors
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport functools
10*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import fields
11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Hashable, Set
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Workerclass _UnionTag(str):
15*523fa7a6SAndroid Build Coastguard Worker    _cls: Hashable
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Worker    @staticmethod
18*523fa7a6SAndroid Build Coastguard Worker    def create(t, cls):
19*523fa7a6SAndroid Build Coastguard Worker        tag = _UnionTag(t)
20*523fa7a6SAndroid Build Coastguard Worker        assert not hasattr(tag, "_cls")
21*523fa7a6SAndroid Build Coastguard Worker        tag._cls = cls
22*523fa7a6SAndroid Build Coastguard Worker        return tag
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Worker    def __eq__(self, cmp) -> bool:
25*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(cmp, str)
26*523fa7a6SAndroid Build Coastguard Worker        other = str(cmp)
27*523fa7a6SAndroid Build Coastguard Worker        assert other in _get_field_names(
28*523fa7a6SAndroid Build Coastguard Worker            self._cls
29*523fa7a6SAndroid Build Coastguard Worker        ), f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}"
30*523fa7a6SAndroid Build Coastguard Worker        return str(self) == other
31*523fa7a6SAndroid Build Coastguard Worker
32*523fa7a6SAndroid Build Coastguard Worker    def __hash__(self):
33*523fa7a6SAndroid Build Coastguard Worker        return hash(str(self))
34*523fa7a6SAndroid Build Coastguard Worker
35*523fa7a6SAndroid Build Coastguard Worker
36*523fa7a6SAndroid Build Coastguard Worker@functools.lru_cache(maxsize=None)
37*523fa7a6SAndroid Build Coastguard Workerdef _get_field_names(cls) -> Set[str]:
38*523fa7a6SAndroid Build Coastguard Worker    return {f.name for f in fields(cls)}
39*523fa7a6SAndroid Build Coastguard Worker
40*523fa7a6SAndroid Build Coastguard Worker
41*523fa7a6SAndroid Build Coastguard Workerclass _Union:
42*523fa7a6SAndroid Build Coastguard Worker    _type: _UnionTag
43*523fa7a6SAndroid Build Coastguard Worker
44*523fa7a6SAndroid Build Coastguard Worker    @classmethod
45*523fa7a6SAndroid Build Coastguard Worker    def create(cls, **kwargs):
46*523fa7a6SAndroid Build Coastguard Worker        assert len(kwargs) == 1
47*523fa7a6SAndroid Build Coastguard Worker        obj = cls(**{**{f.name: None for f in fields(cls)}, **kwargs})  # type: ignore[arg-type]
48*523fa7a6SAndroid Build Coastguard Worker        obj._type = _UnionTag.create(next(iter(kwargs.keys())), cls)
49*523fa7a6SAndroid Build Coastguard Worker        return obj
50*523fa7a6SAndroid Build Coastguard Worker
51*523fa7a6SAndroid Build Coastguard Worker    def __post_init__(self):
52*523fa7a6SAndroid Build Coastguard Worker        assert not any(f.name in ("type", "_type", "create", "value") for f in fields(self))  # type: ignore[arg-type, misc]
53*523fa7a6SAndroid Build Coastguard Worker
54*523fa7a6SAndroid Build Coastguard Worker    @property
55*523fa7a6SAndroid Build Coastguard Worker    def type(self) -> str:
56*523fa7a6SAndroid Build Coastguard Worker        try:
57*523fa7a6SAndroid Build Coastguard Worker            return self._type
58*523fa7a6SAndroid Build Coastguard Worker        except AttributeError as e:
59*523fa7a6SAndroid Build Coastguard Worker            raise RuntimeError(
60*523fa7a6SAndroid Build Coastguard Worker                f"Please use {type(self).__name__}.create to instantiate the union type."
61*523fa7a6SAndroid Build Coastguard Worker            ) from e
62*523fa7a6SAndroid Build Coastguard Worker
63*523fa7a6SAndroid Build Coastguard Worker    @property
64*523fa7a6SAndroid Build Coastguard Worker    def value(self):
65*523fa7a6SAndroid Build Coastguard Worker        return getattr(self, self.type)
66*523fa7a6SAndroid Build Coastguard Worker
67*523fa7a6SAndroid Build Coastguard Worker    def __getattribute__(self, name):
68*523fa7a6SAndroid Build Coastguard Worker        attr = super().__getattribute__(name)
69*523fa7a6SAndroid Build Coastguard Worker        if attr is None and name in _get_field_names(type(self)) and name != self.type:  # type: ignore[arg-type]
70*523fa7a6SAndroid Build Coastguard Worker            raise AttributeError(f"Field {name} is not set.")
71*523fa7a6SAndroid Build Coastguard Worker        return attr
72*523fa7a6SAndroid Build Coastguard Worker
73*523fa7a6SAndroid Build Coastguard Worker    def __str__(self):
74*523fa7a6SAndroid Build Coastguard Worker        return self.__repr__()
75*523fa7a6SAndroid Build Coastguard Worker
76*523fa7a6SAndroid Build Coastguard Worker    def __repr__(self):
77*523fa7a6SAndroid Build Coastguard Worker        return f"{type(self).__name__}({self.type}={getattr(self, self.type)})"
78