1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TraceType implementations for common Python types.""" 16 17from typing import Any, Hashable, Optional, Sequence, Type 18from typing import Dict as PythonDict 19from typing import Tuple as PythonTuple 20import weakref 21 22from tensorflow.core.function.trace_type import default_types_pb2 23from tensorflow.core.function.trace_type import serialization 24from tensorflow.python.types import trace 25 26 27class Literal(trace.TraceType, serialization.Serializable): 28 """Represents a Literal type like bool, int or string.""" 29 30 def __init__(self, value: Any): 31 self.value = value 32 self._value_hash = hash(value) 33 34 def is_subtype_of(self, other: trace.TraceType) -> bool: 35 return self == other 36 37 def most_specific_common_supertype( 38 self, types: Sequence[trace.TraceType]) -> Optional["Literal"]: 39 return self if all(self == other for other in types) else None 40 41 @classmethod 42 def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedLiteral]: 43 return default_types_pb2.SerializedLiteral 44 45 @classmethod 46 def experimental_from_proto( 47 cls, proto: default_types_pb2.SerializedLiteral) -> "Literal": 48 if proto.HasField("bool_value"): 49 return Literal(proto.bool_value) 50 51 if proto.HasField("int_value"): 52 return Literal(proto.int_value) 53 54 if proto.HasField("float_value"): 55 return Literal(proto.float_value) 56 57 if proto.HasField("str_value"): 58 return Literal(proto.str_value) 59 60 if proto.HasField("none_value"): 61 return Literal(None) 62 63 raise ValueError("Malformed Literal proto can not be deserialized") 64 65 def experimental_as_proto(self) -> default_types_pb2.SerializedLiteral: 66 if isinstance(self.value, bool): 67 return default_types_pb2.SerializedLiteral(bool_value=self.value) 68 69 if isinstance(self.value, int): 70 return default_types_pb2.SerializedLiteral(int_value=self.value) 71 72 if isinstance(self.value, float): 73 return default_types_pb2.SerializedLiteral(float_value=self.value) 74 75 if isinstance(self.value, str): 76 return default_types_pb2.SerializedLiteral(str_value=self.value) 77 78 if self.value is None: 79 return default_types_pb2.SerializedLiteral( 80 none_value=default_types_pb2.SerializedLiteral.NoneValue()) 81 82 raise ValueError("Can not serialize Literal of type " + 83 type(self.value).__name__) 84 85 def _placeholder_value(self) -> Any: 86 return self.value 87 88 def __eq__(self, other) -> bool: 89 if not isinstance(other, trace.TraceType): 90 return NotImplemented 91 92 return isinstance(other, Literal) and self.value == other.value 93 94 def __hash__(self) -> int: 95 return self._value_hash 96 97 def __repr__(self): 98 return f"{self.__class__.__name__}(value={self.value!r})" 99 100 101class Weakref(trace.TraceType): 102 """Represents weakref of an arbitrary Python object. 103 104 When a function argument is a custom class, instead of making a copy of it 105 just for the sake of function cache, a weakref is instead kept to save memory. 106 """ 107 108 def __init__(self, ref: weakref.ReferenceType): 109 self._ref = ref 110 self._ref_hash = hash(ref) 111 112 def is_subtype_of(self, other: trace.TraceType) -> bool: 113 return self == other 114 115 def most_specific_common_supertype( 116 self, types: Sequence[trace.TraceType]) -> Optional["Weakref"]: 117 return self if all(self == other for other in types) else None 118 119 def _placeholder_value(self) -> Any: 120 return self._ref() 121 122 def __eq__(self, other): 123 if not isinstance(other, trace.TraceType): 124 return NotImplemented 125 126 if not isinstance(other, Weakref): 127 return False 128 129 if self._ref() is None or other._ref() is None: 130 return False 131 132 if self._ref() is other._ref(): 133 return True 134 135 return self._ref == other._ref 136 137 def __hash__(self): 138 return self._ref_hash 139 140 def __repr__(self): 141 return f"{self.__class__.__name__}(ref={self._ref!r})" 142 143 144class Tuple(trace.TraceType, serialization.Serializable): 145 """Represents a tuple of TraceType objects.""" 146 147 def __init__(self, *components: trace.TraceType): 148 self.components = components 149 150 def is_subtype_of(self, other: trace.TraceType) -> bool: 151 if (not isinstance(other, Tuple) or 152 len(self.components) != len(other.components)): 153 return False 154 155 return all( 156 self_component.is_subtype_of(other_component) for self_component, 157 other_component in zip(self.components, other.components)) 158 159 def most_specific_common_supertype( 160 self, others: Sequence[trace.TraceType]) -> Optional["Tuple"]: 161 """See base class.""" 162 if not all( 163 isinstance(other, Tuple) and 164 len(self.components) == len(other.components) for other in others): 165 return None 166 167 supertyped_components = [] 168 for i, component in enumerate(self.components): 169 supertyped_component = component.most_specific_common_supertype( 170 [other.components[i] for other in others]) 171 if supertyped_component is None: 172 return None 173 supertyped_components.append(supertyped_component) 174 175 return Tuple(*supertyped_components) 176 177 @classmethod 178 def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedTuple]: 179 return default_types_pb2.SerializedTuple 180 181 @classmethod 182 def experimental_from_proto( 183 cls, proto: default_types_pb2.SerializedTuple) -> "Tuple": 184 return Tuple(*[serialization.deserialize(c) for c in proto.components]) 185 186 def experimental_as_proto(self) -> default_types_pb2.SerializedTuple: 187 return default_types_pb2.SerializedTuple( 188 components=[serialization.serialize(c) for c in self.components]) 189 190 def _placeholder_value(self) -> Any: 191 components = [ 192 component._placeholder_value() # pylint: disable=protected-access 193 for component in self.components 194 ] 195 return tuple(components) 196 197 def __eq__(self, other: Any) -> bool: 198 if not isinstance(other, trace.TraceType): 199 return NotImplemented 200 201 if not isinstance(other, Tuple): 202 return False 203 204 return self.components == other.components 205 206 def __hash__(self) -> int: 207 return hash(self.components) 208 209 def __repr__(self): 210 return f"Tuple(components={self.components!r})" 211 212 213class List(trace.TraceType, serialization.Serializable): 214 """Represents a list of TraceType objects.""" 215 216 def __init__(self, *components: trace.TraceType): 217 self.components_tuple = Tuple(*components) 218 219 def is_subtype_of(self, other: trace.TraceType) -> bool: 220 if not isinstance(other, List): 221 return False 222 223 return self.components_tuple.is_subtype_of(other.components_tuple) 224 225 def most_specific_common_supertype( 226 self, others: Sequence[trace.TraceType]) -> Optional["Tuple"]: 227 """See base class.""" 228 if not all(isinstance(other, List) for other in others): 229 return None 230 231 supertyped_components_tuple = self.components_tuple.most_specific_common_supertype( 232 [other.components_tuple for other in others]) 233 234 if supertyped_components_tuple is None: 235 return None 236 237 return List(*supertyped_components_tuple.components) 238 239 @classmethod 240 def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedList]: 241 return default_types_pb2.SerializedList 242 243 @classmethod 244 def experimental_from_proto( 245 cls, proto: default_types_pb2.SerializedList) -> "List": 246 return List( 247 *Tuple.experimental_from_proto(proto.components_tuple).components) 248 249 def experimental_as_proto(self) -> default_types_pb2.SerializedList: 250 return default_types_pb2.SerializedList( 251 components_tuple=self.components_tuple.experimental_as_proto()) 252 253 def _placeholder_value(self) -> Any: 254 return list(self.components_tuple._placeholder_value()) # pylint: disable=protected-access 255 256 def __eq__(self, other: Any) -> bool: 257 if not isinstance(other, trace.TraceType): 258 return NotImplemented 259 260 if not isinstance(other, List): 261 return False 262 263 return self.components_tuple == other.components_tuple 264 265 def __hash__(self) -> int: 266 return hash(self.components_tuple) 267 268 def __repr__(self): 269 return f"List(components={self.components_tuple.components!r})" 270 271 272class NamedTuple(trace.TraceType, serialization.Serializable): 273 """Represents a NamedTuple of TraceType objects.""" 274 275 def __init__(self, 276 type_name: str, 277 attribute_names: PythonTuple[str], 278 attributes: PythonTuple[trace.TraceType], 279 placeholder_type: Optional[Type[Any]] = None): 280 self.type_name = type_name 281 self.attribute_names = attribute_names 282 self.attributes = Tuple(*attributes) 283 self._placeholder_type = placeholder_type 284 285 @classmethod 286 def from_type_and_attributes( 287 cls, named_tuple_type: Any, 288 attributes: PythonTuple[trace.TraceType]) -> "NamedTuple": 289 return NamedTuple(named_tuple_type.__name__, named_tuple_type._fields, 290 attributes, named_tuple_type) 291 292 def is_subtype_of(self, other: trace.TraceType) -> bool: 293 if not isinstance(other, NamedTuple): 294 return False 295 296 return (self.type_name == other.type_name and 297 self.attribute_names == other.attribute_names and 298 self.attributes.is_subtype_of(other.attributes)) 299 300 def most_specific_common_supertype( 301 self, others: Sequence[trace.TraceType]) -> Optional["NamedTuple"]: 302 """See base class.""" 303 if not all( 304 isinstance(other, NamedTuple) and self.type_name == other.type_name and 305 self.attribute_names == other.attribute_names for other in others): 306 return None 307 308 supertyped_attributes = self.attributes.most_specific_common_supertype( 309 [other.attributes for other in others]) 310 311 if supertyped_attributes is None: 312 return None 313 314 return NamedTuple(self.type_name, self.attribute_names, 315 supertyped_attributes.components, self._placeholder_type) 316 317 @classmethod 318 def experimental_type_proto( 319 cls) -> Type[default_types_pb2.SerializedNamedTuple]: 320 return default_types_pb2.SerializedNamedTuple 321 322 @classmethod 323 def experimental_from_proto( 324 cls, proto: default_types_pb2.SerializedNamedTuple) -> "NamedTuple": 325 return NamedTuple( 326 proto.type_name, tuple(proto.attribute_names), 327 Tuple.experimental_from_proto(proto.attributes).components) 328 329 def experimental_as_proto(self) -> default_types_pb2.SerializedNamedTuple: 330 return default_types_pb2.SerializedNamedTuple( 331 type_name=self.type_name, 332 attribute_names=list(self.attribute_names), 333 attributes=self.attributes.experimental_as_proto()) 334 335 def _placeholder_value(self) -> Any: 336 if self._placeholder_type is None: 337 # We don't need to trace after serialization so it is not needed but we 338 # can generate a placeholder type using the description if ever needed. 339 raise ValueError("Can not generate placeholder value for NamedTuple with" 340 " unspecified placeholder_type. Note: placeholder_type " 341 "is lost during serialization.") 342 attribute_placeholders = [ 343 attribute._placeholder_value() # pylint: disable=protected-access 344 for attribute in self.attributes.components 345 ] 346 return self._placeholder_type(*attribute_placeholders) 347 348 def __hash__(self) -> int: 349 return hash((self.type_name, self.attribute_names, self.attributes)) 350 351 def __eq__(self, other: Any) -> bool: 352 if not isinstance(other, trace.TraceType): 353 return NotImplemented 354 355 if not isinstance(other, NamedTuple): 356 return False 357 358 return (self.type_name == other.type_name and 359 self.attribute_names == other.attribute_names and 360 self.attributes == other.attributes) 361 362 def __repr__(self): 363 return (f"NamedTuple(type_name={self.type_name}, " 364 f"attribute_names={self.attribute_names}, " 365 f"attributes={self.attributes.components})") 366 367 368class Attrs(trace.TraceType): 369 """Represents a class annotated by attr.s.""" 370 371 def __init__(self, 372 type_name: str, 373 attribute_names: PythonTuple[str], 374 attributes: PythonTuple[trace.TraceType], 375 placeholder_type: Optional[Type[Any]] = None): 376 self.named_attributes = NamedTuple(type_name, attribute_names, attributes) 377 self._placeholder_type = placeholder_type 378 379 @classmethod 380 def from_type_and_attributes( 381 cls, attrs_type: Any, 382 attributes: PythonTuple[trace.TraceType]) -> "Attrs": 383 return Attrs(attrs_type.__name__, 384 tuple(attr.name for attr in attrs_type.__attrs_attrs__), 385 attributes, attrs_type) 386 387 def is_subtype_of(self, other: trace.TraceType) -> bool: 388 if not isinstance(other, Attrs): 389 return False 390 391 return self.named_attributes.is_subtype_of(other.named_attributes) 392 393 def most_specific_common_supertype( 394 self, others: Sequence[trace.TraceType]) -> Optional["Attrs"]: 395 """See base class.""" 396 if not all(isinstance(other, Attrs) for other in others): 397 return None 398 399 supertyped_attributes = self.named_attributes.most_specific_common_supertype( 400 [other.named_attributes for other in others]) 401 402 if supertyped_attributes is None: 403 return None 404 405 return Attrs(self.named_attributes.type_name, 406 self.named_attributes.attribute_names, 407 supertyped_attributes.attributes.components, 408 self._placeholder_type) 409 410 @classmethod 411 def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedAttrs]: 412 return default_types_pb2.SerializedAttrs 413 414 @classmethod 415 def experimental_from_proto( 416 cls, proto: default_types_pb2.SerializedAttrs) -> "Attrs": 417 return Attrs( 418 proto.named_attributes.type_name, 419 tuple(proto.named_attributes.attribute_names), 420 Tuple.experimental_from_proto( 421 proto.named_attributes.attributes).components) 422 423 def experimental_as_proto(self) -> default_types_pb2.SerializedAttrs: 424 return default_types_pb2.SerializedAttrs( 425 named_attributes=self.named_attributes.experimental_as_proto()) 426 427 def _placeholder_value(self) -> Any: 428 if self._placeholder_type is None: 429 # We don't need to trace after serialization so it is not needed but we 430 # can generate a placeholder type using the description if ever needed. 431 raise ValueError("Can not generate placeholder value for Attrs with" 432 " unspecified placeholder_type. Note: placeholder_type " 433 "is lost during serialization.") 434 attribute_placeholders = [ 435 attribute._placeholder_value() # pylint: disable=protected-access 436 for attribute in self.named_attributes.attributes.components 437 ] 438 return self._placeholder_type(*attribute_placeholders) 439 440 def __hash__(self) -> int: 441 return hash(self.named_attributes) 442 443 def __eq__(self, other: Any) -> bool: 444 if not isinstance(other, trace.TraceType): 445 return NotImplemented 446 447 if not isinstance(other, Attrs): 448 return False 449 450 return self.named_attributes == other.named_attributes 451 452 def __repr__(self): 453 return (f"Attrs(type_name={self.named_attributes.type_name}, " 454 f"attribute_names={self.named_attributes.attribute_names}, " 455 f"attributes={self.named_attributes.attributes.components})") 456 457 458class Dict(trace.TraceType, serialization.Serializable): 459 """Represents a dictionary of TraceType objects. 460 461 Attributes: 462 mapping: A mapping from keys to corresponding TraceTypes of the dict values. 463 """ 464 465 def __init__(self, mapping: PythonDict[Hashable, trace.TraceType]): 466 self.mapping = mapping 467 468 def _has_same_structure(self, other): 469 if not isinstance(other, Dict): 470 return False 471 472 return self.mapping.keys() == other.mapping.keys() 473 474 def is_subtype_of(self, other: trace.TraceType) -> bool: 475 """See base class.""" 476 if not self._has_same_structure(other): 477 return False 478 479 # We need all keys to be present because there can be logic relying on 480 # their existence or lack thereof and hence can not guarantee subtype based 481 # on a subset or superset of keys. 482 # Only the tracing code can explicitly check for key dependencies and inform 483 # that decision. 484 return all(self.mapping[key].is_subtype_of(other.mapping[key]) 485 for key in self.mapping) 486 487 def most_specific_common_supertype( 488 self, types: Sequence[trace.TraceType]) -> Optional["Dict"]: 489 """See base class.""" 490 if not all(self._has_same_structure(other) for other in types): 491 return None 492 493 new_mapping = {} 494 for key in self.mapping.keys(): 495 common = self.mapping[key].most_specific_common_supertype( 496 [other.mapping[key] for other in types]) 497 if common is None: 498 return None 499 else: 500 new_mapping[key] = common 501 502 return Dict(new_mapping) 503 504 @classmethod 505 def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedDict]: 506 return default_types_pb2.SerializedDict 507 508 @classmethod 509 def experimental_from_proto( 510 cls, proto: default_types_pb2.SerializedDict) -> "Dict": 511 return Dict({ 512 Literal.experimental_from_proto(k).value: serialization.deserialize(v) 513 for k, v in zip(proto.keys, proto.values) 514 }) 515 516 def experimental_as_proto(self) -> default_types_pb2.SerializedDict: 517 return default_types_pb2.SerializedDict( 518 keys=[Literal(k).experimental_as_proto() for k in self.mapping.keys()], 519 values=[serialization.serialize(v) for v in self.mapping.values()]) 520 521 def _placeholder_value(self) -> Any: 522 return { 523 key: value._placeholder_value() # pylint: disable=protected-access 524 for key, value in self.mapping.items() 525 } 526 527 def __eq__(self, other) -> bool: 528 if not isinstance(other, trace.TraceType): 529 return NotImplemented 530 531 if not isinstance(other, Dict): 532 return False 533 534 return self.mapping == other.mapping 535 536 def __hash__(self) -> int: 537 return hash(frozenset(self.mapping.keys())) 538 539 def __repr__(self): 540 return f"{self.__class__.__name__}(mapping={self.mapping!r})" 541 542 543class Reference(trace.TraceType, serialization.Serializable): 544 """Represents a resource with an identifier. 545 546 Resource identifiers are useful to denote identical resources, that is, 547 resources which are known at compilation time to point to the same thing. 548 This information is useful in automatic control dependencies for instance, 549 where ops using the same resource don't run concurrently. 550 """ 551 552 def __init__(self, base: trace.TraceType, identifier: Hashable): 553 self.base = base 554 self.identifier = identifier 555 556 def is_subtype_of(self, other: trace.TraceType) -> bool: 557 if isinstance(other, Reference) and self.identifier == other.identifier: 558 return self.base.is_subtype_of(other.base) 559 return False 560 561 def most_specific_common_supertype( 562 self, types: Sequence[trace.TraceType]) -> Optional["Reference"]: 563 if all( 564 isinstance(other, Reference) and self.identifier == other.identifier 565 for other in types): 566 base_supertype = self.base.most_specific_common_supertype( 567 [other.base for other in types]) 568 if base_supertype is not None: 569 return Reference(base_supertype, self.identifier) 570 return None 571 572 @classmethod 573 def experimental_type_proto( 574 cls) -> Type[default_types_pb2.SerializedReference]: 575 return default_types_pb2.SerializedReference 576 577 @classmethod 578 def experimental_from_proto( 579 cls, proto: default_types_pb2.SerializedReference) -> "Reference": 580 return Reference( 581 serialization.deserialize(proto.base), 582 Literal.experimental_from_proto(proto.identifier).value) 583 584 def experimental_as_proto(self) -> default_types_pb2.SerializedReference: 585 return default_types_pb2.SerializedReference( 586 identifier=Literal(self.identifier).experimental_as_proto(), 587 base=serialization.serialize(self.base)) 588 589 def _placeholder_value(self) -> Any: 590 return self.base._placeholder_value() # pylint: disable=protected-access 591 592 def __eq__(self, other: Any) -> bool: 593 if not isinstance(other, trace.TraceType): 594 return NotImplemented 595 596 return isinstance( 597 other, Reference 598 ) and self.identifier == other.identifier and self.base == other.base 599 600 def __hash__(self) -> int: 601 return hash((self.identifier, self.base)) 602 603 def __repr__(self): 604 return (f"{self.__class__.__name__}(base={self.base!r}, " 605 f"identifier={self.identifier!r})") 606 607serialization.register_serializable(Literal) 608serialization.register_serializable(Tuple) 609serialization.register_serializable(List) 610serialization.register_serializable(NamedTuple) 611serialization.register_serializable(Attrs) 612serialization.register_serializable(Dict) 613serialization.register_serializable(Reference) 614