xref: /aosp_15_r20/external/tensorflow/tensorflow/python/debug/lib/debug_events_monitors.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Monitors for Debug Events in the tfdbg2 format.
16
17Monitors get access to graph-building- and execution-related data
18objects as the DebugDataReader (see `debug_events_reader.py`) reads the
19data in a continuous fashion, via a set of callbacks. This mechanism enables
20hooking custom logic into the DebugEvent reading stream without the need for
21any polling or iterating over the entire data held by DebugDataReader.
22
23This module includes the following built-in hooks:
24  - InfNanMonitor: Monitors infinity and nan values in top-level execution and
25    intra-graph execution events.
26
27When a monitor (subtype of `BaseMonitor`) is constructed with a DebugDataReader
28as the first argument of the constructor call, the monitor is automatically
29registered with the DebugDataReader. For example:
30
31```py
32debug_data_reader = debug_events_reader.DebugDataReader(dump_dir)
33inf_nan_monitor = debug_events_monitors.InfNanMonitor(debug_data_reader)
34
35debug_data_reader.update()
36# `inf_nan_monitor`'s on_* methods will get called as the execution-related
37# and other types of data are read by `debug_data_reader`.
38```
39"""
40import numpy as np
41
42from tensorflow.core.protobuf import debug_event_pb2
43
44
45class BaseMonitor(object):
46  """Base class for debug event data monitors."""
47
48  def __init__(self, debug_events_reader):
49    self._debug_data_reader = debug_events_reader
50    debug_events_reader._add_monitor(self)  # pylint:disable=protected-access
51
52  def on_execution(self, execution_index, execution):
53    """Monitor method for top-level execution events.
54
55    Return values (if any) are ignored by the associated DebugDataReader.
56
57    Args:
58      execution_index: The index of the top-level execution event, as an int.
59      execution: An Execution data object, for a top-level op or function
60        execution event.
61    """
62
63  def on_graph_execution_trace(self,
64                               graph_execution_trace_index,
65                               graph_execution_trace):
66    """Monitor method for intra-graph execution events.
67
68    Return values (if any) are ignored by the associated DebugDataReader.
69
70    Args:
71      graph_execution_trace_index: The index of the intra-graph execution
72        event, as an int.
73      graph_execution_trace: A GraphExecutionTrace data object, for an
74        intra-graph tensor event.
75    """
76
77  # TODO(cais): Add more monitor methods such as on_graph_op_creation().
78
79
80class InfNanAlert(object):
81  """Alert for Infinity and NaN values."""
82
83  def __init__(self,
84               wall_time,
85               op_type,
86               output_slot,
87               size=None,
88               num_neg_inf=None,
89               num_pos_inf=None,
90               num_nan=None,
91               execution_index=None,
92               graph_execution_trace_index=None):
93    self._wall_time = wall_time
94    self._op_type = op_type
95    self._output_slot = output_slot
96    self._size = size
97    self._num_neg_inf = num_neg_inf
98    self._num_pos_inf = num_pos_inf
99    self._num_nan = num_nan
100    self._execution_index = execution_index
101    self._graph_execution_trace_index = graph_execution_trace_index
102
103  @property
104  def wall_time(self):
105    return self._wall_time
106
107  @property
108  def op_type(self):
109    return self._op_type
110
111  @property
112  def output_slot(self):
113    return self._output_slot
114
115  @property
116  def size(self):
117    return self._size
118
119  @property
120  def num_neg_inf(self):
121    return self._num_neg_inf
122
123  @property
124  def num_pos_inf(self):
125    return self._num_pos_inf
126
127  @property
128  def num_nan(self):
129    return self._num_nan
130
131  @property
132  def execution_index(self):
133    return self._execution_index
134
135  @property
136  def graph_execution_trace_index(self):
137    return self._graph_execution_trace_index
138
139
140class InfNanMonitor(BaseMonitor):
141  """Monitor for Infinity and NaN in tensor values."""
142
143  def __init__(self, debug_events_reader, limit=0):
144    super(InfNanMonitor, self).__init__(debug_events_reader)
145    self._limit = limit  # Track only the first _ alert events, for efficiency.
146    self._alerts = []
147
148  def _check_full_tensor_value(self,
149                               tensor_value,
150                               wall_time,
151                               op_type,
152                               output_slot,
153                               execution_index=None,
154                               graph_execution_trace_index=None):
155    """Check a full tensor value.
156
157    Appends to the list of alerts if any inf or nan is found in the full tensor
158    value.
159
160    Args:
161      tensor_value: The full tensor value as a `np.ndarray`.
162      wall_time: Wall timestamp for the execution event that generated the
163        tensor value.
164      op_type: Op type executed.
165      output_slot: The output slot of the op.
166      execution_index: Index to the top-level execution event.
167      graph_execution_trace_index: Index to the intra-graph execution trace
168        (if applicable.)
169    """
170    size = np.size(tensor_value)
171    if not size or not np.issubdtype(tensor_value.dtype, np.floating):
172      return
173    is_inf = np.isinf(tensor_value)
174    num_neg_inf = np.count_nonzero(
175        np.logical_and(is_inf, np.less(tensor_value, 0.0)))
176    num_pos_inf = np.count_nonzero(
177        np.logical_and(is_inf, np.greater(tensor_value, 0.0)))
178    num_nan = np.count_nonzero(np.isnan(tensor_value))
179    if num_neg_inf or num_pos_inf or num_nan:
180      self._alerts.append(InfNanAlert(
181          wall_time,
182          op_type,
183          output_slot,
184          size=size,
185          num_neg_inf=num_neg_inf,
186          num_pos_inf=num_pos_inf,
187          num_nan=num_nan,
188          execution_index=execution_index,
189          graph_execution_trace_index=graph_execution_trace_index))
190
191  def _check_debug_tensor_value(self,
192                                tensor_debug_mode,
193                                debug_tensor_value,
194                                wall_time,
195                                op_type,
196                                output_slot,
197                                execution_index=None,
198                                graph_execution_trace_index=None):
199    """Check for bad numerical values based on debug summary of tensor value.
200
201    If tensor_debug_mode is one in which debug_tensor_value does not carry
202    information about the presence or count of inf / nan values (e.g., SHAPE),
203    this method is a no-op.
204
205    When infs and/or nans are found, `InfNanAlert` objects are created and
206    appended to `self._alerts`.
207
208    Args:
209      tensor_debug_mode: TensorDebugMode proto enum.
210      debug_tensor_value: Debug tensor value as a list of numbers.
211      wall_time: Wall timestamp for the tensor event.
212      op_type: Type of the op that generated the tensor (e.g., "Conv2D").
213      output_slot: Output slot index of the tensor for the op.
214      execution_index: Top-level execution index.
215      graph_execution_trace_index: Intra-graph execution index.
216    """
217    # FULL_TENSOR mode is handled by a separate code path.
218    assert tensor_debug_mode != debug_event_pb2.TensorDebugMode.FULL_TENSOR
219    if not debug_tensor_value:
220      return
221    if tensor_debug_mode == debug_event_pb2.TensorDebugMode.CURT_HEALTH:
222      _, any_nan_inf = debug_tensor_value
223      if any_nan_inf:
224        self._alerts.append(InfNanAlert(
225            wall_time,
226            op_type,
227            output_slot,
228            execution_index=execution_index,
229            graph_execution_trace_index=graph_execution_trace_index))
230    elif tensor_debug_mode == debug_event_pb2.TensorDebugMode.CONCISE_HEALTH:
231      _, size, num_neg_inf, num_pos_inf, num_nan = debug_tensor_value
232      if num_neg_inf or num_pos_inf or num_nan:
233        self._alerts.append(InfNanAlert(
234            wall_time,
235            op_type,
236            output_slot,
237            size=size,
238            num_neg_inf=num_neg_inf,
239            num_pos_inf=num_pos_inf,
240            num_nan=num_nan,
241            execution_index=execution_index,
242            graph_execution_trace_index=graph_execution_trace_index))
243    elif tensor_debug_mode == debug_event_pb2.TensorDebugMode.FULL_HEALTH:
244      (_, _, _, _, size, num_neg_inf, num_pos_inf, num_nan,
245       _, _, _) = debug_tensor_value
246      if num_neg_inf or num_pos_inf or num_nan:
247        self._alerts.append(InfNanAlert(
248            wall_time,
249            op_type,
250            output_slot,
251            size=size,
252            num_neg_inf=num_neg_inf,
253            num_pos_inf=num_pos_inf,
254            num_nan=num_nan,
255            execution_index=execution_index,
256            graph_execution_trace_index=graph_execution_trace_index))
257
258  def on_execution(self,
259                   execution_index,
260                   execution):
261    if self._limit > 0 and len(self._alerts) >= self._limit:
262      return
263    if (execution.tensor_debug_mode ==
264        debug_event_pb2.TensorDebugMode.FULL_TENSOR):
265      tensor_values = self._debug_data_reader.execution_to_tensor_values(
266          execution)
267      for output_slot, tensor_value in enumerate(tensor_values):
268        self._check_full_tensor_value(
269            tensor_value, execution.wall_time, execution.op_type, output_slot,
270            execution_index=execution_index)
271    elif execution.debug_tensor_values:
272      for output_slot, debug_tensor_value in enumerate(
273          execution.debug_tensor_values):
274        self._check_debug_tensor_value(
275            execution.tensor_debug_mode,
276            debug_tensor_value,
277            execution.wall_time,
278            execution.op_type,
279            output_slot,
280            execution_index=execution_index)
281
282  def on_graph_execution_trace(self,
283                               graph_execution_trace_index,
284                               graph_execution_trace):
285    """Monitor method for GraphExecutionTrace data object."""
286    if self._limit > 0 and len(self._alerts) >= self._limit:
287      return
288    if (graph_execution_trace.tensor_debug_mode ==
289        debug_event_pb2.TensorDebugMode.FULL_TENSOR):
290      tensor_value = (
291          self._debug_data_reader.graph_execution_trace_to_tensor_value(
292              graph_execution_trace))
293      self._check_full_tensor_value(
294          tensor_value, graph_execution_trace.wall_time,
295          graph_execution_trace.op_type, graph_execution_trace.output_slot,
296          graph_execution_trace_index=graph_execution_trace_index)
297    elif graph_execution_trace.debug_tensor_value:
298      self._check_debug_tensor_value(
299          graph_execution_trace.tensor_debug_mode,
300          graph_execution_trace.debug_tensor_value,
301          graph_execution_trace.wall_time,
302          graph_execution_trace.op_type,
303          graph_execution_trace.output_slot,
304          graph_execution_trace_index=graph_execution_trace_index)
305
306  def alerts(self):
307    return self._alerts
308