xref: /aosp_15_r20/external/tensorflow/tensorflow/core/function/trace_type/default_types.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TraceType implementations for common Python types."""
16
17from typing import Any, Hashable, Optional, Sequence, Type
18from typing import Dict as PythonDict
19from typing import Tuple as PythonTuple
20import weakref
21
22from tensorflow.core.function.trace_type import default_types_pb2
23from tensorflow.core.function.trace_type import serialization
24from tensorflow.python.types import trace
25
26
27class Literal(trace.TraceType, serialization.Serializable):
28  """Represents a Literal type like bool, int or string."""
29
30  def __init__(self, value: Any):
31    self.value = value
32    self._value_hash = hash(value)
33
34  def is_subtype_of(self, other: trace.TraceType) -> bool:
35    return self == other
36
37  def most_specific_common_supertype(
38      self, types: Sequence[trace.TraceType]) -> Optional["Literal"]:
39    return self if all(self == other for other in types) else None
40
41  @classmethod
42  def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedLiteral]:
43    return default_types_pb2.SerializedLiteral
44
45  @classmethod
46  def experimental_from_proto(
47      cls, proto: default_types_pb2.SerializedLiteral) -> "Literal":
48    if proto.HasField("bool_value"):
49      return Literal(proto.bool_value)
50
51    if proto.HasField("int_value"):
52      return Literal(proto.int_value)
53
54    if proto.HasField("float_value"):
55      return Literal(proto.float_value)
56
57    if proto.HasField("str_value"):
58      return Literal(proto.str_value)
59
60    if proto.HasField("none_value"):
61      return Literal(None)
62
63    raise ValueError("Malformed Literal proto can not be deserialized")
64
65  def experimental_as_proto(self) -> default_types_pb2.SerializedLiteral:
66    if isinstance(self.value, bool):
67      return default_types_pb2.SerializedLiteral(bool_value=self.value)
68
69    if isinstance(self.value, int):
70      return default_types_pb2.SerializedLiteral(int_value=self.value)
71
72    if isinstance(self.value, float):
73      return default_types_pb2.SerializedLiteral(float_value=self.value)
74
75    if isinstance(self.value, str):
76      return default_types_pb2.SerializedLiteral(str_value=self.value)
77
78    if self.value is None:
79      return default_types_pb2.SerializedLiteral(
80          none_value=default_types_pb2.SerializedLiteral.NoneValue())
81
82    raise ValueError("Can not serialize Literal of type " +
83                     type(self.value).__name__)
84
85  def _placeholder_value(self) -> Any:
86    return self.value
87
88  def __eq__(self, other) -> bool:
89    if not isinstance(other, trace.TraceType):
90      return NotImplemented
91
92    return isinstance(other, Literal) and self.value == other.value
93
94  def __hash__(self) -> int:
95    return self._value_hash
96
97  def __repr__(self):
98    return f"{self.__class__.__name__}(value={self.value!r})"
99
100
101class Weakref(trace.TraceType):
102  """Represents weakref of an arbitrary Python object.
103
104  When a function argument is a custom class, instead of making a copy of it
105  just for the sake of function cache, a weakref is instead kept to save memory.
106  """
107
108  def __init__(self, ref: weakref.ReferenceType):
109    self._ref = ref
110    self._ref_hash = hash(ref)
111
112  def is_subtype_of(self, other: trace.TraceType) -> bool:
113    return self == other
114
115  def most_specific_common_supertype(
116      self, types: Sequence[trace.TraceType]) -> Optional["Weakref"]:
117    return self if all(self == other for other in types) else None
118
119  def _placeholder_value(self) -> Any:
120    return self._ref()
121
122  def __eq__(self, other):
123    if not isinstance(other, trace.TraceType):
124      return NotImplemented
125
126    if not isinstance(other, Weakref):
127      return False
128
129    if self._ref() is None or other._ref() is None:
130      return False
131
132    if self._ref() is other._ref():
133      return True
134
135    return self._ref == other._ref
136
137  def __hash__(self):
138    return self._ref_hash
139
140  def __repr__(self):
141    return f"{self.__class__.__name__}(ref={self._ref!r})"
142
143
144class Tuple(trace.TraceType, serialization.Serializable):
145  """Represents a tuple of TraceType objects."""
146
147  def __init__(self, *components: trace.TraceType):
148    self.components = components
149
150  def is_subtype_of(self, other: trace.TraceType) -> bool:
151    if (not isinstance(other, Tuple) or
152        len(self.components) != len(other.components)):
153      return False
154
155    return all(
156        self_component.is_subtype_of(other_component) for self_component,
157        other_component in zip(self.components, other.components))
158
159  def most_specific_common_supertype(
160      self, others: Sequence[trace.TraceType]) -> Optional["Tuple"]:
161    """See base class."""
162    if not all(
163        isinstance(other, Tuple) and
164        len(self.components) == len(other.components) for other in others):
165      return None
166
167    supertyped_components = []
168    for i, component in enumerate(self.components):
169      supertyped_component = component.most_specific_common_supertype(
170          [other.components[i] for other in others])
171      if supertyped_component is None:
172        return None
173      supertyped_components.append(supertyped_component)
174
175    return Tuple(*supertyped_components)
176
177  @classmethod
178  def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedTuple]:
179    return default_types_pb2.SerializedTuple
180
181  @classmethod
182  def experimental_from_proto(
183      cls, proto: default_types_pb2.SerializedTuple) -> "Tuple":
184    return Tuple(*[serialization.deserialize(c) for c in proto.components])
185
186  def experimental_as_proto(self) -> default_types_pb2.SerializedTuple:
187    return default_types_pb2.SerializedTuple(
188        components=[serialization.serialize(c) for c in self.components])
189
190  def _placeholder_value(self) -> Any:
191    components = [
192        component._placeholder_value()  # pylint: disable=protected-access
193        for component in self.components
194    ]
195    return tuple(components)
196
197  def __eq__(self, other: Any) -> bool:
198    if not isinstance(other, trace.TraceType):
199      return NotImplemented
200
201    if not isinstance(other, Tuple):
202      return False
203
204    return self.components == other.components
205
206  def __hash__(self) -> int:
207    return hash(self.components)
208
209  def __repr__(self):
210    return f"Tuple(components={self.components!r})"
211
212
213class List(trace.TraceType, serialization.Serializable):
214  """Represents a list of TraceType objects."""
215
216  def __init__(self, *components: trace.TraceType):
217    self.components_tuple = Tuple(*components)
218
219  def is_subtype_of(self, other: trace.TraceType) -> bool:
220    if not isinstance(other, List):
221      return False
222
223    return self.components_tuple.is_subtype_of(other.components_tuple)
224
225  def most_specific_common_supertype(
226      self, others: Sequence[trace.TraceType]) -> Optional["Tuple"]:
227    """See base class."""
228    if not all(isinstance(other, List) for other in others):
229      return None
230
231    supertyped_components_tuple = self.components_tuple.most_specific_common_supertype(
232        [other.components_tuple for other in others])
233
234    if supertyped_components_tuple is None:
235      return None
236
237    return List(*supertyped_components_tuple.components)
238
239  @classmethod
240  def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedList]:
241    return default_types_pb2.SerializedList
242
243  @classmethod
244  def experimental_from_proto(
245      cls, proto: default_types_pb2.SerializedList) -> "List":
246    return List(
247        *Tuple.experimental_from_proto(proto.components_tuple).components)
248
249  def experimental_as_proto(self) -> default_types_pb2.SerializedList:
250    return default_types_pb2.SerializedList(
251        components_tuple=self.components_tuple.experimental_as_proto())
252
253  def _placeholder_value(self) -> Any:
254    return list(self.components_tuple._placeholder_value())  # pylint: disable=protected-access
255
256  def __eq__(self, other: Any) -> bool:
257    if not isinstance(other, trace.TraceType):
258      return NotImplemented
259
260    if not isinstance(other, List):
261      return False
262
263    return self.components_tuple == other.components_tuple
264
265  def __hash__(self) -> int:
266    return hash(self.components_tuple)
267
268  def __repr__(self):
269    return f"List(components={self.components_tuple.components!r})"
270
271
272class NamedTuple(trace.TraceType, serialization.Serializable):
273  """Represents a NamedTuple of TraceType objects."""
274
275  def __init__(self,
276               type_name: str,
277               attribute_names: PythonTuple[str],
278               attributes: PythonTuple[trace.TraceType],
279               placeholder_type: Optional[Type[Any]] = None):
280    self.type_name = type_name
281    self.attribute_names = attribute_names
282    self.attributes = Tuple(*attributes)
283    self._placeholder_type = placeholder_type
284
285  @classmethod
286  def from_type_and_attributes(
287      cls, named_tuple_type: Any,
288      attributes: PythonTuple[trace.TraceType]) -> "NamedTuple":
289    return NamedTuple(named_tuple_type.__name__, named_tuple_type._fields,
290                      attributes, named_tuple_type)
291
292  def is_subtype_of(self, other: trace.TraceType) -> bool:
293    if not isinstance(other, NamedTuple):
294      return False
295
296    return (self.type_name == other.type_name and
297            self.attribute_names == other.attribute_names and
298            self.attributes.is_subtype_of(other.attributes))
299
300  def most_specific_common_supertype(
301      self, others: Sequence[trace.TraceType]) -> Optional["NamedTuple"]:
302    """See base class."""
303    if not all(
304        isinstance(other, NamedTuple) and self.type_name == other.type_name and
305        self.attribute_names == other.attribute_names for other in others):
306      return None
307
308    supertyped_attributes = self.attributes.most_specific_common_supertype(
309        [other.attributes for other in others])
310
311    if supertyped_attributes is None:
312      return None
313
314    return NamedTuple(self.type_name, self.attribute_names,
315                      supertyped_attributes.components, self._placeholder_type)
316
317  @classmethod
318  def experimental_type_proto(
319      cls) -> Type[default_types_pb2.SerializedNamedTuple]:
320    return default_types_pb2.SerializedNamedTuple
321
322  @classmethod
323  def experimental_from_proto(
324      cls, proto: default_types_pb2.SerializedNamedTuple) -> "NamedTuple":
325    return NamedTuple(
326        proto.type_name, tuple(proto.attribute_names),
327        Tuple.experimental_from_proto(proto.attributes).components)
328
329  def experimental_as_proto(self) -> default_types_pb2.SerializedNamedTuple:
330    return default_types_pb2.SerializedNamedTuple(
331        type_name=self.type_name,
332        attribute_names=list(self.attribute_names),
333        attributes=self.attributes.experimental_as_proto())
334
335  def _placeholder_value(self) -> Any:
336    if self._placeholder_type is None:
337      # We don't need to trace after serialization so it is not needed but we
338      # can generate a placeholder type using the description if ever needed.
339      raise ValueError("Can not generate placeholder value for NamedTuple with"
340                       " unspecified placeholder_type. Note: placeholder_type "
341                       "is lost during serialization.")
342    attribute_placeholders = [
343        attribute._placeholder_value()  # pylint: disable=protected-access
344        for attribute in self.attributes.components
345    ]
346    return self._placeholder_type(*attribute_placeholders)
347
348  def __hash__(self) -> int:
349    return hash((self.type_name, self.attribute_names, self.attributes))
350
351  def __eq__(self, other: Any) -> bool:
352    if not isinstance(other, trace.TraceType):
353      return NotImplemented
354
355    if not isinstance(other, NamedTuple):
356      return False
357
358    return (self.type_name == other.type_name and
359            self.attribute_names == other.attribute_names and
360            self.attributes == other.attributes)
361
362  def __repr__(self):
363    return (f"NamedTuple(type_name={self.type_name}, "
364            f"attribute_names={self.attribute_names}, "
365            f"attributes={self.attributes.components})")
366
367
368class Attrs(trace.TraceType):
369  """Represents a class annotated by attr.s."""
370
371  def __init__(self,
372               type_name: str,
373               attribute_names: PythonTuple[str],
374               attributes: PythonTuple[trace.TraceType],
375               placeholder_type: Optional[Type[Any]] = None):
376    self.named_attributes = NamedTuple(type_name, attribute_names, attributes)
377    self._placeholder_type = placeholder_type
378
379  @classmethod
380  def from_type_and_attributes(
381      cls, attrs_type: Any,
382      attributes: PythonTuple[trace.TraceType]) -> "Attrs":
383    return Attrs(attrs_type.__name__,
384                 tuple(attr.name for attr in attrs_type.__attrs_attrs__),
385                 attributes, attrs_type)
386
387  def is_subtype_of(self, other: trace.TraceType) -> bool:
388    if not isinstance(other, Attrs):
389      return False
390
391    return self.named_attributes.is_subtype_of(other.named_attributes)
392
393  def most_specific_common_supertype(
394      self, others: Sequence[trace.TraceType]) -> Optional["Attrs"]:
395    """See base class."""
396    if not all(isinstance(other, Attrs) for other in others):
397      return None
398
399    supertyped_attributes = self.named_attributes.most_specific_common_supertype(
400        [other.named_attributes for other in others])
401
402    if supertyped_attributes is None:
403      return None
404
405    return Attrs(self.named_attributes.type_name,
406                 self.named_attributes.attribute_names,
407                 supertyped_attributes.attributes.components,
408                 self._placeholder_type)
409
410  @classmethod
411  def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedAttrs]:
412    return default_types_pb2.SerializedAttrs
413
414  @classmethod
415  def experimental_from_proto(
416      cls, proto: default_types_pb2.SerializedAttrs) -> "Attrs":
417    return Attrs(
418        proto.named_attributes.type_name,
419        tuple(proto.named_attributes.attribute_names),
420        Tuple.experimental_from_proto(
421            proto.named_attributes.attributes).components)
422
423  def experimental_as_proto(self) -> default_types_pb2.SerializedAttrs:
424    return default_types_pb2.SerializedAttrs(
425        named_attributes=self.named_attributes.experimental_as_proto())
426
427  def _placeholder_value(self) -> Any:
428    if self._placeholder_type is None:
429      # We don't need to trace after serialization so it is not needed but we
430      # can generate a placeholder type using the description if ever needed.
431      raise ValueError("Can not generate placeholder value for Attrs with"
432                       " unspecified placeholder_type. Note: placeholder_type "
433                       "is lost during serialization.")
434    attribute_placeholders = [
435        attribute._placeholder_value()  # pylint: disable=protected-access
436        for attribute in self.named_attributes.attributes.components
437    ]
438    return self._placeholder_type(*attribute_placeholders)
439
440  def __hash__(self) -> int:
441    return hash(self.named_attributes)
442
443  def __eq__(self, other: Any) -> bool:
444    if not isinstance(other, trace.TraceType):
445      return NotImplemented
446
447    if not isinstance(other, Attrs):
448      return False
449
450    return self.named_attributes == other.named_attributes
451
452  def __repr__(self):
453    return (f"Attrs(type_name={self.named_attributes.type_name}, "
454            f"attribute_names={self.named_attributes.attribute_names}, "
455            f"attributes={self.named_attributes.attributes.components})")
456
457
458class Dict(trace.TraceType, serialization.Serializable):
459  """Represents a dictionary of TraceType objects.
460
461  Attributes:
462    mapping: A mapping from keys to corresponding TraceTypes of the dict values.
463  """
464
465  def __init__(self, mapping: PythonDict[Hashable, trace.TraceType]):
466    self.mapping = mapping
467
468  def _has_same_structure(self, other):
469    if not isinstance(other, Dict):
470      return False
471
472    return self.mapping.keys() == other.mapping.keys()
473
474  def is_subtype_of(self, other: trace.TraceType) -> bool:
475    """See base class."""
476    if not self._has_same_structure(other):
477      return False
478
479    # We need all keys to be present because there can be logic relying on
480    # their existence or lack thereof and hence can not guarantee subtype based
481    # on a subset or superset of keys.
482    # Only the tracing code can explicitly check for key dependencies and inform
483    # that decision.
484    return all(self.mapping[key].is_subtype_of(other.mapping[key])
485               for key in self.mapping)
486
487  def most_specific_common_supertype(
488      self, types: Sequence[trace.TraceType]) -> Optional["Dict"]:
489    """See base class."""
490    if not all(self._has_same_structure(other) for other in types):
491      return None
492
493    new_mapping = {}
494    for key in self.mapping.keys():
495      common = self.mapping[key].most_specific_common_supertype(
496          [other.mapping[key] for other in types])
497      if common is None:
498        return None
499      else:
500        new_mapping[key] = common
501
502    return Dict(new_mapping)
503
504  @classmethod
505  def experimental_type_proto(cls) -> Type[default_types_pb2.SerializedDict]:
506    return default_types_pb2.SerializedDict
507
508  @classmethod
509  def experimental_from_proto(
510      cls, proto: default_types_pb2.SerializedDict) -> "Dict":
511    return Dict({
512        Literal.experimental_from_proto(k).value: serialization.deserialize(v)
513        for k, v in zip(proto.keys, proto.values)
514    })
515
516  def experimental_as_proto(self) -> default_types_pb2.SerializedDict:
517    return default_types_pb2.SerializedDict(
518        keys=[Literal(k).experimental_as_proto() for k in self.mapping.keys()],
519        values=[serialization.serialize(v) for v in self.mapping.values()])
520
521  def _placeholder_value(self) -> Any:
522    return {
523        key: value._placeholder_value()  # pylint: disable=protected-access
524        for key, value in self.mapping.items()
525    }
526
527  def __eq__(self, other) -> bool:
528    if not isinstance(other, trace.TraceType):
529      return NotImplemented
530
531    if not isinstance(other, Dict):
532      return False
533
534    return self.mapping == other.mapping
535
536  def __hash__(self) -> int:
537    return hash(frozenset(self.mapping.keys()))
538
539  def __repr__(self):
540    return f"{self.__class__.__name__}(mapping={self.mapping!r})"
541
542
543class Reference(trace.TraceType, serialization.Serializable):
544  """Represents a resource with an identifier.
545
546  Resource identifiers are useful to denote identical resources, that is,
547  resources which are known at compilation time to point to the same thing.
548  This information is useful in automatic control dependencies for instance,
549  where ops using the same resource don't run concurrently.
550  """
551
552  def __init__(self, base: trace.TraceType, identifier: Hashable):
553    self.base = base
554    self.identifier = identifier
555
556  def is_subtype_of(self, other: trace.TraceType) -> bool:
557    if isinstance(other, Reference) and self.identifier == other.identifier:
558      return self.base.is_subtype_of(other.base)
559    return False
560
561  def most_specific_common_supertype(
562      self, types: Sequence[trace.TraceType]) -> Optional["Reference"]:
563    if all(
564        isinstance(other, Reference) and self.identifier == other.identifier
565        for other in types):
566      base_supertype = self.base.most_specific_common_supertype(
567          [other.base for other in types])
568      if base_supertype is not None:
569        return Reference(base_supertype, self.identifier)
570    return None
571
572  @classmethod
573  def experimental_type_proto(
574      cls) -> Type[default_types_pb2.SerializedReference]:
575    return default_types_pb2.SerializedReference
576
577  @classmethod
578  def experimental_from_proto(
579      cls, proto: default_types_pb2.SerializedReference) -> "Reference":
580    return Reference(
581        serialization.deserialize(proto.base),
582        Literal.experimental_from_proto(proto.identifier).value)
583
584  def experimental_as_proto(self) -> default_types_pb2.SerializedReference:
585    return default_types_pb2.SerializedReference(
586        identifier=Literal(self.identifier).experimental_as_proto(),
587        base=serialization.serialize(self.base))
588
589  def _placeholder_value(self) -> Any:
590    return self.base._placeholder_value()  # pylint: disable=protected-access
591
592  def __eq__(self, other: Any) -> bool:
593    if not isinstance(other, trace.TraceType):
594      return NotImplemented
595
596    return isinstance(
597        other, Reference
598    ) and self.identifier == other.identifier and self.base == other.base
599
600  def __hash__(self) -> int:
601    return hash((self.identifier, self.base))
602
603  def __repr__(self):
604    return (f"{self.__class__.__name__}(base={self.base!r}, "
605            f"identifier={self.identifier!r})")
606
607serialization.register_serializable(Literal)
608serialization.register_serializable(Tuple)
609serialization.register_serializable(List)
610serialization.register_serializable(NamedTuple)
611serialization.register_serializable(Attrs)
612serialization.register_serializable(Dict)
613serialization.register_serializable(Reference)
614