xref: /aosp_15_r20/external/emboss/compiler/util/ir_data_utils.py (revision 99e0aae7469b87d12f0ad23e61142c2d74c1ef70)
1# Copyright 2024 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Provides a helpers for working with IR data elements.
16
17Historical note: At one point protocol buffers were used for IR data. The
18codebase still expects the IR data classes to behave similarly, particularly
19with respect to "autovivification" where accessing an undefined field will
20create it temporarily and add it if assigned to. Though, perhaps not fully
21following the Pythonic ethos, we provide this behavior via the `builder` and
22`reader` helpers to remain compatible with the rest of the codebase.
23
24builder
25-------
26Instead of:
27```
28def set_function_name_end(function: Function):
29  if not function.function_name:
30    function.function_name = Word()
31  if not function.function_name.source_location:
32    function.function_name.source_location = Location()
33  word.source_location.end = Position(line=1,column=2)
34```
35
36We can do:
37```
38def set_function_name_end(function: Function):
39  builder(function).function_name.source_location.end = Position(line=1,
40  column=2)
41```
42
43reader
44------
45Instead of:
46```
47def is_leaf_synthetic(data):
48  if data:
49    if data.attribute:
50      if data.attribute.value:
51        if data.attribute.value.is_synthetic is not None:
52          return data.attribute.value.is_synthetic
53  return False
54```
55We can do:
56```
57def is_leaf_synthetic(data):
58  return reader(data).attribute.value.is_synthetic
59```
60
61IrDataSerializer
62----------------
63Provides methods for serializing and deserializing an IR data object.
64"""
65import enum
66import json
67from typing import (
68    Any,
69    Callable,
70    Generic,
71    MutableMapping,
72    MutableSequence,
73    Optional,
74    Tuple,
75    TypeVar,
76    Union,
77    cast,
78)
79
80from compiler.util import ir_data
81from compiler.util import ir_data_fields
82
83
84MessageT = TypeVar("MessageT", bound=ir_data.Message)
85
86
87def field_specs(ir: Union[MessageT, type[MessageT]]):
88  """Retrieves the field specs for the IR data class"""
89  data_type = ir if isinstance(ir, type) else type(ir)
90  return ir_data_fields.IrDataclassSpecs.get_specs(data_type).all_field_specs
91
92
93class IrDataSerializer:
94  """Provides methods for serializing IR data objects"""
95
96  def __init__(self, ir: MessageT):
97    assert ir is not None
98    self.ir = ir
99
100  def _to_dict(
101      self,
102      ir: MessageT,
103      field_func: Callable[
104          [MessageT], list[Tuple[ir_data_fields.FieldSpec, Any]]
105      ],
106  ) -> MutableMapping[str, Any]:
107    assert ir is not None
108    values: MutableMapping[str, Any] = {}
109    for spec, value in field_func(ir):
110      if value is not None and spec.is_dataclass:
111        if spec.is_sequence:
112          value = [self._to_dict(v, field_func) for v in value]
113        else:
114          value = self._to_dict(value, field_func)
115      values[spec.name] = value
116    return values
117
118  def to_dict(self, exclude_none: bool = False):
119    """Converts the IR data class to a dictionary."""
120
121    def non_empty(ir):
122      return fields_and_values(
123          ir, lambda v: v is not None and (not isinstance(v, list) or len(v))
124      )
125
126    def all_fields(ir):
127      return fields_and_values(ir)
128
129    # It's tempting to use `dataclasses.asdict` here, but that does a deep
130    # copy which is overkill for the current usage; mainly as an intermediary
131    # for `to_json` and `repr`.
132    return self._to_dict(self.ir, non_empty if exclude_none else all_fields)
133
134  def to_json(self, *args, **kwargs):
135    """Converts the IR data class to a JSON string"""
136    return json.dumps(self.to_dict(exclude_none=True), *args, **kwargs)
137
138  @staticmethod
139  def from_json(data_cls, data):
140    """Constructs an IR data class from the given JSON string"""
141    as_dict = json.loads(data)
142    return IrDataSerializer.from_dict(data_cls, as_dict)
143
144  def copy_from_dict(self, data):
145    """Deserializes the data and overwrites the IR data class with it"""
146    cls = type(self.ir)
147    data_copy = IrDataSerializer.from_dict(cls, data)
148    for k in field_specs(cls):
149      setattr(self.ir, k, getattr(data_copy, k))
150
151  @staticmethod
152  def _enum_type_converter(enum_cls: type[enum.Enum], val: Any) -> enum.Enum:
153    if isinstance(val, str):
154      return getattr(enum_cls, val)
155    return enum_cls(val)
156
157  @staticmethod
158  def _enum_type_hook(enum_cls: type[enum.Enum]):
159    return lambda val: IrDataSerializer._enum_type_converter(enum_cls, val)
160
161  @staticmethod
162  def _from_dict(data_cls: type[MessageT], data):
163    class_fields: MutableMapping[str, Any] = {}
164    for name, spec in ir_data_fields.field_specs(data_cls).items():
165      if (value := data.get(name)) is not None:
166        if spec.is_dataclass:
167          if spec.is_sequence:
168            class_fields[name] = [
169                IrDataSerializer._from_dict(spec.data_type, v) for v in value
170            ]
171          else:
172            class_fields[name] = IrDataSerializer._from_dict(
173                spec.data_type, value
174            )
175        else:
176          if spec.data_type in (
177              ir_data.FunctionMapping,
178              ir_data.AddressableUnit,
179          ):
180            class_fields[name] = IrDataSerializer._enum_type_converter(
181                spec.data_type, value
182            )
183          else:
184            if spec.is_sequence:
185              class_fields[name] = value
186            else:
187              class_fields[name] = spec.data_type(value)
188    return data_cls(**class_fields)
189
190  @staticmethod
191  def from_dict(data_cls: type[MessageT], data):
192    """Creates a new IR data instance from a serialized dict"""
193    return IrDataSerializer._from_dict(data_cls, data)
194
195
196class _IrDataSequenceBuilder(MutableSequence[MessageT]):
197  """Wrapper for a list of IR elements
198
199  Simply wraps the returned values during indexed access and iteration with
200  IrDataBuilders.
201  """
202
203  def __init__(self, target: MutableSequence[MessageT]):
204    self._target = target
205
206  def __delitem__(self, key):
207    del self._target[key]
208
209  def __getitem__(self, key):
210    return _IrDataBuilder(self._target.__getitem__(key))
211
212  def __setitem__(self, key, value):
213    self._target[key] = value
214
215  def __iter__(self):
216    itr = iter(self._target)
217    for i in itr:
218      yield _IrDataBuilder(i)
219
220  def __repr__(self):
221    return repr(self._target)
222
223  def __len__(self):
224    return len(self._target)
225
226  def __eq__(self, other):
227    return self._target == other
228
229  def __ne__(self, other):
230    return self._target != other
231
232  def insert(self, index, value):
233    self._target.insert(index, value)
234
235  def extend(self, values):
236    self._target.extend(values)
237
238
239class _IrDataBuilder(Generic[MessageT]):
240  """Wrapper for an IR element"""
241
242  def __init__(self, ir: MessageT) -> None:
243    assert ir is not None
244    self.ir: MessageT = ir
245
246  def __setattr__(self, __name: str, __value: Any) -> None:
247    if __name == "ir":
248      # This our proxy object
249      object.__setattr__(self, __name, __value)
250    else:
251      # Passthrough to the proxy object
252      ir: MessageT = object.__getattribute__(self, "ir")
253      setattr(ir, __name, __value)
254
255  def __getattribute__(self, name: str) -> Any:
256    """Hook for `getattr` that handles adding missing fields.
257
258    If the field is missing inserts it, and then returns either the raw value
259    for basic types
260    or a new IrBuilder wrapping the field to handle the next field access in a
261    longer chain.
262    """
263
264    # Check if getting one of the builder attributes
265    if name in ("CopyFrom", "ir"):
266      return object.__getattribute__(self, name)
267
268    # Get our target object by bypassing our getattr hook
269    ir: MessageT = object.__getattribute__(self, "ir")
270    if ir is None:
271      return object.__getattribute__(self, name)
272
273    if name in ("HasField", "WhichOneof"):
274      return getattr(ir, name)
275
276    field_spec = field_specs(ir).get(name)
277    if field_spec is None:
278      raise AttributeError(
279          f"No field {name} on {type(ir).__module__}.{type(ir).__name__}."
280      )
281
282    obj = getattr(ir, name, None)
283    if obj is None:
284      # Create a default and store it
285      obj = ir_data_fields.build_default(field_spec)
286      setattr(ir, name, obj)
287
288    if field_spec.is_dataclass:
289      obj = (
290          _IrDataSequenceBuilder(obj)
291          if field_spec.is_sequence
292          else _IrDataBuilder(obj)
293      )
294
295    return obj
296
297  def CopyFrom(self, template: MessageT):  # pylint:disable=invalid-name
298    """Updates the fields of this class with values set in the template"""
299    update(cast(type[MessageT], self), template)
300
301
302def builder(target: MessageT) -> MessageT:
303  """Create a wrapper around the target to help build an IR Data structure"""
304  # Check if the target is already a builder.
305  if isinstance(target, (_IrDataBuilder, _IrDataSequenceBuilder)):
306    return target
307
308  # Builders are only valid for IR data classes.
309  if not hasattr(type(target), "IR_DATACLASS"):
310    raise TypeError(f"Builder target {type(target)} is not an ir_data.message")
311
312  # Create a builder and cast it to the target type to expose type hinting for
313  # the wrapped type.
314  return cast(MessageT, _IrDataBuilder(target))
315
316
317def _field_checker_from_spec(spec: ir_data_fields.FieldSpec):
318  """Helper that builds an FieldChecker that pretends to be an IR class"""
319  if spec.is_sequence:
320    return []
321  if spec.is_dataclass:
322    return _ReadOnlyFieldChecker(spec)
323  return ir_data_fields.build_default(spec)
324
325
326def _field_type(ir_or_spec: Union[MessageT, ir_data_fields.FieldSpec]) -> type:
327  if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
328    return ir_or_spec.data_type
329  return type(ir_or_spec)
330
331
332class _ReadOnlyFieldChecker:
333  """Class used the chain calls to fields that aren't set"""
334
335  def __init__(self, ir_or_spec: Union[MessageT, ir_data_fields.FieldSpec]) -> None:
336    self.ir_or_spec = ir_or_spec
337
338  def __setattr__(self, name: str, value: Any) -> None:
339    if name == "ir_or_spec":
340      return object.__setattr__(self, name, value)
341
342    raise AttributeError(f"Cannot set {name} on read-only wrapper")
343
344  def __getattribute__(self, name: str) -> Any:  # pylint:disable=too-many-return-statements
345    ir_or_spec = object.__getattribute__(self, "ir_or_spec")
346    if name == "ir_or_spec":
347      return ir_or_spec
348
349    field_type = _field_type(ir_or_spec)
350    spec = field_specs(field_type).get(name)
351    if not spec:
352      if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
353        if name == "HasField":
354          return lambda x: False
355        if name == "WhichOneof":
356          return lambda x: None
357      return object.__getattribute__(ir_or_spec, name)
358
359    if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
360      # Just pretending
361      return _field_checker_from_spec(spec)
362
363    value = getattr(ir_or_spec, name)
364    if value is None:
365      return _field_checker_from_spec(spec)
366
367    if spec.is_dataclass:
368      if spec.is_sequence:
369        return [_ReadOnlyFieldChecker(i) for i in value]
370      return _ReadOnlyFieldChecker(value)
371
372    return value
373
374  def __eq__(self, other):
375    if isinstance(other, _ReadOnlyFieldChecker):
376      other = other.ir_or_spec
377    return self.ir_or_spec == other
378
379  def __ne__(self, other):
380    return not self == other
381
382
383def reader(obj: Union[MessageT, _ReadOnlyFieldChecker]) -> MessageT:
384  """Builds a read-only wrapper that can be used to check chains of possibly
385  unset fields.
386
387  This wrapper explicitly does not alter the wrapped object and is only
388  intended for reading contents.
389
390  For example, a `reader` lets you do:
391  ```
392  def get_function_name_end_column(function: ir_data.Function):
393    return reader(function).function_name.source_location.end.column
394  ```
395
396  Instead of:
397  ```
398  def get_function_name_end_column(function: ir_data.Function):
399    if function.function_name:
400      if function.function_name.source_location:
401        if function.function_name.source_location.end:
402          return function.function_name.source_location.end.column
403    return 0
404  ```
405  """
406  # Create a read-only wrapper if it's not already one.
407  if not isinstance(obj, _ReadOnlyFieldChecker):
408    obj = _ReadOnlyFieldChecker(obj)
409
410  # Cast it back to the original type.
411  return cast(MessageT, obj)
412
413
414def _extract_ir(
415    ir_or_wrapper: Union[MessageT, _ReadOnlyFieldChecker, _IrDataBuilder, None],
416) -> Optional[ir_data_fields.IrDataclassInstance]:
417  if isinstance(ir_or_wrapper, _ReadOnlyFieldChecker):
418    ir_or_spec = ir_or_wrapper.ir_or_spec
419    if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
420      # This is a placeholder entry, no fields are set.
421      return None
422    ir_or_wrapper = ir_or_spec
423  elif isinstance(ir_or_wrapper, _IrDataBuilder):
424    ir_or_wrapper = ir_or_wrapper.ir
425  return cast(ir_data_fields.IrDataclassInstance, ir_or_wrapper)
426
427
428def fields_and_values(
429    ir_wrapper: Union[MessageT, _ReadOnlyFieldChecker],
430    value_filt: Optional[Callable[[Any], bool]] = None,
431) -> list[Tuple[ir_data_fields.FieldSpec, Any]]:
432  """Retrieves the fields and their values for a given IR data class.
433
434  Args:
435    ir: The IR data class or a read-only wrapper of an IR data class.
436    value_filt: Optional filter used to exclude values.
437  """
438  if (ir := _extract_ir(ir_wrapper)) is None:
439    return []
440
441  return ir_data_fields.fields_and_values(ir, value_filt)
442
443
444def get_set_fields(ir: MessageT):
445  """Retrieves the field spec and value of fields that are set in the given IR data class.
446
447  A value is considered "set" if it is not None.
448  """
449  return fields_and_values(ir, lambda v: v is not None)
450
451
452def copy(ir_wrapper: Optional[MessageT]) -> Optional[MessageT]:
453  """Creates a copy of the given IR data class"""
454  if (ir := _extract_ir(ir_wrapper)) is None:
455    return None
456  ir_copy = ir_data_fields.copy(ir)
457  return cast(MessageT, ir_copy)
458
459
460def update(ir: MessageT, template: MessageT):
461  """Updates `ir`s fields with all set fields in the template."""
462  if not (template_ir := _extract_ir(template)):
463    return
464
465  ir_data_fields.update(
466      cast(ir_data_fields.IrDataclassInstance, ir), template_ir
467  )
468