xref: /aosp_15_r20/external/tensorflow/tensorflow/core/function/trace_type/trace_type_builder.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utitiles for Cache Key generation based on Function Trace Type."""
16
17import collections.abc
18from typing import Any, Callable, Hashable
19import weakref
20
21from tensorflow.core.function.trace_type import default_types
22from tensorflow.core.function.trace_type import util
23from tensorflow.python.types import trace
24
25
26class WeakrefDeletionObserver:
27  """An observer for the event of deleting a weakref.
28
29  This allows users of FunctionTraceType to be notified when an instance which
30  depends on a weakref becomes invalid by the deletion of the weakref. In
31  particular, tf.function caches can use this mechanism to clear the cache of
32  keys that are no longer valid.
33
34  We use the observer pattern and not just basic callbacks because the keys
35  are typically created before they are used by the cache.
36  """
37
38  def __init__(self):
39    self._triggered = False
40    self._callables = []
41
42  def add_listener(self, on_delete: Callable[[], None]):
43    if self._triggered:
44      on_delete()
45    else:
46      self._callables.append(on_delete)
47
48  def weakref_deleted(self):
49    self._triggered = True
50    for c in self._callables:
51      c()
52
53  def __call__(self, _):
54    """Call handler for convenience of use with weakref."""
55    self.weakref_deleted()
56
57
58class InternalTracingContext(trace.TracingContext):
59  """Container for variables and flags shared across TraceType generation."""
60
61  def __init__(self):
62    self._deletion_observer = WeakrefDeletionObserver()
63    self._global_to_local_id = {}
64
65  # TODO(b/202772221): Consider dropping after alias pattern matching is
66  # supported.
67  def make_reference_type(self, base_type: trace.TraceType,
68                          global_id: Hashable) -> trace.TraceType:
69    if global_id not in self._global_to_local_id:
70      self._global_to_local_id[global_id] = len(self._global_to_local_id)
71
72    return default_types.Reference(base_type,
73                                   self._global_to_local_id[global_id])
74
75  @property
76  def deletion_observer(self):
77    """Returns a functor which invalidates the current key when called."""
78    return self._deletion_observer
79
80
81def from_object(obj: Any,
82                context: trace.TracingContext = None) -> trace.TraceType:
83  """Returns a TraceType corresponding to the object based on the context.
84
85  Args:
86    obj: The object to generate a TraceType for.
87    context: The TracingContext to be shared during protocol calls.
88
89  Returns:
90    A TraceType object representing the given object.
91  """
92
93  if context is None:
94    context = InternalTracingContext()
95
96  if isinstance(obj, trace.SupportsTracingProtocol):
97    return obj.__tf_tracing_type__(context)
98
99  if hasattr(obj, "__wrapped__"):
100    return from_object(obj.__wrapped__, context)
101
102  if isinstance(obj, list):
103    return default_types.List(*(from_object(c, context) for c in obj))
104
105  if isinstance(obj, tuple):
106    if util.is_namedtuple(obj):
107      named_tuple_type = type(obj)
108      return default_types.NamedTuple.from_type_and_attributes(
109          named_tuple_type, tuple(from_object(c, context) for c in obj))
110    else:
111      return default_types.Tuple(*(from_object(c, context) for c in obj))
112
113  if isinstance(obj, collections.abc.Mapping):
114    return default_types.Dict({k: from_object(obj[k], context) for k in obj})
115
116  if util.is_attrs(obj):
117    return default_types.Attrs.from_type_and_attributes(
118        type(obj),
119        tuple(
120            from_object(getattr(obj, a.name), context)
121            for a in obj.__attrs_attrs__))
122
123  try:
124    ref = weakref.ref(obj, context.deletion_observer)
125    if ref is None:
126      raise TypeError(
127          f"Deleted objects are not valid tf.function arguments, Got {obj!r}")
128    else:
129      return default_types.Weakref(ref)
130  except TypeError:
131    try:
132      return default_types.Literal(obj)
133    except:
134      raise TypeError(
135          f"Python object could not be represented through the generic tracing "
136          f"type. Consider implementing the Tracing Protocol for it: {obj!r}")
137