1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-ignore-all-errors 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport functools 10*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import fields 11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Hashable, Set 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Workerclass _UnionTag(str): 15*523fa7a6SAndroid Build Coastguard Worker _cls: Hashable 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Worker @staticmethod 18*523fa7a6SAndroid Build Coastguard Worker def create(t, cls): 19*523fa7a6SAndroid Build Coastguard Worker tag = _UnionTag(t) 20*523fa7a6SAndroid Build Coastguard Worker assert not hasattr(tag, "_cls") 21*523fa7a6SAndroid Build Coastguard Worker tag._cls = cls 22*523fa7a6SAndroid Build Coastguard Worker return tag 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Worker def __eq__(self, cmp) -> bool: 25*523fa7a6SAndroid Build Coastguard Worker assert isinstance(cmp, str) 26*523fa7a6SAndroid Build Coastguard Worker other = str(cmp) 27*523fa7a6SAndroid Build Coastguard Worker assert other in _get_field_names( 28*523fa7a6SAndroid Build Coastguard Worker self._cls 29*523fa7a6SAndroid Build Coastguard Worker ), f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}" 30*523fa7a6SAndroid Build Coastguard Worker return str(self) == other 31*523fa7a6SAndroid Build Coastguard Worker 32*523fa7a6SAndroid Build Coastguard Worker def __hash__(self): 33*523fa7a6SAndroid Build Coastguard Worker return hash(str(self)) 34*523fa7a6SAndroid Build Coastguard Worker 35*523fa7a6SAndroid Build Coastguard Worker 36*523fa7a6SAndroid Build Coastguard Worker@functools.lru_cache(maxsize=None) 37*523fa7a6SAndroid Build Coastguard Workerdef _get_field_names(cls) -> Set[str]: 38*523fa7a6SAndroid Build Coastguard Worker return {f.name for f in fields(cls)} 39*523fa7a6SAndroid Build Coastguard Worker 40*523fa7a6SAndroid Build Coastguard Worker 41*523fa7a6SAndroid Build Coastguard Workerclass _Union: 42*523fa7a6SAndroid Build Coastguard Worker _type: _UnionTag 43*523fa7a6SAndroid Build Coastguard Worker 44*523fa7a6SAndroid Build Coastguard Worker @classmethod 45*523fa7a6SAndroid Build Coastguard Worker def create(cls, **kwargs): 46*523fa7a6SAndroid Build Coastguard Worker assert len(kwargs) == 1 47*523fa7a6SAndroid Build Coastguard Worker obj = cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) # type: ignore[arg-type] 48*523fa7a6SAndroid Build Coastguard Worker obj._type = _UnionTag.create(next(iter(kwargs.keys())), cls) 49*523fa7a6SAndroid Build Coastguard Worker return obj 50*523fa7a6SAndroid Build Coastguard Worker 51*523fa7a6SAndroid Build Coastguard Worker def __post_init__(self): 52*523fa7a6SAndroid Build Coastguard Worker assert not any(f.name in ("type", "_type", "create", "value") for f in fields(self)) # type: ignore[arg-type, misc] 53*523fa7a6SAndroid Build Coastguard Worker 54*523fa7a6SAndroid Build Coastguard Worker @property 55*523fa7a6SAndroid Build Coastguard Worker def type(self) -> str: 56*523fa7a6SAndroid Build Coastguard Worker try: 57*523fa7a6SAndroid Build Coastguard Worker return self._type 58*523fa7a6SAndroid Build Coastguard Worker except AttributeError as e: 59*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 60*523fa7a6SAndroid Build Coastguard Worker f"Please use {type(self).__name__}.create to instantiate the union type." 61*523fa7a6SAndroid Build Coastguard Worker ) from e 62*523fa7a6SAndroid Build Coastguard Worker 63*523fa7a6SAndroid Build Coastguard Worker @property 64*523fa7a6SAndroid Build Coastguard Worker def value(self): 65*523fa7a6SAndroid Build Coastguard Worker return getattr(self, self.type) 66*523fa7a6SAndroid Build Coastguard Worker 67*523fa7a6SAndroid Build Coastguard Worker def __getattribute__(self, name): 68*523fa7a6SAndroid Build Coastguard Worker attr = super().__getattribute__(name) 69*523fa7a6SAndroid Build Coastguard Worker if attr is None and name in _get_field_names(type(self)) and name != self.type: # type: ignore[arg-type] 70*523fa7a6SAndroid Build Coastguard Worker raise AttributeError(f"Field {name} is not set.") 71*523fa7a6SAndroid Build Coastguard Worker return attr 72*523fa7a6SAndroid Build Coastguard Worker 73*523fa7a6SAndroid Build Coastguard Worker def __str__(self): 74*523fa7a6SAndroid Build Coastguard Worker return self.__repr__() 75*523fa7a6SAndroid Build Coastguard Worker 76*523fa7a6SAndroid Build Coastguard Worker def __repr__(self): 77*523fa7a6SAndroid Build Coastguard Worker return f"{type(self).__name__}({self.type}={getattr(self, self.type)})" 78