1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Context information for a tf.function.""" 16 17from typing import Any, NamedTuple, Tuple 18 19from tensorflow.core.function import trace_type 20from tensorflow.core.function.polymorphism import function_cache 21from tensorflow.python.eager import context 22from tensorflow.python.framework import device as pydev 23from tensorflow.python.framework import func_graph as func_graph_module 24from tensorflow.python.framework import ops 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.saved_model import save_context 27 28 29# EagerContext is used by tf.function to identify cases where tracing 30# needs to occur due to a change in conditions other than the arguments. 31class EagerContext(NamedTuple): 32 parent_graph: Any 33 device_functions: Any 34 colocation_stack: Any 35 in_cross_replica_context: Any 36 variable_policy: Any 37 xla_context_id: Any 38 39 40def make_function_context() -> function_cache.FunctionContext: 41 """Generates a FunctionContext based on current contextual info.""" 42 ctx = context.context() 43 44 # Don't need to open an init_scope if the tf.function call is in eager mode 45 # already. 46 executing_eagerly = ctx.executing_eagerly() 47 parent_graph = None 48 xla_context_id = 0 49 if not executing_eagerly: 50 # We want to force function retracing for each different 51 # XLAControlFlowContext, so add `xla_context_id` to the context. 52 xla_context = _enclosing_xla_context() 53 if xla_context is not None and xla_context.RequiresUniqueFunctionRetracing( 54 ): 55 xla_context_id = id(xla_context) 56 57 with ops.init_scope(): 58 # The graph, or whether we're executing eagerly, should be a part of the 59 # cache key so we don't improperly capture tensors such as variables. 60 executing_eagerly = ctx.executing_eagerly() 61 parent_graph = None if executing_eagerly else ops.get_default_graph() 62 63 # pylint: disable=protected-access 64 default_graph = ops.get_default_graph() 65 # TODO(b/117617952): The current distribution strategy will affect graph 66 # building (e.g. accessing different variables from different devices) and 67 # so requires retracing for each device. 68 strategy_stack = default_graph._distribution_strategy_stack 69 uses_distribution_strategy = ( 70 strategy_stack and 71 strategy_stack[-1].strategy.extended._retrace_functions_for_each_device) 72 if executing_eagerly: 73 colocation_stack = () 74 if uses_distribution_strategy: 75 device_functions = (pydev.merge_device(ctx.device_name),) 76 else: 77 device_functions = () 78 else: 79 colocation_stack = tuple(default_graph._colocation_stack.peek_objs()) 80 if (uses_distribution_strategy or 81 func_graph_module.device_stack_has_callable( 82 default_graph._device_function_stack)): 83 # Putting the device in the cache key ensures that call-site device 84 # annotations are respected. 85 device_functions = tuple(default_graph._device_functions_outer_to_inner) 86 else: 87 device_functions = () 88 89 in_cross_replica_context = False 90 try: 91 in_cross_replica_context = (strategy_stack[-1].replica_context is None) # pylint: disable=protected-access 92 except (AttributeError, IndexError): 93 pass 94 95 if save_context.in_save_context(): 96 variable_policy = ( 97 save_context.get_save_options().experimental_variable_policy) 98 else: 99 variable_policy = None 100 101 return function_cache.FunctionContext( 102 EagerContext(parent_graph, device_functions, colocation_stack, 103 in_cross_replica_context, variable_policy, xla_context_id)) 104 105 106def _enclosing_xla_context(): 107 """Returns the XLAControlFlowContext, which exists inside a tpu.rewrite().""" 108 graph = ops.get_default_graph() 109 while graph is not None: 110 # pylint: disable=protected-access 111 context_ = graph._get_control_flow_context() 112 # pylint: enable=protected-access 113 while context_ is not None: 114 if isinstance(context_, control_flow_ops.XLAControlFlowContext): 115 return context_ 116 context_ = context_.outer_context 117 # This may be a FuncGraph due to defuns or v2 control flow. We need to 118 # find the original graph with the XLAControlFlowContext. 119 graph = getattr(graph, "outer_graph", None) 120 return None 121 122 123def make_cache_key( 124 args: Any, 125 captures: Any = None, 126) -> Tuple[function_cache.FunctionCacheKey, trace_type.WeakrefDeletionObserver]: 127 """Computes the cache key given the function arguments.""" 128 if captures is None: 129 captures = dict() 130 signature_context = trace_type.InternalTracingContext() 131 args_signature = trace_type.from_object( 132 args, signature_context) 133 captures_dict_tracetype = trace_type.from_object( 134 captures, signature_context) 135 captures_signature = function_cache.CaptureSnapshot( 136 captures_dict_tracetype.mapping) 137 138 return function_cache.FunctionCacheKey( 139 args_signature, 140 captures_signature, 141 make_function_context()), signature_context.deletion_observer 142