xref: /aosp_15_r20/external/tensorflow/tensorflow/python/types/trace.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""tf.function tracing types.
16
17See `core.GenericFunction` and `core.ConcreteFunction`.
18
19`GenericFunction` assigns types to call arguments, forming a signature.
20Function signatures are used to match arguments to `ConcreteFunction`s.
21For example, when a new `ConcreteFunction` is traced, it is assigned a
22the signature of the arguments it was traced with. Subsequent call arguments
23which match its signature will be dispatched to the same `ConcreteFunction`.
24If no `ConcreteFunction` with a matching signature is found, a new one may be
25traced (a process known as retracing).
26"""
27
28import abc
29from typing import Optional, Sequence
30from typing_extensions import Protocol
31from typing_extensions import runtime_checkable
32from tensorflow.python.util.tf_export import tf_export
33from tensorflow.tools.docs import doc_controls
34
35
36@tf_export("types.experimental.TraceType", v1=[])
37class TraceType(metaclass=abc.ABCMeta):
38  """Represents the type of object(s) for tf.function tracing purposes.
39
40  `TraceType` is an abstract class that other classes might inherit from to
41  provide information regarding associated class(es) for the purposes of
42  tf.function tracing. The typing logic provided through this mechanism will be
43  used to make decisions regarding usage of cached concrete functions and
44  retracing.
45
46  For example, if we have the following tf.function and classes:
47  ```python
48  @tf.function
49  def get_mixed_flavor(fruit_a, fruit_b):
50    return fruit_a.flavor + fruit_b.flavor
51
52  class Fruit:
53    flavor = tf.constant([0, 0])
54
55  class Apple(Fruit):
56    flavor = tf.constant([1, 2])
57
58  class Mango(Fruit):
59    flavor = tf.constant([3, 4])
60  ```
61
62  tf.function does not know when to re-use an existing concrete function in
63  regards to the `Fruit` class so naively it retraces for every new instance.
64  ```python
65  get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function
66  get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again
67  ```
68
69  However, we, as the designers of the `Fruit` class, know that each subclass
70  has a fixed flavor and we can reuse an existing traced concrete function if
71  it was the same subclass. Avoiding such unnecessary tracing of concrete
72  functions can have significant performance benefits.
73
74  ```python
75  class FruitTraceType(tf.types.experimental.TraceType):
76    def __init__(self, fruit_type):
77      self.fruit_type = fruit_type
78
79    def is_subtype_of(self, other):
80       return (type(other) is FruitTraceType and
81               self.fruit_type is other.fruit_type)
82
83    def most_specific_common_supertype(self, others):
84       return self if all(self == other for other in others) else None
85
86  class Fruit:
87
88   def __tf_tracing_type__(self, context):
89     return FruitTraceType(type(self))
90  ```
91
92  Now if we try calling it again:
93  ```python
94  get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function
95  get_mixed_flavor(Apple(), Mango()) # Re-uses the traced concrete function
96  ```
97  """
98
99  @abc.abstractmethod
100  def is_subtype_of(self, other: "TraceType") -> bool:
101    """Returns True if `self` is a subtype of `other`.
102
103    For example, `tf.function` uses subtyping for dispatch:
104    if `a.is_subtype_of(b)` is True, then an argument of `TraceType`
105    `a` can be used as argument to a `ConcreteFunction` traced with an
106    a `TraceType` `b`.
107
108    Args:
109     other: A TraceType object to be compared against.
110
111    Example:
112
113    ```python
114    class Dimension(TraceType):
115      def __init__(self, value: Optional[int]):
116        self.value = value
117
118      def is_subtype_of(self, other):
119        # Either the value is the same or other has a generalized value that
120        # can represent any specific ones.
121        return (self.value == other.value) or (other.value is None)
122    ```
123    """
124
125  @abc.abstractmethod
126  def most_specific_common_supertype(
127      self, others: Sequence["TraceType"]) -> Optional["TraceType"]:
128    """Returns the most specific supertype of `self` and `others`, if exists.
129
130    The returned `TraceType` is a supertype of `self` and `others`, that is,
131    they are all subtypes (see `is_subtype_of`) of it.
132    It is also most specific, that is, there it has no subtype that is also
133    a common supertype of `self` and `others`.
134
135    If `self` and `others` have no common supertype, this returns `None`.
136
137    Args:
138     others: A sequence of TraceTypes.
139
140    Example:
141    ```python
142     class Dimension(TraceType):
143       def __init__(self, value: Optional[int]):
144         self.value = value
145
146       def most_specific_common_supertype(self, other):
147          # Either the value is the same or other has a generalized value that
148          # can represent any specific ones.
149          if self.value == other.value:
150            return self.value
151          else:
152            return Dimension(None)
153    ```
154    """
155
156  # TODO(b/221309709): Polish into a stable placeholder_value.
157  @doc_controls.do_not_doc_inheritable
158  def _placeholder_value(self):
159    """Creates a placeholder for tracing.
160
161    Often it is more useful to trace with a placeholder value than an actual
162    one. For example, a placeholder value can represent multiple different
163    actual values. This means that the trace generated with that placeholder
164    value is more general and reusable which saves expensive retracing.
165
166    For the `Fruit` example shared above, implementing:
167
168    ```python
169    class FruitTraceType:
170      def _placeholder_value():
171        return Fruit()
172    ```
173    instructs tf.function to trace with the `Fruit()` objects
174    instead of the actual `Apple()` and `Mango()` objects when it receives a
175    call to `get_mixed_flavor(Apple(), Mango())`. For example, Tensor arguments
176    are replaced with Tensors of similar shape and dtype, output from
177    a tf.Placeholder op.
178
179    More generally, placeholder values are the arguments of a tf.function,
180    as seen from the function's body:
181    ```python
182    @tf.function
183    def foo(x):
184      # Here `x` can be the placeholder value
185      ...
186
187    foo(x) # Here `x` is the actual value
188    ```
189    """
190    raise NotImplementedError
191
192  @abc.abstractmethod
193  def __hash__(self) -> int:
194    pass
195
196  @abc.abstractmethod
197  def __eq__(self, other) -> bool:
198    pass
199
200
201@tf_export("types.experimental.TracingContext", v1=[])
202class TracingContext(metaclass=abc.ABCMeta):
203  """Contains information scoped to the tracing of multiple objects.
204
205  `TracingContext` is a container class for flags and variables that have
206  any kind of influence on the tracing behaviour of the class implementing
207  the __tf_tracing_type__. This context will be shared across all
208  __tf_tracing_type__ calls while constructing the TraceType for a particular
209  set of objects.
210  """
211  pass
212
213
214@runtime_checkable
215class SupportsTracingProtocol(Protocol):
216  """A protocol allowing custom classes to control tf.function retracing."""
217
218  @doc_controls.doc_private
219  @abc.abstractmethod
220  def __tf_tracing_type__(self, context: TracingContext) -> TraceType:
221    """Returns the tracing type of this object.
222
223    The tracing type is used to build the signature of a tf.function
224    when traced, and to match arguments with existing signatures.
225    When a Function object is called, tf.function looks at the tracing type
226    of the call arguments. If an existing signature of matching type exists,
227    it will be used. Otherwise, a new function is traced, and its signature
228    will use the tracing type of the call arguments.
229
230    Args:
231      context: a context object created for each function call for tracking
232        information about the call arguments as a whole
233    Returns:
234      The tracing type of this object.
235    """
236
237# TODO(b/219556836): Direct tf_export decorator adds non-method members to the
238# Protocol which breaks @runtime_checkable since it does not support them.
239tf_export(
240    "types.experimental.SupportsTracingProtocol",
241    v1=[]).export_constant(__name__, "SupportsTracingProtocol")
242