1# Copyright 2024 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Provides a helpers for working with IR data elements. 16 17Historical note: At one point protocol buffers were used for IR data. The 18codebase still expects the IR data classes to behave similarly, particularly 19with respect to "autovivification" where accessing an undefined field will 20create it temporarily and add it if assigned to. Though, perhaps not fully 21following the Pythonic ethos, we provide this behavior via the `builder` and 22`reader` helpers to remain compatible with the rest of the codebase. 23 24builder 25------- 26Instead of: 27``` 28def set_function_name_end(function: Function): 29 if not function.function_name: 30 function.function_name = Word() 31 if not function.function_name.source_location: 32 function.function_name.source_location = Location() 33 word.source_location.end = Position(line=1,column=2) 34``` 35 36We can do: 37``` 38def set_function_name_end(function: Function): 39 builder(function).function_name.source_location.end = Position(line=1, 40 column=2) 41``` 42 43reader 44------ 45Instead of: 46``` 47def is_leaf_synthetic(data): 48 if data: 49 if data.attribute: 50 if data.attribute.value: 51 if data.attribute.value.is_synthetic is not None: 52 return data.attribute.value.is_synthetic 53 return False 54``` 55We can do: 56``` 57def is_leaf_synthetic(data): 58 return reader(data).attribute.value.is_synthetic 59``` 60 61IrDataSerializer 62---------------- 63Provides methods for serializing and deserializing an IR data object. 64""" 65import enum 66import json 67from typing import ( 68 Any, 69 Callable, 70 Generic, 71 MutableMapping, 72 MutableSequence, 73 Optional, 74 Tuple, 75 TypeVar, 76 Union, 77 cast, 78) 79 80from compiler.util import ir_data 81from compiler.util import ir_data_fields 82 83 84MessageT = TypeVar("MessageT", bound=ir_data.Message) 85 86 87def field_specs(ir: Union[MessageT, type[MessageT]]): 88 """Retrieves the field specs for the IR data class""" 89 data_type = ir if isinstance(ir, type) else type(ir) 90 return ir_data_fields.IrDataclassSpecs.get_specs(data_type).all_field_specs 91 92 93class IrDataSerializer: 94 """Provides methods for serializing IR data objects""" 95 96 def __init__(self, ir: MessageT): 97 assert ir is not None 98 self.ir = ir 99 100 def _to_dict( 101 self, 102 ir: MessageT, 103 field_func: Callable[ 104 [MessageT], list[Tuple[ir_data_fields.FieldSpec, Any]] 105 ], 106 ) -> MutableMapping[str, Any]: 107 assert ir is not None 108 values: MutableMapping[str, Any] = {} 109 for spec, value in field_func(ir): 110 if value is not None and spec.is_dataclass: 111 if spec.is_sequence: 112 value = [self._to_dict(v, field_func) for v in value] 113 else: 114 value = self._to_dict(value, field_func) 115 values[spec.name] = value 116 return values 117 118 def to_dict(self, exclude_none: bool = False): 119 """Converts the IR data class to a dictionary.""" 120 121 def non_empty(ir): 122 return fields_and_values( 123 ir, lambda v: v is not None and (not isinstance(v, list) or len(v)) 124 ) 125 126 def all_fields(ir): 127 return fields_and_values(ir) 128 129 # It's tempting to use `dataclasses.asdict` here, but that does a deep 130 # copy which is overkill for the current usage; mainly as an intermediary 131 # for `to_json` and `repr`. 132 return self._to_dict(self.ir, non_empty if exclude_none else all_fields) 133 134 def to_json(self, *args, **kwargs): 135 """Converts the IR data class to a JSON string""" 136 return json.dumps(self.to_dict(exclude_none=True), *args, **kwargs) 137 138 @staticmethod 139 def from_json(data_cls, data): 140 """Constructs an IR data class from the given JSON string""" 141 as_dict = json.loads(data) 142 return IrDataSerializer.from_dict(data_cls, as_dict) 143 144 def copy_from_dict(self, data): 145 """Deserializes the data and overwrites the IR data class with it""" 146 cls = type(self.ir) 147 data_copy = IrDataSerializer.from_dict(cls, data) 148 for k in field_specs(cls): 149 setattr(self.ir, k, getattr(data_copy, k)) 150 151 @staticmethod 152 def _enum_type_converter(enum_cls: type[enum.Enum], val: Any) -> enum.Enum: 153 if isinstance(val, str): 154 return getattr(enum_cls, val) 155 return enum_cls(val) 156 157 @staticmethod 158 def _enum_type_hook(enum_cls: type[enum.Enum]): 159 return lambda val: IrDataSerializer._enum_type_converter(enum_cls, val) 160 161 @staticmethod 162 def _from_dict(data_cls: type[MessageT], data): 163 class_fields: MutableMapping[str, Any] = {} 164 for name, spec in ir_data_fields.field_specs(data_cls).items(): 165 if (value := data.get(name)) is not None: 166 if spec.is_dataclass: 167 if spec.is_sequence: 168 class_fields[name] = [ 169 IrDataSerializer._from_dict(spec.data_type, v) for v in value 170 ] 171 else: 172 class_fields[name] = IrDataSerializer._from_dict( 173 spec.data_type, value 174 ) 175 else: 176 if spec.data_type in ( 177 ir_data.FunctionMapping, 178 ir_data.AddressableUnit, 179 ): 180 class_fields[name] = IrDataSerializer._enum_type_converter( 181 spec.data_type, value 182 ) 183 else: 184 if spec.is_sequence: 185 class_fields[name] = value 186 else: 187 class_fields[name] = spec.data_type(value) 188 return data_cls(**class_fields) 189 190 @staticmethod 191 def from_dict(data_cls: type[MessageT], data): 192 """Creates a new IR data instance from a serialized dict""" 193 return IrDataSerializer._from_dict(data_cls, data) 194 195 196class _IrDataSequenceBuilder(MutableSequence[MessageT]): 197 """Wrapper for a list of IR elements 198 199 Simply wraps the returned values during indexed access and iteration with 200 IrDataBuilders. 201 """ 202 203 def __init__(self, target: MutableSequence[MessageT]): 204 self._target = target 205 206 def __delitem__(self, key): 207 del self._target[key] 208 209 def __getitem__(self, key): 210 return _IrDataBuilder(self._target.__getitem__(key)) 211 212 def __setitem__(self, key, value): 213 self._target[key] = value 214 215 def __iter__(self): 216 itr = iter(self._target) 217 for i in itr: 218 yield _IrDataBuilder(i) 219 220 def __repr__(self): 221 return repr(self._target) 222 223 def __len__(self): 224 return len(self._target) 225 226 def __eq__(self, other): 227 return self._target == other 228 229 def __ne__(self, other): 230 return self._target != other 231 232 def insert(self, index, value): 233 self._target.insert(index, value) 234 235 def extend(self, values): 236 self._target.extend(values) 237 238 239class _IrDataBuilder(Generic[MessageT]): 240 """Wrapper for an IR element""" 241 242 def __init__(self, ir: MessageT) -> None: 243 assert ir is not None 244 self.ir: MessageT = ir 245 246 def __setattr__(self, __name: str, __value: Any) -> None: 247 if __name == "ir": 248 # This our proxy object 249 object.__setattr__(self, __name, __value) 250 else: 251 # Passthrough to the proxy object 252 ir: MessageT = object.__getattribute__(self, "ir") 253 setattr(ir, __name, __value) 254 255 def __getattribute__(self, name: str) -> Any: 256 """Hook for `getattr` that handles adding missing fields. 257 258 If the field is missing inserts it, and then returns either the raw value 259 for basic types 260 or a new IrBuilder wrapping the field to handle the next field access in a 261 longer chain. 262 """ 263 264 # Check if getting one of the builder attributes 265 if name in ("CopyFrom", "ir"): 266 return object.__getattribute__(self, name) 267 268 # Get our target object by bypassing our getattr hook 269 ir: MessageT = object.__getattribute__(self, "ir") 270 if ir is None: 271 return object.__getattribute__(self, name) 272 273 if name in ("HasField", "WhichOneof"): 274 return getattr(ir, name) 275 276 field_spec = field_specs(ir).get(name) 277 if field_spec is None: 278 raise AttributeError( 279 f"No field {name} on {type(ir).__module__}.{type(ir).__name__}." 280 ) 281 282 obj = getattr(ir, name, None) 283 if obj is None: 284 # Create a default and store it 285 obj = ir_data_fields.build_default(field_spec) 286 setattr(ir, name, obj) 287 288 if field_spec.is_dataclass: 289 obj = ( 290 _IrDataSequenceBuilder(obj) 291 if field_spec.is_sequence 292 else _IrDataBuilder(obj) 293 ) 294 295 return obj 296 297 def CopyFrom(self, template: MessageT): # pylint:disable=invalid-name 298 """Updates the fields of this class with values set in the template""" 299 update(cast(type[MessageT], self), template) 300 301 302def builder(target: MessageT) -> MessageT: 303 """Create a wrapper around the target to help build an IR Data structure""" 304 # Check if the target is already a builder. 305 if isinstance(target, (_IrDataBuilder, _IrDataSequenceBuilder)): 306 return target 307 308 # Builders are only valid for IR data classes. 309 if not hasattr(type(target), "IR_DATACLASS"): 310 raise TypeError(f"Builder target {type(target)} is not an ir_data.message") 311 312 # Create a builder and cast it to the target type to expose type hinting for 313 # the wrapped type. 314 return cast(MessageT, _IrDataBuilder(target)) 315 316 317def _field_checker_from_spec(spec: ir_data_fields.FieldSpec): 318 """Helper that builds an FieldChecker that pretends to be an IR class""" 319 if spec.is_sequence: 320 return [] 321 if spec.is_dataclass: 322 return _ReadOnlyFieldChecker(spec) 323 return ir_data_fields.build_default(spec) 324 325 326def _field_type(ir_or_spec: Union[MessageT, ir_data_fields.FieldSpec]) -> type: 327 if isinstance(ir_or_spec, ir_data_fields.FieldSpec): 328 return ir_or_spec.data_type 329 return type(ir_or_spec) 330 331 332class _ReadOnlyFieldChecker: 333 """Class used the chain calls to fields that aren't set""" 334 335 def __init__(self, ir_or_spec: Union[MessageT, ir_data_fields.FieldSpec]) -> None: 336 self.ir_or_spec = ir_or_spec 337 338 def __setattr__(self, name: str, value: Any) -> None: 339 if name == "ir_or_spec": 340 return object.__setattr__(self, name, value) 341 342 raise AttributeError(f"Cannot set {name} on read-only wrapper") 343 344 def __getattribute__(self, name: str) -> Any: # pylint:disable=too-many-return-statements 345 ir_or_spec = object.__getattribute__(self, "ir_or_spec") 346 if name == "ir_or_spec": 347 return ir_or_spec 348 349 field_type = _field_type(ir_or_spec) 350 spec = field_specs(field_type).get(name) 351 if not spec: 352 if isinstance(ir_or_spec, ir_data_fields.FieldSpec): 353 if name == "HasField": 354 return lambda x: False 355 if name == "WhichOneof": 356 return lambda x: None 357 return object.__getattribute__(ir_or_spec, name) 358 359 if isinstance(ir_or_spec, ir_data_fields.FieldSpec): 360 # Just pretending 361 return _field_checker_from_spec(spec) 362 363 value = getattr(ir_or_spec, name) 364 if value is None: 365 return _field_checker_from_spec(spec) 366 367 if spec.is_dataclass: 368 if spec.is_sequence: 369 return [_ReadOnlyFieldChecker(i) for i in value] 370 return _ReadOnlyFieldChecker(value) 371 372 return value 373 374 def __eq__(self, other): 375 if isinstance(other, _ReadOnlyFieldChecker): 376 other = other.ir_or_spec 377 return self.ir_or_spec == other 378 379 def __ne__(self, other): 380 return not self == other 381 382 383def reader(obj: Union[MessageT, _ReadOnlyFieldChecker]) -> MessageT: 384 """Builds a read-only wrapper that can be used to check chains of possibly 385 unset fields. 386 387 This wrapper explicitly does not alter the wrapped object and is only 388 intended for reading contents. 389 390 For example, a `reader` lets you do: 391 ``` 392 def get_function_name_end_column(function: ir_data.Function): 393 return reader(function).function_name.source_location.end.column 394 ``` 395 396 Instead of: 397 ``` 398 def get_function_name_end_column(function: ir_data.Function): 399 if function.function_name: 400 if function.function_name.source_location: 401 if function.function_name.source_location.end: 402 return function.function_name.source_location.end.column 403 return 0 404 ``` 405 """ 406 # Create a read-only wrapper if it's not already one. 407 if not isinstance(obj, _ReadOnlyFieldChecker): 408 obj = _ReadOnlyFieldChecker(obj) 409 410 # Cast it back to the original type. 411 return cast(MessageT, obj) 412 413 414def _extract_ir( 415 ir_or_wrapper: Union[MessageT, _ReadOnlyFieldChecker, _IrDataBuilder, None], 416) -> Optional[ir_data_fields.IrDataclassInstance]: 417 if isinstance(ir_or_wrapper, _ReadOnlyFieldChecker): 418 ir_or_spec = ir_or_wrapper.ir_or_spec 419 if isinstance(ir_or_spec, ir_data_fields.FieldSpec): 420 # This is a placeholder entry, no fields are set. 421 return None 422 ir_or_wrapper = ir_or_spec 423 elif isinstance(ir_or_wrapper, _IrDataBuilder): 424 ir_or_wrapper = ir_or_wrapper.ir 425 return cast(ir_data_fields.IrDataclassInstance, ir_or_wrapper) 426 427 428def fields_and_values( 429 ir_wrapper: Union[MessageT, _ReadOnlyFieldChecker], 430 value_filt: Optional[Callable[[Any], bool]] = None, 431) -> list[Tuple[ir_data_fields.FieldSpec, Any]]: 432 """Retrieves the fields and their values for a given IR data class. 433 434 Args: 435 ir: The IR data class or a read-only wrapper of an IR data class. 436 value_filt: Optional filter used to exclude values. 437 """ 438 if (ir := _extract_ir(ir_wrapper)) is None: 439 return [] 440 441 return ir_data_fields.fields_and_values(ir, value_filt) 442 443 444def get_set_fields(ir: MessageT): 445 """Retrieves the field spec and value of fields that are set in the given IR data class. 446 447 A value is considered "set" if it is not None. 448 """ 449 return fields_and_values(ir, lambda v: v is not None) 450 451 452def copy(ir_wrapper: Optional[MessageT]) -> Optional[MessageT]: 453 """Creates a copy of the given IR data class""" 454 if (ir := _extract_ir(ir_wrapper)) is None: 455 return None 456 ir_copy = ir_data_fields.copy(ir) 457 return cast(MessageT, ir_copy) 458 459 460def update(ir: MessageT, template: MessageT): 461 """Updates `ir`s fields with all set fields in the template.""" 462 if not (template_ir := _extract_ir(template)): 463 return 464 465 ir_data_fields.update( 466 cast(ir_data_fields.IrDataclassInstance, ir), template_ir 467 ) 468