1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""tf.function tracing types. 16 17See `core.GenericFunction` and `core.ConcreteFunction`. 18 19`GenericFunction` assigns types to call arguments, forming a signature. 20Function signatures are used to match arguments to `ConcreteFunction`s. 21For example, when a new `ConcreteFunction` is traced, it is assigned a 22the signature of the arguments it was traced with. Subsequent call arguments 23which match its signature will be dispatched to the same `ConcreteFunction`. 24If no `ConcreteFunction` with a matching signature is found, a new one may be 25traced (a process known as retracing). 26""" 27 28import abc 29from typing import Optional, Sequence 30from typing_extensions import Protocol 31from typing_extensions import runtime_checkable 32from tensorflow.python.util.tf_export import tf_export 33from tensorflow.tools.docs import doc_controls 34 35 36@tf_export("types.experimental.TraceType", v1=[]) 37class TraceType(metaclass=abc.ABCMeta): 38 """Represents the type of object(s) for tf.function tracing purposes. 39 40 `TraceType` is an abstract class that other classes might inherit from to 41 provide information regarding associated class(es) for the purposes of 42 tf.function tracing. The typing logic provided through this mechanism will be 43 used to make decisions regarding usage of cached concrete functions and 44 retracing. 45 46 For example, if we have the following tf.function and classes: 47 ```python 48 @tf.function 49 def get_mixed_flavor(fruit_a, fruit_b): 50 return fruit_a.flavor + fruit_b.flavor 51 52 class Fruit: 53 flavor = tf.constant([0, 0]) 54 55 class Apple(Fruit): 56 flavor = tf.constant([1, 2]) 57 58 class Mango(Fruit): 59 flavor = tf.constant([3, 4]) 60 ``` 61 62 tf.function does not know when to re-use an existing concrete function in 63 regards to the `Fruit` class so naively it retraces for every new instance. 64 ```python 65 get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function 66 get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again 67 ``` 68 69 However, we, as the designers of the `Fruit` class, know that each subclass 70 has a fixed flavor and we can reuse an existing traced concrete function if 71 it was the same subclass. Avoiding such unnecessary tracing of concrete 72 functions can have significant performance benefits. 73 74 ```python 75 class FruitTraceType(tf.types.experimental.TraceType): 76 def __init__(self, fruit_type): 77 self.fruit_type = fruit_type 78 79 def is_subtype_of(self, other): 80 return (type(other) is FruitTraceType and 81 self.fruit_type is other.fruit_type) 82 83 def most_specific_common_supertype(self, others): 84 return self if all(self == other for other in others) else None 85 86 class Fruit: 87 88 def __tf_tracing_type__(self, context): 89 return FruitTraceType(type(self)) 90 ``` 91 92 Now if we try calling it again: 93 ```python 94 get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function 95 get_mixed_flavor(Apple(), Mango()) # Re-uses the traced concrete function 96 ``` 97 """ 98 99 @abc.abstractmethod 100 def is_subtype_of(self, other: "TraceType") -> bool: 101 """Returns True if `self` is a subtype of `other`. 102 103 For example, `tf.function` uses subtyping for dispatch: 104 if `a.is_subtype_of(b)` is True, then an argument of `TraceType` 105 `a` can be used as argument to a `ConcreteFunction` traced with an 106 a `TraceType` `b`. 107 108 Args: 109 other: A TraceType object to be compared against. 110 111 Example: 112 113 ```python 114 class Dimension(TraceType): 115 def __init__(self, value: Optional[int]): 116 self.value = value 117 118 def is_subtype_of(self, other): 119 # Either the value is the same or other has a generalized value that 120 # can represent any specific ones. 121 return (self.value == other.value) or (other.value is None) 122 ``` 123 """ 124 125 @abc.abstractmethod 126 def most_specific_common_supertype( 127 self, others: Sequence["TraceType"]) -> Optional["TraceType"]: 128 """Returns the most specific supertype of `self` and `others`, if exists. 129 130 The returned `TraceType` is a supertype of `self` and `others`, that is, 131 they are all subtypes (see `is_subtype_of`) of it. 132 It is also most specific, that is, there it has no subtype that is also 133 a common supertype of `self` and `others`. 134 135 If `self` and `others` have no common supertype, this returns `None`. 136 137 Args: 138 others: A sequence of TraceTypes. 139 140 Example: 141 ```python 142 class Dimension(TraceType): 143 def __init__(self, value: Optional[int]): 144 self.value = value 145 146 def most_specific_common_supertype(self, other): 147 # Either the value is the same or other has a generalized value that 148 # can represent any specific ones. 149 if self.value == other.value: 150 return self.value 151 else: 152 return Dimension(None) 153 ``` 154 """ 155 156 # TODO(b/221309709): Polish into a stable placeholder_value. 157 @doc_controls.do_not_doc_inheritable 158 def _placeholder_value(self): 159 """Creates a placeholder for tracing. 160 161 Often it is more useful to trace with a placeholder value than an actual 162 one. For example, a placeholder value can represent multiple different 163 actual values. This means that the trace generated with that placeholder 164 value is more general and reusable which saves expensive retracing. 165 166 For the `Fruit` example shared above, implementing: 167 168 ```python 169 class FruitTraceType: 170 def _placeholder_value(): 171 return Fruit() 172 ``` 173 instructs tf.function to trace with the `Fruit()` objects 174 instead of the actual `Apple()` and `Mango()` objects when it receives a 175 call to `get_mixed_flavor(Apple(), Mango())`. For example, Tensor arguments 176 are replaced with Tensors of similar shape and dtype, output from 177 a tf.Placeholder op. 178 179 More generally, placeholder values are the arguments of a tf.function, 180 as seen from the function's body: 181 ```python 182 @tf.function 183 def foo(x): 184 # Here `x` can be the placeholder value 185 ... 186 187 foo(x) # Here `x` is the actual value 188 ``` 189 """ 190 raise NotImplementedError 191 192 @abc.abstractmethod 193 def __hash__(self) -> int: 194 pass 195 196 @abc.abstractmethod 197 def __eq__(self, other) -> bool: 198 pass 199 200 201@tf_export("types.experimental.TracingContext", v1=[]) 202class TracingContext(metaclass=abc.ABCMeta): 203 """Contains information scoped to the tracing of multiple objects. 204 205 `TracingContext` is a container class for flags and variables that have 206 any kind of influence on the tracing behaviour of the class implementing 207 the __tf_tracing_type__. This context will be shared across all 208 __tf_tracing_type__ calls while constructing the TraceType for a particular 209 set of objects. 210 """ 211 pass 212 213 214@runtime_checkable 215class SupportsTracingProtocol(Protocol): 216 """A protocol allowing custom classes to control tf.function retracing.""" 217 218 @doc_controls.doc_private 219 @abc.abstractmethod 220 def __tf_tracing_type__(self, context: TracingContext) -> TraceType: 221 """Returns the tracing type of this object. 222 223 The tracing type is used to build the signature of a tf.function 224 when traced, and to match arguments with existing signatures. 225 When a Function object is called, tf.function looks at the tracing type 226 of the call arguments. If an existing signature of matching type exists, 227 it will be used. Otherwise, a new function is traced, and its signature 228 will use the tracing type of the call arguments. 229 230 Args: 231 context: a context object created for each function call for tracking 232 information about the call arguments as a whole 233 Returns: 234 The tracing type of this object. 235 """ 236 237# TODO(b/219556836): Direct tf_export decorator adds non-method members to the 238# Protocol which breaks @runtime_checkable since it does not support them. 239tf_export( 240 "types.experimental.SupportsTracingProtocol", 241 v1=[]).export_constant(__name__, "SupportsTracingProtocol") 242