1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Gradient tape utilities.""" 16 17import contextlib 18 19from tensorflow.python import pywrap_tfe 20from tensorflow.python.util.lazy_loader import LazyLoader 21 22# There is a circular dependency between this, ops.py, and 23# distribution_strategy_context. 24# TODO(b/117329403): Remove this circular dependency. 25distribution_strategy_context = LazyLoader( 26 "distribution_strategy_context", globals(), 27 "tensorflow.python.distribute." 28 "distribution_strategy_context") 29 30 31class Tape(object): 32 """Represents a gradient propagation trace.""" 33 34 __slots__ = ["_tape"] 35 36 def __init__(self, tape): 37 self._tape = tape 38 39 def watched_variables(self): 40 return pywrap_tfe.TFE_Py_TapeWatchedVariables(self._tape) 41 42 43def push_new_tape(persistent=False, watch_accessed_variables=True): 44 """Pushes a new tape onto the tape stack.""" 45 tape = pywrap_tfe.TFE_Py_TapeSetNew(persistent, watch_accessed_variables) 46 return Tape(tape) 47 48 49def push_tape(tape): 50 """Pushes an existing tape onto the tape stack.""" 51 pywrap_tfe.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access 52 53 54def watch(tape, tensor): 55 """Marks this tensor to be watched by the given tape.""" 56 pywrap_tfe.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access 57 58 59class VariableWatcher(object): 60 """A scope that tracks all trainable variable accesses within it. 61 62 This explicitly ignores variables that are not marked as trainable. 63 64 Sample usage: 65 66 var = tf.Variable(0.0) 67 with VariableWatcher() as variable_watcher: 68 var.assign_add(1.0) 69 70 assert variable_watcher.watched_variables == [var] 71 """ 72 73 __slots__ = ["_variable_watcher"] 74 75 def __init__(self): 76 self._variable_watcher = None 77 78 def __enter__(self): 79 self._variable_watcher = pywrap_tfe.TFE_Py_VariableWatcherNew() 80 return self 81 82 def __exit__(self, typ, value, traceback): 83 pywrap_tfe.TFE_Py_VariableWatcherRemove(self._variable_watcher) 84 85 def watched_variables(self): 86 """Returns a tuple of variables accessed under this scope.""" 87 return pywrap_tfe.TFE_Py_VariableWatcherWatchedVariables( 88 self._variable_watcher) 89 90 91def watch_variable(tape, variable): 92 """Marks this variable to be watched by the given tape.""" 93 strategy, context = ( 94 distribution_strategy_context.get_strategy_and_replica_context()) 95 if context: 96 variables = [strategy.extended.value_container(variable)] 97 else: 98 variables = strategy.experimental_local_results(variable) 99 for var in variables: 100 pywrap_tfe.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access 101 pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var) 102 103 104def variable_accessed(variable): 105 """Notifies all tapes in the stack that a variable has been accessed. 106 107 Args: 108 variable: variable to be watched. 109 """ 110 strategy, context = ( 111 distribution_strategy_context.get_strategy_and_replica_context()) 112 if context: 113 variables = [strategy.extended.value_container(variable)] 114 else: 115 variables = strategy.experimental_local_results(variable) 116 for var in variables: 117 pywrap_tfe.TFE_Py_TapeVariableAccessed(var) 118 pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var) 119 120 121def variables_accessed(variables): 122 """Notifies all tapes in the stack that variables have been accessed. 123 124 Only trainable variables are marked as accessed. 125 126 Args: 127 variables: iterable of variables to mark as accessed. 128 """ 129 strategy, context = ( 130 distribution_strategy_context.get_strategy_and_replica_context()) 131 accessed = [] 132 if context: 133 accessed = [strategy.extended.value_container(variable) 134 for variable in variables if variable.trainable] 135 else: 136 for variable in variables: 137 if variable.trainable: 138 accessed.extend(strategy.experimental_local_results(variable)) 139 140 for var in accessed: 141 pywrap_tfe.TFE_Py_TapeVariableAccessed(var) 142 pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var) 143 144 145def pop_tape(tape): 146 """Pops the given tape in the stack.""" 147 pywrap_tfe.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access 148 149 150@contextlib.contextmanager 151def stop_recording(): 152 """Stop all gradient recording (backprop and forwardprop).""" 153 is_stopped = pywrap_tfe.TFE_Py_TapeSetIsStopped() 154 try: 155 if not is_stopped: 156 pywrap_tfe.TFE_Py_TapeSetStopOnThread() 157 yield 158 finally: 159 if not is_stopped: 160 pywrap_tfe.TFE_Py_TapeSetRestartOnThread() 161 162 163def should_record_backprop(tensors): 164 """Returns true if any tape in the stack watches any of these tensors. 165 166 Only takes GradientTapes into account, not forward accumulators. 167 168 Args: 169 tensors: Tensors to check, typically inputs to an operation. 170 171 Returns: 172 Boolean, whether any tape watches any of `tensors`. 173 """ 174 return pywrap_tfe.TFE_Py_TapeSetShouldRecordBackprop(tensors) 175 176 177def record_operation(op_type, output_tensors, input_tensors, backward_function, 178 forward_function=None): 179 """Records the operation on all tapes in the stack.""" 180 pywrap_tfe.TFE_Py_TapeSetRecordOperation(op_type, output_tensors, 181 input_tensors, backward_function, 182 forward_function) 183 184 185def record_operation_backprop_only(op_type, output_tensors, input_tensors, 186 backward_function): 187 """Records the operation on all backward tapes in the stack.""" 188 pywrap_tfe.TFE_Py_TapeSetRecordOperationBackprop(op_type, output_tensors, 189 input_tensors, 190 backward_function) 191 192 193def record_operation_forwardprop_only(op_type, output_tensors, input_tensors, 194 backward_function, 195 forwardprop_output_indices): 196 """Records the operation on all forward accumulators in the stack. 197 198 Args: 199 op_type: a string for the operation type, used in the backprop code 200 output_tensors: a list of Python Tensor objects output by the operation 201 input_tensors: a list of input Tensors to the recorded operation 202 backward_function: the function to be called to, given the gradients of the 203 output tensors, produce the gradients of the input tensors. This function 204 is automatically transposed to produce output gradients given input 205 gradients. 206 forwardprop_output_indices: indicates any output_tensors which contain JVPs. 207 Typically these will have come from TFE_Py_PackForwardGradients. May be 208 None or an empty sequence if there are no JVP outputs from the operation. 209 """ 210 pywrap_tfe.TFE_Py_TapeSetRecordOperationForwardprop( 211 op_type, output_tensors, input_tensors, backward_function, 212 forwardprop_output_indices) 213 214 215def delete_trace(tensor_id): 216 """Deletes traces for this Tensor from all tapes in the stack.""" 217 pywrap_tfe.TFE_Py_TapeSetDeleteTrace(tensor_id) 218 219 220def could_possibly_record(): 221 """Returns True if any tape is active.""" 222 return not pywrap_tfe.TFE_Py_TapeSetIsEmpty() 223