1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utitiles for Cache Key generation based on Function Trace Type.""" 16 17import collections.abc 18from typing import Any, Callable, Hashable 19import weakref 20 21from tensorflow.core.function.trace_type import default_types 22from tensorflow.core.function.trace_type import util 23from tensorflow.python.types import trace 24 25 26class WeakrefDeletionObserver: 27 """An observer for the event of deleting a weakref. 28 29 This allows users of FunctionTraceType to be notified when an instance which 30 depends on a weakref becomes invalid by the deletion of the weakref. In 31 particular, tf.function caches can use this mechanism to clear the cache of 32 keys that are no longer valid. 33 34 We use the observer pattern and not just basic callbacks because the keys 35 are typically created before they are used by the cache. 36 """ 37 38 def __init__(self): 39 self._triggered = False 40 self._callables = [] 41 42 def add_listener(self, on_delete: Callable[[], None]): 43 if self._triggered: 44 on_delete() 45 else: 46 self._callables.append(on_delete) 47 48 def weakref_deleted(self): 49 self._triggered = True 50 for c in self._callables: 51 c() 52 53 def __call__(self, _): 54 """Call handler for convenience of use with weakref.""" 55 self.weakref_deleted() 56 57 58class InternalTracingContext(trace.TracingContext): 59 """Container for variables and flags shared across TraceType generation.""" 60 61 def __init__(self): 62 self._deletion_observer = WeakrefDeletionObserver() 63 self._global_to_local_id = {} 64 65 # TODO(b/202772221): Consider dropping after alias pattern matching is 66 # supported. 67 def make_reference_type(self, base_type: trace.TraceType, 68 global_id: Hashable) -> trace.TraceType: 69 if global_id not in self._global_to_local_id: 70 self._global_to_local_id[global_id] = len(self._global_to_local_id) 71 72 return default_types.Reference(base_type, 73 self._global_to_local_id[global_id]) 74 75 @property 76 def deletion_observer(self): 77 """Returns a functor which invalidates the current key when called.""" 78 return self._deletion_observer 79 80 81def from_object(obj: Any, 82 context: trace.TracingContext = None) -> trace.TraceType: 83 """Returns a TraceType corresponding to the object based on the context. 84 85 Args: 86 obj: The object to generate a TraceType for. 87 context: The TracingContext to be shared during protocol calls. 88 89 Returns: 90 A TraceType object representing the given object. 91 """ 92 93 if context is None: 94 context = InternalTracingContext() 95 96 if isinstance(obj, trace.SupportsTracingProtocol): 97 return obj.__tf_tracing_type__(context) 98 99 if hasattr(obj, "__wrapped__"): 100 return from_object(obj.__wrapped__, context) 101 102 if isinstance(obj, list): 103 return default_types.List(*(from_object(c, context) for c in obj)) 104 105 if isinstance(obj, tuple): 106 if util.is_namedtuple(obj): 107 named_tuple_type = type(obj) 108 return default_types.NamedTuple.from_type_and_attributes( 109 named_tuple_type, tuple(from_object(c, context) for c in obj)) 110 else: 111 return default_types.Tuple(*(from_object(c, context) for c in obj)) 112 113 if isinstance(obj, collections.abc.Mapping): 114 return default_types.Dict({k: from_object(obj[k], context) for k in obj}) 115 116 if util.is_attrs(obj): 117 return default_types.Attrs.from_type_and_attributes( 118 type(obj), 119 tuple( 120 from_object(getattr(obj, a.name), context) 121 for a in obj.__attrs_attrs__)) 122 123 try: 124 ref = weakref.ref(obj, context.deletion_observer) 125 if ref is None: 126 raise TypeError( 127 f"Deleted objects are not valid tf.function arguments, Got {obj!r}") 128 else: 129 return default_types.Weakref(ref) 130 except TypeError: 131 try: 132 return default_types.Literal(obj) 133 except: 134 raise TypeError( 135 f"Python object could not be represented through the generic tracing " 136 f"type. Consider implementing the Tracing Protocol for it: {obj!r}") 137