xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/debug_gradients.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow Debugger: Tools for debugging gradients."""
16
17import re
18import uuid
19
20from tensorflow.python.debug.lib import debug_data
21from tensorflow.python.debug.lib import debug_graphs
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import gen_array_ops
24from tensorflow.python.ops import variables
25
26_GRADIENT_DEBUG_TAG = "gradient_debug_"
27
28_gradient_debuggers = {}
29
30
31def _tensor_to_grad_debug_op_name(tensor, grad_debugger_uuid):
32  op_name, slot = debug_graphs.parse_node_or_tensor_name(tensor.name)
33  return "%s_%d/%s%s" % (op_name, slot, _GRADIENT_DEBUG_TAG, grad_debugger_uuid)
34
35
36def _parse_grad_debug_op_name(op_name):
37  """Parse the name of a debug gradient op.
38
39  Args:
40    op_name: the name of the debug gradient op.
41
42  Returns:
43    1) The UUID of the GradientsDebugger that created the debug gradient op.
44    2) Name of the original tensor whose gradient is debugged by the debug
45       gradient op.
46  """
47  name_items = op_name.split("/")
48  assert len(name_items) > 1
49  assert name_items[-1].startswith(_GRADIENT_DEBUG_TAG)
50
51  grad_debugger_uuid = name_items[-1][len(_GRADIENT_DEBUG_TAG):]
52  if "_" in grad_debugger_uuid:
53    grad_debugger_uuid = grad_debugger_uuid[:grad_debugger_uuid.index("_")]
54  orig_tensor_slot = int(name_items[-2][name_items[-2].rfind("_") + 1:])
55  orig_base_op_name = name_items[-2][:name_items[-2].rfind("_")]
56  orig_tensor_name = ("/".join(name_items[:-2] + [orig_base_op_name]) +
57                      ":%d" % orig_tensor_slot)
58
59  return grad_debugger_uuid, orig_tensor_name
60
61
62class GradientsDebugger:
63  """Gradients Debugger.
64
65  Allows retrieval of gradient tensors created by TensorFlow's automatic
66  differentiation algorithm, i.e., `tf.gradients` and optimizer classes that
67  use it.
68  """
69  # TODO(cais): Add examples code in the doc string?
70
71  def __init__(self, y_tensor=None):
72    """Constructor of GradientsDebugger.
73
74    Args:
75      y_tensor: optional: the `tf.Tensor` to be differentiated, i.e., the tensor
76        on the numerator of the differentiation.
77    """
78
79    self._uuid = uuid.uuid4().hex
80    _gradient_debuggers[self._uuid] = self
81
82    # A dict mapping x-tensor names to gradient tensor. x-tensor refers to the
83    # independent tf.Tensor, i.e., the tensor on the denominator of the
84    # differentiation.
85    self._gradient_tensors = {}
86    self._y_tensor = y_tensor
87
88    self._graph = None
89    if y_tensor:
90      self._graph = y_tensor.graph
91
92    self._is_active_context = False
93
94  @property
95  def y_tensor(self):
96    return self._y_tensor
97
98  @property
99  def graph(self):
100    return self._graph
101
102  def __enter__(self):
103    self._is_active_context = True
104
105  def __exit__(self, unused_type, unused_value, unused_traceback):
106    self._is_active_context = False
107
108  def identify_gradient(self, input_tensor):
109    """Create a debug identity tensor that registers and forwards gradients.
110
111    The side effect of this method is that when gradient tensor(s) are created
112    with respect to the any paths that include the `input_tensor`, the gradient
113    tensor(s) with respect to `input_tensor` will be registered with this
114    this `GradientsDebugger` instance and can later be retrieved, with the
115    methods `gradient_tensor` and `gradient_tensors`.
116
117    Example:
118
119    ```python
120    x = tf.Variable(1.0)
121    y = tf.add(x, x)
122
123    grad_debugger = tf_debug.GradientsDebugger()
124    debug_y = grad_debugger.identify_gradient(y)
125    z = tf.square(debug_y)
126
127    # Create a train op under the grad_debugger context.
128    with grad_debugger:
129      train_op = tf.compat.v1.train.GradientDescentOptimizer(z)
130
131    # Now we can reflect through grad_debugger to get the gradient tensor
132    # with respect to y.
133    y_grad = grad_debugger.gradient_tensor(y)
134    ```
135
136    Args:
137      input_tensor: the input `tf.Tensor` object whose related gradient tensors
138        are to be registered with this `GradientsDebugger` instance when they
139        are created, e.g., during `tf.gradients` calls or the construction
140        of optimization (training) op that uses `tf.gradients`.
141
142    Returns:
143      A forwarded identity of `input_tensor`, as a `tf.Tensor`.
144
145    Raises:
146      ValueError: If an op with name that duplicates the gradient-debugging op
147        already exists in the graph (highly unlikely).
148    """
149    # TODO(cais): Allow overriding gradient.
150    # TODO(cais): Implement value_stack.
151    grad_debug_op_name = _tensor_to_grad_debug_op_name(input_tensor, self._uuid)
152    # pylint: disable=protected-access
153    identity_op = (
154        gen_array_ops.debug_gradient_ref_identity
155        if input_tensor.dtype._is_ref_dtype else
156        gen_array_ops.debug_gradient_identity)
157    # pylint: enable=protected-access
158    debug_grad_identity = identity_op(input_tensor, name=grad_debug_op_name)
159    assert debug_grad_identity.dtype == input_tensor.dtype
160    if debug_grad_identity.op.name != grad_debug_op_name:
161      raise ValueError(
162          "The graph already contains an op named %s" % grad_debug_op_name)
163    return debug_grad_identity
164
165  def watch_gradients_by_tensors(self, graph, tensors):
166    """Watch gradient tensors by x-tensor(s).
167
168    The side effect of this method is that when gradient tensor(s) are created
169    with respect to the any paths that include the `x_tensor`s, the gradient
170    tensor(s) with respect to the tensor will be registered with this
171    this `GradientsDebugger` instance and can later be retrieved, with the
172    methods `gradient_tensor` and `gradient_tensors`.
173
174    Unlike the method `identify_gradient`, this method is used to retrieve
175    gradient tensors after the construction of the forward subgraph has
176    completed (but before the construction of the backward subgraph).
177
178    This method is the same as `watch_gradients_by_x_tensor_names` except that
179    the tensors are specified by the Python `tf.Tensor` or `tf.Variable`
180    objects, instead by name patterns.
181
182    Example:
183
184    ```python
185    x = tf.Variable(1.0)
186    y = tf.add(x, x, name="y")
187    z = tf.square(debug_y)
188
189    # Create a train op under the grad_debugger context.
190    grad_debugger = tf_debug.GradientsDebugger()
191    with grad_debugger.watch_gradients_by_tensors(y):
192      train_op = tf.compat.v1.train.GradientDescentOptimizer(z)
193
194    # Now we can reflect through grad_debugger to get the gradient tensor
195    # with respect to y.
196    y_grad = grad_debugger.gradient_tensor(y)
197    # or
198    y_grad = grad_debugger.gradient_tensor("y:0")
199    ```
200
201    Args:
202      graph: the `tf.Graph` to watch the gradients on.
203      tensors: a `tf.Tensor` or `tf.Variable` object, or a list of such objects.
204
205    Returns:
206      The GradientsDebugger instance itself.
207    """
208
209    if not isinstance(tensors, list):
210      tensors = [tensors]
211
212    tensor_name_regex = []
213    for tensor in tensors:
214      tensor_name_regex.append(re.escape(tensor.name) + "$")
215    tensor_name_regex = "(" + "|".join(tensor_name_regex) + ")"
216    return self.watch_gradients_by_tensor_names(graph, tensor_name_regex)
217
218  def watch_gradients_by_tensor_names(self, graph, tensor_name_regex):
219    """Watch gradient tensors by name(s) of the x-tensor(s).
220
221    The side effect of this method is that when gradient tensor(s) are created
222    with respect to the x-tensors, the gradient tensor(s) will be registered
223    with this `GradientsDebugger` instance and can later be retrieved.
224
225    Unlike the `identify_gradient` method, this method is used after the
226    construction of the forward graph has completed. Unlike the
227    `watch_gradients_by_tensor` method, this method does not use handles to the
228    tensors of interest; it uses their names.
229
230    This method is the same as `watch_gradients_by_tensors` except that the
231    x-tensors are specified by name patterns, instead of `tf.Tensor` or
232    `tf.Variable` objects.
233
234    Example:
235
236    ```python
237    x = tf.Variable(1.0, name="x")
238    y = tf.add(x, x, name="y")
239    z = tf.square(debug_y)
240
241    # Create a train op under the grad_debugger context.
242    grad_debugger = tf_debug.GradientsDebugger()
243    with grad_debugger.watch_gradients_by_tensor_names(r"(x|y):0$"):
244      train_op = tf.compat.v1.train.GradientDescentOptimizer(z)
245
246    # Now we can reflect through grad_debugger to get the gradient tensor
247    # with respect to x and y.
248    x_grad = grad_debugger.gradient_tensor("x:0")
249    y_grad = grad_debugger.gradient_tensor("y:0")
250    ```
251
252    Args:
253      graph: the `tf.Graph` to watch the gradients on.
254      tensor_name_regex: the regular-expression pattern of the name(s) of the
255        x-tensor(s) to watch. x-tensor refers to the tensors on the denominator
256        of the differentiation.
257
258    Returns:
259      The GradientsDebugger instance itself.
260    """
261    tensor_name_pattern = re.compile(tensor_name_regex)
262    with graph.as_default():
263      for op in graph.get_operations():
264        for output in op.outputs:
265          if tensor_name_pattern.match(output.name):
266            debug_op = self.identify_gradient(output)
267
268            # Make a copy of output.consumers() since we'll modify the consumers
269            # TODO(skyewm): this is unnecessary once the C API is enabled
270            for consumer in list(output.consumers()):
271              if consumer == debug_op.op:
272                continue
273
274              # Locate the slot index of the original input.
275              for i, consumer_input in enumerate(consumer.inputs):
276                if consumer_input == output:
277                  consumer._update_input(i, debug_op)  # pylint: disable=protected-access
278    return self
279
280  def _check_same_graph(self, tensor):
281    if self._graph is None:
282      self._graph = tensor.graph
283    elif self._graph != tensor.graph:
284      raise ValueError(
285          "The graph of the value (%s) is not the same as the graph %s" %
286          (tensor.graph, self._graph))
287
288  def register_gradient_tensor(self,
289                               x_tensor_name,
290                               gradient_tensor):
291    """Register the gradient tensor for an x-tensor.
292
293    Args:
294      x_tensor_name: (`str`) the name of the independent `tf.Tensor`, i.e.,
295        the tensor on the denominator of the differentiation.
296      gradient_tensor: the gradient `tf.Tensor`.
297    """
298    if len(_gradient_debuggers) == 1 or self._is_active_context:
299      self._check_same_graph(gradient_tensor)
300      self._gradient_tensors[x_tensor_name] = gradient_tensor
301
302  def gradient_tensor(self, x_tensor):
303    """Get the gradient tensor of an x-tensor.
304
305    Args:
306      x_tensor: (`tf.Tensor`, `tf.Variable` or `str`) The x-tensor object or its
307        name. x-tensor refers to the independent `tf.Tensor`, i.e., the tensor
308        on the denominator of the differentiation.
309
310    Returns:
311      If found, the gradient tensor.
312
313    Raises:
314      TypeError: If `x_tensor` is not a `tf.Tensor`, `tf.Variable` or `str`.
315      LookupError: If the `x_tensor` has not been registered with a gradient
316        tensor.
317    """
318    x_tensor_name = self._get_tensor_name(x_tensor)
319    if x_tensor_name not in self._gradient_tensors:
320      raise LookupError(
321          "This GradientsDebugger has not received any gradient tensor for "
322          "x-tensor %s" % x_tensor_name)
323    return self._gradient_tensors[x_tensor_name]
324
325  def gradient_tensors(self):
326    """Get the gradient tensors that this object is aware of.
327
328    Returns:
329      A dict mapping x-tensor names to gradient tensor objects. x-tensor refers
330      to the tensors on the denominator of the differentation.
331    """
332    return self._gradient_tensors
333
334  def _get_tensor_name(self, tensor):
335    if isinstance(tensor, (ops.Tensor, variables.Variable)):
336      return tensor.name
337    elif isinstance(tensor, str):
338      return tensor
339    else:
340      raise TypeError(
341          "x_tensor must be a str or tf.Tensor or tf.Variable, "
342          "but instead has type %s" % type(tensor))
343
344
345def clear_gradient_debuggers():
346  """Clear all globally registered gradient debuggers."""
347  _gradient_debuggers.clear()
348
349
350@ops.RegisterGradient("DebugGradientIdentity")
351def _identify_gradient_grad(op, dy):
352  """Gradient function for the DebugIdentity op."""
353  # TODO(cais): Allow overriding gradient.
354  grad_debugger_uuid, orig_tensor_name = _parse_grad_debug_op_name(op.name)
355  grad_debugger = _gradient_debuggers[grad_debugger_uuid]
356  grad_debugger.register_gradient_tensor(orig_tensor_name, dy)
357  return dy
358
359
360@ops.RegisterGradient("DebugGradientRefIdentity")
361def _identify_gradient_grad_ref(op, dy):
362  """Gradient function for the DebugIdentity op."""
363  return _identify_gradient_grad(op, dy)
364
365
366def gradient_values_from_dump(grad_debugger, x_tensor, dump):
367  """Find gradient values from a `DebugDumpDir` object.
368
369  Args:
370    grad_debugger: the `tf_debug.GradientsDebugger` instance to be used.
371    x_tensor: (`tf.Tensor`, `tf.Variable` or `str`) The x-tensor object or its
372      name. x-tensor refers to the independent `tf.Tensor`, i.e., the tensor
373      on the denominator of the differentiation.
374    dump: A `tfdbg.DebugDumpDir` object.
375
376  Returns:
377    If this `GradientsDebugger` instance has the gradient tensor of `x_tensor`
378      registered: a list of `numpy.ndarray` representing the value of the
379      gradient tensor from `dump`. The list could be empty, if the gradient
380      tensor is not executed in the `tf.Session.run()` call that generated
381      the `dump`. The list could also contain multiple values of the gradient
382      tensor, e.g., if gradient tensor is computed repeatedly in a
383      `tf.while_loop` during the run that generated the `dump`.
384
385  Raises:
386    LookupError: If this `GradientsDebugger` instance does not have the
387      gradient tensor of `x_tensor` registered.
388    ValueError: If this `GradientsDebugger` has a `tf.Graph` object that
389      does not match the `tf.Graph` object of the `dump`.
390    TypeError: If `x_tensor` is not a `tf.Tensor`, `tf.Variable` or `str`.
391  """
392  # TODO(cais): Use this method in LocalCLIDebugWrapperSession to present the
393  # gradient tensors to the TFDBG CLI.
394
395  # If possible, verify that the Python graph of the dump and that of this
396  # GradientsDebugger match.
397  if (dump.python_graph and grad_debugger.graph and
398      dump.python_graph != grad_debugger.graph):
399    raise ValueError(
400        "This GradientsDebugger instance has a graph (%s) that differs from "
401        "the graph of the DebugDumpDir object (%s)." %
402        (grad_debugger.graph, dump.python_graph))
403
404  gradient_tensor = grad_debugger.gradient_tensor(x_tensor)
405  node_name, output_slot = debug_graphs.parse_node_or_tensor_name(
406      gradient_tensor.name)
407
408  try:
409    return dump.get_tensors(node_name, output_slot, "DebugIdentity")
410  except debug_data.WatchKeyDoesNotExistInDebugDumpDirError:
411    return []
412