xref: /aosp_15_r20/external/tensorflow/tensorflow/python/eager/tape.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Gradient tape utilities."""
16
17import contextlib
18
19from tensorflow.python import pywrap_tfe
20from tensorflow.python.util.lazy_loader import LazyLoader
21
22# There is a circular dependency between this, ops.py, and
23# distribution_strategy_context.
24# TODO(b/117329403): Remove this circular dependency.
25distribution_strategy_context = LazyLoader(
26    "distribution_strategy_context", globals(),
27    "tensorflow.python.distribute."
28    "distribution_strategy_context")
29
30
31class Tape(object):
32  """Represents a gradient propagation trace."""
33
34  __slots__ = ["_tape"]
35
36  def __init__(self, tape):
37    self._tape = tape
38
39  def watched_variables(self):
40    return pywrap_tfe.TFE_Py_TapeWatchedVariables(self._tape)
41
42
43def push_new_tape(persistent=False, watch_accessed_variables=True):
44  """Pushes a new tape onto the tape stack."""
45  tape = pywrap_tfe.TFE_Py_TapeSetNew(persistent, watch_accessed_variables)
46  return Tape(tape)
47
48
49def push_tape(tape):
50  """Pushes an existing tape onto the tape stack."""
51  pywrap_tfe.TFE_Py_TapeSetAdd(tape._tape)  # pylint: disable=protected-access
52
53
54def watch(tape, tensor):
55  """Marks this tensor to be watched by the given tape."""
56  pywrap_tfe.TFE_Py_TapeWatch(tape._tape, tensor)  # pylint: disable=protected-access
57
58
59class VariableWatcher(object):
60  """A scope that tracks all trainable variable accesses within it.
61
62  This explicitly ignores variables that are not marked as trainable.
63
64  Sample usage:
65
66  var = tf.Variable(0.0)
67  with VariableWatcher() as variable_watcher:
68    var.assign_add(1.0)
69
70  assert variable_watcher.watched_variables == [var]
71  """
72
73  __slots__ = ["_variable_watcher"]
74
75  def __init__(self):
76    self._variable_watcher = None
77
78  def __enter__(self):
79    self._variable_watcher = pywrap_tfe.TFE_Py_VariableWatcherNew()
80    return self
81
82  def __exit__(self, typ, value, traceback):
83    pywrap_tfe.TFE_Py_VariableWatcherRemove(self._variable_watcher)
84
85  def watched_variables(self):
86    """Returns a tuple of variables accessed under this scope."""
87    return pywrap_tfe.TFE_Py_VariableWatcherWatchedVariables(
88        self._variable_watcher)
89
90
91def watch_variable(tape, variable):
92  """Marks this variable to be watched by the given tape."""
93  strategy, context = (
94      distribution_strategy_context.get_strategy_and_replica_context())
95  if context:
96    variables = [strategy.extended.value_container(variable)]
97  else:
98    variables = strategy.experimental_local_results(variable)
99  for var in variables:
100    pywrap_tfe.TFE_Py_TapeWatchVariable(tape._tape, var)  # pylint: disable=protected-access
101    pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
102
103
104def variable_accessed(variable):
105  """Notifies all tapes in the stack that a variable has been accessed.
106
107  Args:
108    variable: variable to be watched.
109  """
110  strategy, context = (
111      distribution_strategy_context.get_strategy_and_replica_context())
112  if context:
113    variables = [strategy.extended.value_container(variable)]
114  else:
115    variables = strategy.experimental_local_results(variable)
116  for var in variables:
117    pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
118    pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
119
120
121def variables_accessed(variables):
122  """Notifies all tapes in the stack that variables have been accessed.
123
124  Only trainable variables are marked as accessed.
125
126  Args:
127    variables: iterable of variables to mark as accessed.
128  """
129  strategy, context = (
130      distribution_strategy_context.get_strategy_and_replica_context())
131  accessed = []
132  if context:
133    accessed = [strategy.extended.value_container(variable)
134                for variable in variables if variable.trainable]
135  else:
136    for variable in variables:
137      if variable.trainable:
138        accessed.extend(strategy.experimental_local_results(variable))
139
140  for var in accessed:
141    pywrap_tfe.TFE_Py_TapeVariableAccessed(var)
142    pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var)
143
144
145def pop_tape(tape):
146  """Pops the given tape in the stack."""
147  pywrap_tfe.TFE_Py_TapeSetRemove(tape._tape)  # pylint: disable=protected-access
148
149
150@contextlib.contextmanager
151def stop_recording():
152  """Stop all gradient recording (backprop and forwardprop)."""
153  is_stopped = pywrap_tfe.TFE_Py_TapeSetIsStopped()
154  try:
155    if not is_stopped:
156      pywrap_tfe.TFE_Py_TapeSetStopOnThread()
157    yield
158  finally:
159    if not is_stopped:
160      pywrap_tfe.TFE_Py_TapeSetRestartOnThread()
161
162
163def should_record_backprop(tensors):
164  """Returns true if any tape in the stack watches any of these tensors.
165
166  Only takes GradientTapes into account, not forward accumulators.
167
168  Args:
169    tensors: Tensors to check, typically inputs to an operation.
170
171  Returns:
172    Boolean, whether any tape watches any of `tensors`.
173  """
174  return pywrap_tfe.TFE_Py_TapeSetShouldRecordBackprop(tensors)
175
176
177def record_operation(op_type, output_tensors, input_tensors, backward_function,
178                     forward_function=None):
179  """Records the operation on all tapes in the stack."""
180  pywrap_tfe.TFE_Py_TapeSetRecordOperation(op_type, output_tensors,
181                                           input_tensors, backward_function,
182                                           forward_function)
183
184
185def record_operation_backprop_only(op_type, output_tensors, input_tensors,
186                                   backward_function):
187  """Records the operation on all backward tapes in the stack."""
188  pywrap_tfe.TFE_Py_TapeSetRecordOperationBackprop(op_type, output_tensors,
189                                                   input_tensors,
190                                                   backward_function)
191
192
193def record_operation_forwardprop_only(op_type, output_tensors, input_tensors,
194                                      backward_function,
195                                      forwardprop_output_indices):
196  """Records the operation on all forward accumulators in the stack.
197
198  Args:
199    op_type: a string for the operation type, used in the backprop code
200    output_tensors: a list of Python Tensor objects output by the operation
201    input_tensors: a list of input Tensors to the recorded operation
202    backward_function: the function to be called to, given the gradients of the
203      output tensors, produce the gradients of the input tensors. This function
204      is automatically transposed to produce output gradients given input
205      gradients.
206    forwardprop_output_indices: indicates any output_tensors which contain JVPs.
207      Typically these will have come from TFE_Py_PackForwardGradients. May be
208      None or an empty sequence if there are no JVP outputs from the operation.
209  """
210  pywrap_tfe.TFE_Py_TapeSetRecordOperationForwardprop(
211      op_type, output_tensors, input_tensors, backward_function,
212      forwardprop_output_indices)
213
214
215def delete_trace(tensor_id):
216  """Deletes traces for this Tensor from all tapes in the stack."""
217  pywrap_tfe.TFE_Py_TapeSetDeleteTrace(tensor_id)
218
219
220def could_possibly_record():
221  """Returns True if any tape is active."""
222  return not pywrap_tfe.TFE_Py_TapeSetIsEmpty()
223