1# Copyright 2024 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Provides helpers for defining and working with the fields of IR data classes. 16 17Field specs 18----------- 19Various utilities are provided for working with `dataclasses.Field`s: 20 - Classes: 21 - `FieldSpec` - used to track data about an IR data field 22 - `FieldContainer` - used to track the field container type 23 - Functions that work with `FieldSpec`s: 24 - `make_field_spec`, `build_default` 25 - Functions for retrieving a set of `FieldSpec`s for a given class: 26 - `field_specs` 27 - Functions for retrieving fields and their values: 28 - `fields_and_values` 29 - Functions for copying and updating IR data classes 30 - `copy`, `update` 31 - Functions to help defining IR data fields 32 - `oneof_field`, `list_field`, `str_field` 33""" 34 35import dataclasses 36import enum 37import sys 38from typing import ( 39 Any, 40 Callable, 41 ClassVar, 42 ForwardRef, 43 Iterable, 44 Mapping, 45 MutableMapping, 46 NamedTuple, 47 Optional, 48 Protocol, 49 SupportsIndex, 50 Tuple, 51 TypeVar, 52 Union, 53 cast, 54 get_args, 55 get_origin, 56) 57 58 59class IrDataclassInstance(Protocol): 60 """Type bound for an IR dataclass instance.""" 61 62 __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] 63 IR_DATACLASS: ClassVar[object] 64 field_specs: ClassVar["FilteredIrFieldSpecs"] 65 66 67IrDataT = TypeVar("IrDataT", bound=IrDataclassInstance) 68 69CopyValuesListT = TypeVar("CopyValuesListT", bound=type) 70 71_IR_DATACLASSS_ATTR = "IR_DATACLASS" 72 73 74def _is_ir_dataclass(obj): 75 return hasattr(obj, _IR_DATACLASSS_ATTR) 76 77 78class CopyValuesList(list[CopyValuesListT]): 79 """A list that makes copies of any value that is inserted""" 80 81 def __init__( 82 self, value_type: CopyValuesListT, iterable: Optional[Iterable] = None 83 ): 84 if iterable: 85 super().__init__(iterable) 86 else: 87 super().__init__() 88 self.value_type = value_type 89 90 def _copy(self, obj: Any): 91 if _is_ir_dataclass(obj): 92 return copy(obj) 93 return self.value_type(obj) 94 95 def extend(self, iterable: Iterable) -> None: 96 return super().extend([self._copy(i) for i in iterable]) 97 98 def shallow_copy(self, iterable: Iterable) -> None: 99 """Explicitly performs a shallow copy of the provided list""" 100 return super().extend(iterable) 101 102 def append(self, obj: Any) -> None: 103 return super().append(self._copy(obj)) 104 105 def insert(self, index: SupportsIndex, obj: Any) -> None: 106 return super().insert(index, self._copy(obj)) 107 108 109class TemporaryCopyValuesList(NamedTuple): 110 """Class used to temporarily hold a CopyValuesList while copying and 111 constructing an IR dataclass. 112 """ 113 114 temp_list: CopyValuesList 115 116 117class FieldContainer(enum.Enum): 118 """Indicates a fields container type""" 119 120 NONE = 0 121 OPTIONAL = 1 122 LIST = 2 123 124 125class FieldSpec(NamedTuple): 126 """Indicates the container and type of a field. 127 128 `FieldSpec` objects are accessed millions of times during runs so we cache as 129 many operations as possible. 130 - `is_dataclass`: `dataclasses.is_dataclass(data_type)` 131 - `is_sequence`: `container is FieldContainer.LIST` 132 - `is_enum`: `issubclass(data_type, enum.Enum)` 133 - `is_oneof`: `oneof is not None` 134 135 Use `make_field_spec` to automatically fill in the cached operations. 136 """ 137 138 name: str 139 data_type: type 140 container: FieldContainer 141 oneof: Optional[str] 142 is_dataclass: bool 143 is_sequence: bool 144 is_enum: bool 145 is_oneof: bool 146 147 148def make_field_spec( 149 name: str, data_type: type, container: FieldContainer, oneof: Optional[str] 150): 151 """Builds a field spec with cached type queries.""" 152 return FieldSpec( 153 name, 154 data_type, 155 container, 156 oneof, 157 is_dataclass=_is_ir_dataclass(data_type), 158 is_sequence=container is FieldContainer.LIST, 159 is_enum=issubclass(data_type, enum.Enum), 160 is_oneof=oneof is not None, 161 ) 162 163 164def build_default(field_spec: FieldSpec): 165 """Builds a default instance of the given field""" 166 if field_spec.is_sequence: 167 return CopyValuesList(field_spec.data_type) 168 if field_spec.is_enum: 169 return field_spec.data_type(int()) 170 return field_spec.data_type() 171 172 173class FilteredIrFieldSpecs: 174 """Provides cached views of an IR dataclass' fields.""" 175 176 def __init__(self, specs: Mapping[str, FieldSpec]): 177 self.all_field_specs = specs 178 self.field_specs = tuple(specs.values()) 179 self.dataclass_field_specs = { 180 k: v for k, v in specs.items() if v.is_dataclass 181 } 182 self.oneof_field_specs = {k: v for k, v in specs.items() if v.is_oneof} 183 self.sequence_field_specs = tuple( 184 v for v in specs.values() if v.is_sequence 185 ) 186 self.oneof_mappings = tuple( 187 (k, v.oneof) for k, v in self.oneof_field_specs.items() if v.oneof 188 ) 189 190 191def all_ir_classes(mod): 192 """Retrieves a list of all IR dataclass definitions in the given module.""" 193 return ( 194 v 195 for v in mod.__dict__.values() 196 if isinstance(type, v.__class__) and _is_ir_dataclass(v) 197 ) 198 199 200class IrDataclassSpecs: 201 """Maintains a cache of all IR dataclass specs.""" 202 203 spec_cache: MutableMapping[type, FilteredIrFieldSpecs] = {} 204 205 @classmethod 206 def get_mod_specs(cls, mod): 207 """Gets the IR dataclass specs for the given module.""" 208 return { 209 ir_class: FilteredIrFieldSpecs(_field_specs(ir_class)) 210 for ir_class in all_ir_classes(mod) 211 } 212 213 @classmethod 214 def get_specs(cls, data_class): 215 """Gets the field specs for the given class. The specs will be cached.""" 216 if data_class not in cls.spec_cache: 217 mod = sys.modules[data_class.__module__] 218 cls.spec_cache.update(cls.get_mod_specs(mod)) 219 return cls.spec_cache[data_class] 220 221 222def cache_message_specs(mod, cls): 223 """Adds a cached `field_specs` attribute to IR dataclasses in `mod` 224 excluding the given base `cls`. 225 226 This needs to be done after the dataclass decorators run and create the 227 wrapped classes. 228 """ 229 for data_class in all_ir_classes(mod): 230 if data_class is not cls: 231 data_class.field_specs = IrDataclassSpecs.get_specs(data_class) 232 233 234def _field_specs(cls: type[IrDataT]) -> Mapping[str, FieldSpec]: 235 """Gets the IR data field names and types for the given IR data class""" 236 # Get the dataclass fields 237 class_fields = dataclasses.fields(cast(Any, cls)) 238 239 # Pre-python 3.11 (maybe pre 3.10) `get_type_hints` will substitute 240 # `builtins.Expression` for 'Expression' rather than `ir_data.Expression`. 241 # Instead we manually subsitute the type by extracting the list of classes 242 # from the class' module and manually substituting. 243 mod_ns = { 244 k: v 245 for k, v in sys.modules[cls.__module__].__dict__.items() 246 if isinstance(type, v.__class__) 247 } 248 249 # Now extract the concrete type out of optionals 250 result: MutableMapping[str, FieldSpec] = {} 251 for class_field in class_fields: 252 if class_field.name.startswith("_"): 253 continue 254 container_type = FieldContainer.NONE 255 type_hint = class_field.type 256 oneof = class_field.metadata.get("oneof") 257 258 # Check if this type is wrapped 259 origin = get_origin(type_hint) 260 # Get the wrapped types if there are any 261 args = get_args(type_hint) 262 if origin is not None: 263 # Extract the type. 264 type_hint = args[0] 265 266 # Underneath the hood `typing.Optional` is just a `Union[T, None]` so we 267 # have to check if it's a `Union` instead of just using `Optional`. 268 if origin == Union: 269 # Make sure this is an `Optional` and not another `Union` type. 270 assert len(args) == 2 and args[1] == type(None) 271 container_type = FieldContainer.OPTIONAL 272 elif origin == list: 273 container_type = FieldContainer.LIST 274 else: 275 raise TypeError(f"Field has invalid container type: {origin}") 276 277 # Resolve any forward references. 278 if isinstance(type_hint, str): 279 type_hint = mod_ns[type_hint] 280 if isinstance(type_hint, ForwardRef): 281 type_hint = mod_ns[type_hint.__forward_arg__] 282 283 result[class_field.name] = make_field_spec( 284 class_field.name, type_hint, container_type, oneof 285 ) 286 287 return result 288 289 290def field_specs(obj: Union[IrDataT, type[IrDataT]]) -> Mapping[str, FieldSpec]: 291 """Retrieves the fields specs for the the give data type. 292 293 The results of this method are cached to reduce lookup overhead. 294 """ 295 cls = obj if isinstance(obj, type) else type(obj) 296 if cls is type(None): 297 raise TypeError("field_specs called with invalid type: NoneType") 298 return IrDataclassSpecs.get_specs(cls).all_field_specs 299 300 301def fields_and_values( 302 ir: IrDataT, 303 value_filt: Optional[Callable[[Any], bool]] = None, 304): 305 """Retrieves the fields and their values for a given IR data class. 306 307 Args: 308 ir: The IR data class or a read-only wrapper of an IR data class. 309 value_filt: Optional filter used to exclude values. 310 """ 311 set_fields: list[Tuple[FieldSpec, Any]] = [] 312 specs: FilteredIrFieldSpecs = ir.field_specs 313 for spec in specs.field_specs: 314 value = getattr(ir, spec.name) 315 if not value_filt or value_filt(value): 316 set_fields.append((spec, value)) 317 return set_fields 318 319 320# `copy` is one of the hottest paths of embossc. We've taken steps to 321# optimize this path at the expense of code clarity and modularization. 322# 323# 1. `FilteredFieldSpecs` are cached on in the class definition for IR 324# dataclasses under the `ir_data.Message.field_specs` class attribute. We 325# just assume the passed in object has that attribute. 326# 2. We cache a `field_specs` entry that is just the `values()` of the 327# `all_field_specs` dict. 328# 3. Copied lists are wrapped in a `TemporaryCopyValuesList`. This is used to 329# signal to consumers that they can take ownership of the contained list 330# rather than copying it again. See `ir_data.Message()` and `udpate()` for 331# where this is used. 332# 4. `FieldSpec` checks are cached including `is_dataclass` and `is_sequence`. 333# 5. None checks are only done in `copy()`, `_copy_set_fields` only 334# references `_copy()` to avoid this step. 335def _copy_set_fields(ir: IrDataT): 336 values: MutableMapping[str, Any] = {} 337 338 specs: FilteredIrFieldSpecs = ir.field_specs 339 for spec in specs.field_specs: 340 value = getattr(ir, spec.name) 341 if value is not None: 342 if spec.is_sequence: 343 if spec.is_dataclass: 344 copy_value = CopyValuesList(spec.data_type, (_copy(v) for v in value)) 345 value = TemporaryCopyValuesList(copy_value) 346 else: 347 copy_value = CopyValuesList(spec.data_type, value) 348 value = TemporaryCopyValuesList(copy_value) 349 elif spec.is_dataclass: 350 value = _copy(value) 351 values[spec.name] = value 352 return values 353 354 355def _copy(ir: IrDataT) -> IrDataT: 356 return type(ir)(**_copy_set_fields(ir)) # type: ignore[misc] 357 358 359def copy(ir: IrDataT) -> Optional[IrDataT]: 360 """Creates a copy of the given IR data class""" 361 if not ir: 362 return None 363 return _copy(ir) 364 365 366def update(ir: IrDataT, template: IrDataT): 367 """Updates `ir`s fields with all set fields in the template.""" 368 for k, v in _copy_set_fields(template).items(): 369 if isinstance(v, TemporaryCopyValuesList): 370 v = v.temp_list 371 setattr(ir, k, v) 372 373 374class OneOfField: 375 """Decorator for a "oneof" field. 376 377 Tracks when the field is set and will unset othe fields in the associated 378 oneof group. 379 380 Note: Decorators only work if dataclass slots aren't used. 381 """ 382 383 def __init__(self, oneof: str) -> None: 384 super().__init__() 385 self.oneof = oneof 386 self.owner_type = None 387 self.proxy_name: str = "" 388 self.name: str = "" 389 390 def __set_name__(self, owner, name): 391 self.name = name 392 self.proxy_name = f"_{name}" 393 self.owner_type = owner 394 # Add our empty proxy field to the class. 395 setattr(owner, self.proxy_name, None) 396 397 def __get__(self, obj, objtype=None): 398 return getattr(obj, self.proxy_name) 399 400 def __set__(self, obj, value): 401 if value is self: 402 # This will happen if the dataclass uses the default value, we just 403 # default to None. 404 value = None 405 406 if value is not None: 407 # Clear the others 408 for name, oneof in IrDataclassSpecs.get_specs( 409 self.owner_type 410 ).oneof_mappings: 411 if oneof == self.oneof and name != self.name: 412 setattr(obj, name, None) 413 414 setattr(obj, self.proxy_name, value) 415 416 417def oneof_field(name: str): 418 """Alternative for `datclasses.field` that sets up a oneof variable""" 419 return dataclasses.field( # pylint:disable=invalid-field-call 420 default=OneOfField(name), metadata={"oneof": name}, init=True 421 ) 422 423 424def str_field(): 425 """Helper used to define a defaulted str field""" 426 return dataclasses.field(default_factory=str) # pylint:disable=invalid-field-call 427 428 429def list_field(cls_or_fn): 430 """Helper used to define a defaulted list field. 431 432 A lambda can be used to defer resolution of a field type that references its 433 container type, for example: 434 ``` 435 class Foo: 436 subtypes: list['Foo'] = list_field(lambda: Foo) 437 names: list[str] = list_field(str) 438 ``` 439 440 Args: 441 cls_or_fn: The class type or a function that resolves to the class type. 442 """ 443 444 def list_factory(c): 445 return CopyValuesList(c if isinstance(c, type) else c()) 446 447 return dataclasses.field( # pylint:disable=invalid-field-call 448 default_factory=lambda: list_factory(cls_or_fn) 449 ) 450