xref: /aosp_15_r20/external/emboss/compiler/util/ir_data_fields.py (revision 99e0aae7469b87d12f0ad23e61142c2d74c1ef70)
1# Copyright 2024 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Provides helpers for defining and working with the fields of IR data classes.
16
17Field specs
18-----------
19Various utilities are provided for working with `dataclasses.Field`s:
20  - Classes:
21    - `FieldSpec` - used to track data about an IR data field
22    - `FieldContainer` - used to track the field container type
23  - Functions that work with `FieldSpec`s:
24    - `make_field_spec`, `build_default`
25  - Functions for retrieving a set of `FieldSpec`s for a given class:
26    - `field_specs`
27  - Functions for retrieving fields and their values:
28    - `fields_and_values`
29  - Functions for copying and updating IR data classes
30    - `copy`, `update`
31  - Functions to help defining IR data fields
32    - `oneof_field`, `list_field`, `str_field`
33"""
34
35import dataclasses
36import enum
37import sys
38from typing import (
39    Any,
40    Callable,
41    ClassVar,
42    ForwardRef,
43    Iterable,
44    Mapping,
45    MutableMapping,
46    NamedTuple,
47    Optional,
48    Protocol,
49    SupportsIndex,
50    Tuple,
51    TypeVar,
52    Union,
53    cast,
54    get_args,
55    get_origin,
56)
57
58
59class IrDataclassInstance(Protocol):
60  """Type bound for an IR dataclass instance."""
61
62  __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]
63  IR_DATACLASS: ClassVar[object]
64  field_specs: ClassVar["FilteredIrFieldSpecs"]
65
66
67IrDataT = TypeVar("IrDataT", bound=IrDataclassInstance)
68
69CopyValuesListT = TypeVar("CopyValuesListT", bound=type)
70
71_IR_DATACLASSS_ATTR = "IR_DATACLASS"
72
73
74def _is_ir_dataclass(obj):
75  return hasattr(obj, _IR_DATACLASSS_ATTR)
76
77
78class CopyValuesList(list[CopyValuesListT]):
79  """A list that makes copies of any value that is inserted"""
80
81  def __init__(
82      self, value_type: CopyValuesListT, iterable: Optional[Iterable] = None
83  ):
84    if iterable:
85      super().__init__(iterable)
86    else:
87      super().__init__()
88    self.value_type = value_type
89
90  def _copy(self, obj: Any):
91    if _is_ir_dataclass(obj):
92      return copy(obj)
93    return self.value_type(obj)
94
95  def extend(self, iterable: Iterable) -> None:
96    return super().extend([self._copy(i) for i in iterable])
97
98  def shallow_copy(self, iterable: Iterable) -> None:
99    """Explicitly performs a shallow copy of the provided list"""
100    return super().extend(iterable)
101
102  def append(self, obj: Any) -> None:
103    return super().append(self._copy(obj))
104
105  def insert(self, index: SupportsIndex, obj: Any) -> None:
106    return super().insert(index, self._copy(obj))
107
108
109class TemporaryCopyValuesList(NamedTuple):
110  """Class used to temporarily hold a CopyValuesList while copying and
111  constructing an IR dataclass.
112  """
113
114  temp_list: CopyValuesList
115
116
117class FieldContainer(enum.Enum):
118  """Indicates a fields container type"""
119
120  NONE = 0
121  OPTIONAL = 1
122  LIST = 2
123
124
125class FieldSpec(NamedTuple):
126  """Indicates the container and type of a field.
127
128  `FieldSpec` objects are accessed millions of times during runs so we cache as
129  many operations as possible.
130    - `is_dataclass`: `dataclasses.is_dataclass(data_type)`
131    - `is_sequence`: `container is FieldContainer.LIST`
132    - `is_enum`: `issubclass(data_type, enum.Enum)`
133    - `is_oneof`: `oneof is not None`
134
135  Use `make_field_spec` to automatically fill in the cached operations.
136  """
137
138  name: str
139  data_type: type
140  container: FieldContainer
141  oneof: Optional[str]
142  is_dataclass: bool
143  is_sequence: bool
144  is_enum: bool
145  is_oneof: bool
146
147
148def make_field_spec(
149    name: str, data_type: type, container: FieldContainer, oneof: Optional[str]
150):
151  """Builds a field spec with cached type queries."""
152  return FieldSpec(
153      name,
154      data_type,
155      container,
156      oneof,
157      is_dataclass=_is_ir_dataclass(data_type),
158      is_sequence=container is FieldContainer.LIST,
159      is_enum=issubclass(data_type, enum.Enum),
160      is_oneof=oneof is not None,
161  )
162
163
164def build_default(field_spec: FieldSpec):
165  """Builds a default instance of the given field"""
166  if field_spec.is_sequence:
167    return CopyValuesList(field_spec.data_type)
168  if field_spec.is_enum:
169    return field_spec.data_type(int())
170  return field_spec.data_type()
171
172
173class FilteredIrFieldSpecs:
174  """Provides cached views of an IR dataclass' fields."""
175
176  def __init__(self, specs: Mapping[str, FieldSpec]):
177    self.all_field_specs = specs
178    self.field_specs = tuple(specs.values())
179    self.dataclass_field_specs = {
180        k: v for k, v in specs.items() if v.is_dataclass
181    }
182    self.oneof_field_specs = {k: v for k, v in specs.items() if v.is_oneof}
183    self.sequence_field_specs = tuple(
184        v for v in specs.values() if v.is_sequence
185    )
186    self.oneof_mappings = tuple(
187        (k, v.oneof) for k, v in self.oneof_field_specs.items() if v.oneof
188    )
189
190
191def all_ir_classes(mod):
192  """Retrieves a list of all IR dataclass definitions in the given module."""
193  return (
194      v
195      for v in mod.__dict__.values()
196      if isinstance(type, v.__class__) and _is_ir_dataclass(v)
197  )
198
199
200class IrDataclassSpecs:
201  """Maintains a cache of all IR dataclass specs."""
202
203  spec_cache: MutableMapping[type, FilteredIrFieldSpecs] = {}
204
205  @classmethod
206  def get_mod_specs(cls, mod):
207    """Gets the IR dataclass specs for the given module."""
208    return {
209        ir_class: FilteredIrFieldSpecs(_field_specs(ir_class))
210        for ir_class in all_ir_classes(mod)
211    }
212
213  @classmethod
214  def get_specs(cls, data_class):
215    """Gets the field specs for the given class. The specs will be cached."""
216    if data_class not in cls.spec_cache:
217      mod = sys.modules[data_class.__module__]
218      cls.spec_cache.update(cls.get_mod_specs(mod))
219    return cls.spec_cache[data_class]
220
221
222def cache_message_specs(mod, cls):
223  """Adds a cached `field_specs` attribute to IR dataclasses in `mod`
224  excluding the given base `cls`.
225
226  This needs to be done after the dataclass decorators run and create the
227  wrapped classes.
228  """
229  for data_class in all_ir_classes(mod):
230    if data_class is not cls:
231      data_class.field_specs = IrDataclassSpecs.get_specs(data_class)
232
233
234def _field_specs(cls: type[IrDataT]) -> Mapping[str, FieldSpec]:
235  """Gets the IR data field names and types for the given IR data class"""
236  # Get the dataclass fields
237  class_fields = dataclasses.fields(cast(Any, cls))
238
239  # Pre-python 3.11 (maybe pre 3.10) `get_type_hints` will substitute
240  # `builtins.Expression` for 'Expression' rather than `ir_data.Expression`.
241  # Instead we manually subsitute the type by extracting the list of classes
242  # from the class' module and manually substituting.
243  mod_ns = {
244      k: v
245      for k, v in sys.modules[cls.__module__].__dict__.items()
246      if isinstance(type, v.__class__)
247  }
248
249  # Now extract the concrete type out of optionals
250  result: MutableMapping[str, FieldSpec] = {}
251  for class_field in class_fields:
252    if class_field.name.startswith("_"):
253      continue
254    container_type = FieldContainer.NONE
255    type_hint = class_field.type
256    oneof = class_field.metadata.get("oneof")
257
258    # Check if this type is wrapped
259    origin = get_origin(type_hint)
260    # Get the wrapped types if there are any
261    args = get_args(type_hint)
262    if origin is not None:
263      # Extract the type.
264      type_hint = args[0]
265
266      # Underneath the hood `typing.Optional` is just a `Union[T, None]` so we
267      # have to check if it's a `Union` instead of just using `Optional`.
268      if origin == Union:
269        # Make sure this is an `Optional` and not another `Union` type.
270        assert len(args) == 2 and args[1] == type(None)
271        container_type = FieldContainer.OPTIONAL
272      elif origin == list:
273        container_type = FieldContainer.LIST
274      else:
275        raise TypeError(f"Field has invalid container type: {origin}")
276
277    # Resolve any forward references.
278    if isinstance(type_hint, str):
279      type_hint = mod_ns[type_hint]
280    if isinstance(type_hint, ForwardRef):
281      type_hint = mod_ns[type_hint.__forward_arg__]
282
283    result[class_field.name] = make_field_spec(
284        class_field.name, type_hint, container_type, oneof
285    )
286
287  return result
288
289
290def field_specs(obj: Union[IrDataT, type[IrDataT]]) -> Mapping[str, FieldSpec]:
291  """Retrieves the fields specs for the the give data type.
292
293  The results of this method are cached to reduce lookup overhead.
294  """
295  cls = obj if isinstance(obj, type) else type(obj)
296  if cls is type(None):
297    raise TypeError("field_specs called with invalid type: NoneType")
298  return IrDataclassSpecs.get_specs(cls).all_field_specs
299
300
301def fields_and_values(
302    ir: IrDataT,
303    value_filt: Optional[Callable[[Any], bool]] = None,
304):
305  """Retrieves the fields and their values for a given IR data class.
306
307  Args:
308    ir: The IR data class or a read-only wrapper of an IR data class.
309    value_filt: Optional filter used to exclude values.
310  """
311  set_fields: list[Tuple[FieldSpec, Any]] = []
312  specs: FilteredIrFieldSpecs = ir.field_specs
313  for spec in specs.field_specs:
314    value = getattr(ir, spec.name)
315    if not value_filt or value_filt(value):
316      set_fields.append((spec, value))
317  return set_fields
318
319
320# `copy` is one of the hottest paths of embossc. We've taken steps to
321# optimize this path at the expense of code clarity and modularization.
322#
323# 1. `FilteredFieldSpecs` are cached on in the class definition for IR
324#    dataclasses under the `ir_data.Message.field_specs` class attribute. We
325#    just assume the passed in object has that attribute.
326# 2. We cache a `field_specs` entry that is just the `values()` of the
327#    `all_field_specs` dict.
328# 3. Copied lists are wrapped in a `TemporaryCopyValuesList`. This is used to
329#    signal to consumers that they can take ownership of the contained list
330#    rather than copying it again. See `ir_data.Message()` and `udpate()` for
331#    where this is used.
332# 4. `FieldSpec` checks are cached including `is_dataclass` and `is_sequence`.
333# 5. None checks are only done in `copy()`, `_copy_set_fields` only
334#    references `_copy()` to avoid this step.
335def _copy_set_fields(ir: IrDataT):
336  values: MutableMapping[str, Any] = {}
337
338  specs: FilteredIrFieldSpecs = ir.field_specs
339  for spec in specs.field_specs:
340    value = getattr(ir, spec.name)
341    if value is not None:
342      if spec.is_sequence:
343        if spec.is_dataclass:
344          copy_value = CopyValuesList(spec.data_type, (_copy(v) for v in value))
345          value = TemporaryCopyValuesList(copy_value)
346        else:
347          copy_value = CopyValuesList(spec.data_type, value)
348          value = TemporaryCopyValuesList(copy_value)
349      elif spec.is_dataclass:
350        value = _copy(value)
351      values[spec.name] = value
352  return values
353
354
355def _copy(ir: IrDataT) -> IrDataT:
356  return type(ir)(**_copy_set_fields(ir))  # type: ignore[misc]
357
358
359def copy(ir: IrDataT) -> Optional[IrDataT]:
360  """Creates a copy of the given IR data class"""
361  if not ir:
362    return None
363  return _copy(ir)
364
365
366def update(ir: IrDataT, template: IrDataT):
367  """Updates `ir`s fields with all set fields in the template."""
368  for k, v in _copy_set_fields(template).items():
369    if isinstance(v, TemporaryCopyValuesList):
370      v = v.temp_list
371    setattr(ir, k, v)
372
373
374class OneOfField:
375  """Decorator for a "oneof" field.
376
377  Tracks when the field is set and will unset othe fields in the associated
378  oneof group.
379
380  Note: Decorators only work if dataclass slots aren't used.
381  """
382
383  def __init__(self, oneof: str) -> None:
384    super().__init__()
385    self.oneof = oneof
386    self.owner_type = None
387    self.proxy_name: str = ""
388    self.name: str = ""
389
390  def __set_name__(self, owner, name):
391    self.name = name
392    self.proxy_name = f"_{name}"
393    self.owner_type = owner
394    # Add our empty proxy field to the class.
395    setattr(owner, self.proxy_name, None)
396
397  def __get__(self, obj, objtype=None):
398    return getattr(obj, self.proxy_name)
399
400  def __set__(self, obj, value):
401    if value is self:
402      # This will happen if the dataclass uses the default value, we just
403      # default to None.
404      value = None
405
406    if value is not None:
407      # Clear the others
408      for name, oneof in IrDataclassSpecs.get_specs(
409          self.owner_type
410      ).oneof_mappings:
411        if oneof == self.oneof and name != self.name:
412          setattr(obj, name, None)
413
414    setattr(obj, self.proxy_name, value)
415
416
417def oneof_field(name: str):
418  """Alternative for `datclasses.field` that sets up a oneof variable"""
419  return dataclasses.field(  # pylint:disable=invalid-field-call
420      default=OneOfField(name), metadata={"oneof": name}, init=True
421  )
422
423
424def str_field():
425  """Helper used to define a defaulted str field"""
426  return dataclasses.field(default_factory=str)  # pylint:disable=invalid-field-call
427
428
429def list_field(cls_or_fn):
430  """Helper used to define a defaulted list field.
431
432  A lambda can be used to defer resolution of a field type that references its
433  container type, for example:
434  ```
435  class Foo:
436    subtypes: list['Foo'] = list_field(lambda: Foo)
437    names: list[str] = list_field(str)
438  ```
439
440  Args:
441    cls_or_fn: The class type or a function that resolves to the class type.
442  """
443
444  def list_factory(c):
445    return CopyValuesList(c if isinstance(c, type) else c())
446
447  return dataclasses.field(  # pylint:disable=invalid-field-call
448      default_factory=lambda: list_factory(cls_or_fn)
449  )
450