xref: /aosp_15_r20/external/tensorflow/tensorflow/python/eager/function_context.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Context information for a tf.function."""
16
17from typing import Any, NamedTuple, Tuple
18
19from tensorflow.core.function import trace_type
20from tensorflow.core.function.polymorphism import function_cache
21from tensorflow.python.eager import context
22from tensorflow.python.framework import device as pydev
23from tensorflow.python.framework import func_graph as func_graph_module
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.saved_model import save_context
27
28
29# EagerContext is used by tf.function to identify cases where tracing
30# needs to occur due to a change in conditions other than the arguments.
31class EagerContext(NamedTuple):
32  parent_graph: Any
33  device_functions: Any
34  colocation_stack: Any
35  in_cross_replica_context: Any
36  variable_policy: Any
37  xla_context_id: Any
38
39
40def make_function_context() -> function_cache.FunctionContext:
41  """Generates a FunctionContext based on current contextual info."""
42  ctx = context.context()
43
44  # Don't need to open an init_scope if the tf.function call is in eager mode
45  # already.
46  executing_eagerly = ctx.executing_eagerly()
47  parent_graph = None
48  xla_context_id = 0
49  if not executing_eagerly:
50    # We want to force function retracing for each different
51    # XLAControlFlowContext, so add `xla_context_id` to the context.
52    xla_context = _enclosing_xla_context()
53    if xla_context is not None and xla_context.RequiresUniqueFunctionRetracing(
54    ):
55      xla_context_id = id(xla_context)
56
57    with ops.init_scope():
58      # The graph, or whether we're executing eagerly, should be a part of the
59      # cache key so we don't improperly capture tensors such as variables.
60      executing_eagerly = ctx.executing_eagerly()
61      parent_graph = None if executing_eagerly else ops.get_default_graph()
62
63  # pylint: disable=protected-access
64  default_graph = ops.get_default_graph()
65  # TODO(b/117617952): The current distribution strategy will affect graph
66  # building (e.g. accessing different variables from different devices) and
67  # so requires retracing for each device.
68  strategy_stack = default_graph._distribution_strategy_stack
69  uses_distribution_strategy = (
70      strategy_stack and
71      strategy_stack[-1].strategy.extended._retrace_functions_for_each_device)
72  if executing_eagerly:
73    colocation_stack = ()
74    if uses_distribution_strategy:
75      device_functions = (pydev.merge_device(ctx.device_name),)
76    else:
77      device_functions = ()
78  else:
79    colocation_stack = tuple(default_graph._colocation_stack.peek_objs())
80    if (uses_distribution_strategy or
81        func_graph_module.device_stack_has_callable(
82            default_graph._device_function_stack)):
83      # Putting the device in the cache key ensures that call-site device
84      # annotations are respected.
85      device_functions = tuple(default_graph._device_functions_outer_to_inner)
86    else:
87      device_functions = ()
88
89  in_cross_replica_context = False
90  try:
91    in_cross_replica_context = (strategy_stack[-1].replica_context is None)  # pylint: disable=protected-access
92  except (AttributeError, IndexError):
93    pass
94
95  if save_context.in_save_context():
96    variable_policy = (
97        save_context.get_save_options().experimental_variable_policy)
98  else:
99    variable_policy = None
100
101  return function_cache.FunctionContext(
102      EagerContext(parent_graph, device_functions, colocation_stack,
103                   in_cross_replica_context, variable_policy, xla_context_id))
104
105
106def _enclosing_xla_context():
107  """Returns the XLAControlFlowContext, which exists inside a tpu.rewrite()."""
108  graph = ops.get_default_graph()
109  while graph is not None:
110    # pylint: disable=protected-access
111    context_ = graph._get_control_flow_context()
112    # pylint: enable=protected-access
113    while context_ is not None:
114      if isinstance(context_, control_flow_ops.XLAControlFlowContext):
115        return context_
116      context_ = context_.outer_context
117    # This may be a FuncGraph due to defuns or v2 control flow. We need to
118    # find the original graph with the XLAControlFlowContext.
119    graph = getattr(graph, "outer_graph", None)
120  return None
121
122
123def make_cache_key(
124    args: Any,
125    captures: Any = None,
126) -> Tuple[function_cache.FunctionCacheKey, trace_type.WeakrefDeletionObserver]:
127  """Computes the cache key given the function arguments."""
128  if captures is None:
129    captures = dict()
130  signature_context = trace_type.InternalTracingContext()
131  args_signature = trace_type.from_object(
132      args, signature_context)
133  captures_dict_tracetype = trace_type.from_object(
134      captures, signature_context)
135  captures_signature = function_cache.CaptureSnapshot(
136      captures_dict_tracetype.mapping)
137
138  return function_cache.FunctionCacheKey(
139      args_signature,
140      captures_signature,
141      make_function_context()), signature_context.deletion_observer
142