xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/docs/tf_doctest_lib.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Run doctests for tensorflow."""
16
17import doctest
18import re
19import textwrap
20
21import numpy as np
22
23
24class _FloatExtractor(object):
25  """Class for extracting floats from a string.
26
27  For example:
28
29  >>> text_parts, floats = _FloatExtractor()("Text 1.0 Text")
30  >>> text_parts
31  ["Text ", " Text"]
32  >>> floats
33  np.array([1.0])
34  """
35
36  # Note: non-capturing groups "(?" are not returned in matched groups, or by
37  # re.split.
38  _FLOAT_RE = re.compile(
39      r"""
40      (                          # Captures the float value.
41        (?:
42           [-+]|                 # Start with a sign is okay anywhere.
43           (?:                   # Otherwise:
44               ^|                # Start after the start of string
45               (?<=[^\w.])       # Not after a word char, or a .
46           )
47        )
48        (?:                      # Digits and exponent - something like:
49          {digits_dot_maybe_digits}{exponent}?|   # "1.0" "1." "1.0e3", "1.e3"
50          {dot_digits}{exponent}?|                # ".1" ".1e3"
51          {digits}{exponent}|                     # "1e3"
52          {digits}(?=j)                           # "300j"
53        )
54      )
55      j?                         # Optional j for cplx numbers, not captured.
56      (?=                        # Only accept the match if
57        $|                       # * At the end of the string, or
58        [^\w.]                   # * Next char is not a word char or "."
59      )
60      """.format(
61          # Digits, a "." and optional more digits: "1.1".
62          digits_dot_maybe_digits=r'(?:[0-9]+\.(?:[0-9]*))',
63          # A "." with trailing digits ".23"
64          dot_digits=r'(?:\.[0-9]+)',
65          # digits: "12"
66          digits=r'(?:[0-9]+)',
67          # The exponent: An "e" or "E", optional sign, and at least one digit.
68          # "e-123", "E+12", "e12"
69          exponent=r'(?:[eE][-+]?[0-9]+)'),
70      re.VERBOSE)
71
72  def __call__(self, string):
73    """Extracts floats from a string.
74
75    >>> text_parts, floats = _FloatExtractor()("Text 1.0 Text")
76    >>> text_parts
77    ["Text ", " Text"]
78    >>> floats
79    np.array([1.0])
80
81    Args:
82      string: the string to extract floats from.
83
84    Returns:
85      A (string, array) pair, where `string` has each float replaced by "..."
86      and `array` is a `float32` `numpy.array` containing the extracted floats.
87    """
88    texts = []
89    floats = []
90    for i, part in enumerate(self._FLOAT_RE.split(string)):
91      if i % 2 == 0:
92        texts.append(part)
93      else:
94        floats.append(float(part))
95
96    return texts, np.array(floats)
97
98
99class TfDoctestOutputChecker(doctest.OutputChecker, object):
100  """Customizes how `want` and `got` are compared, see `check_output`."""
101
102  def __init__(self, *args, **kwargs):
103    super(TfDoctestOutputChecker, self).__init__(*args, **kwargs)
104    self.extract_floats = _FloatExtractor()
105    self.text_good = None
106    self.float_size_good = None
107
108  _ADDRESS_RE = re.compile(r'\bat 0x[0-9a-f]*?>')
109  # TODO(yashkatariya): Add other tensor's string substitutions too.
110  # tf.RaggedTensor doesn't need one.
111  _NUMPY_OUTPUT_RE = re.compile(r'<tf.Tensor.*?numpy=(.*?)>', re.DOTALL)
112
113  def _allclose(self, want, got, rtol=1e-3, atol=1e-3):
114    return np.allclose(want, got, rtol=rtol, atol=atol)
115
116  def _tf_tensor_numpy_output(self, string):
117    modified_string = self._NUMPY_OUTPUT_RE.sub(r'\1', string)
118    return modified_string, modified_string != string
119
120  MESSAGE = textwrap.dedent("""\n
121        #############################################################
122        Check the documentation (https://www.tensorflow.org/community/contribute/docs_ref) on how to
123        write testable docstrings.
124        #############################################################""")
125
126  def check_output(self, want, got, optionflags):
127    """Compares the docstring output to the output gotten by running the code.
128
129    Python addresses in the output are replaced with wildcards.
130
131    Float values in the output compared as using `np.allclose`:
132
133      * Float values are extracted from the text and replaced with wildcards.
134      * The wildcard text is compared to the actual output.
135      * The float values are compared using `np.allclose`.
136
137    The method returns `True` if both the text comparison and the numeric
138    comparison are successful.
139
140    The numeric comparison will fail if either:
141
142      * The wrong number of floats are found.
143      * The float values are not within tolerence.
144
145    Args:
146      want: The output in the docstring.
147      got: The output generated after running the snippet.
148      optionflags: Flags passed to the doctest.
149
150    Returns:
151      A bool, indicating if the check was successful or not.
152    """
153
154    # If the docstring's output is empty and there is some output generated
155    # after running the snippet, return True. This is because if the user
156    # doesn't want to display output, respect that over what the doctest wants.
157    if got and not want:
158      return True
159
160    if want is None:
161      want = ''
162
163    # Replace python's addresses with ellipsis (`...`) since it can change on
164    # each execution.
165    want = self._ADDRESS_RE.sub('at ...>', want)
166
167    # Replace tf.Tensor strings with only their numpy field values.
168    want, want_changed = self._tf_tensor_numpy_output(want)
169    if want_changed:
170      got, _ = self._tf_tensor_numpy_output(got)
171
172    # Separate out the floats, and replace `want` with the wild-card version
173    # "result=7.0" => "result=..."
174    want_text_parts, self.want_floats = self.extract_floats(want)
175    want_text_parts = [part.strip() for part in want_text_parts]
176    want_text_wild = '...'.join(want_text_parts)
177
178    # Find the floats in the string returned by the test
179    _, self.got_floats = self.extract_floats(got)
180
181    self.text_good = super(TfDoctestOutputChecker, self).check_output(
182        want=want_text_wild, got=got, optionflags=optionflags)
183    if not self.text_good:
184      return False
185
186    if self.want_floats.size == 0:
187      # If there are no floats in the "want" string, ignore all the floats in
188      # the result. "np.array([ ... ])" matches "np.array([ 1.0, 2.0 ])"
189      return True
190
191    self.float_size_good = (self.want_floats.size == self.got_floats.size)
192
193    if self.float_size_good:
194      return self._allclose(self.want_floats, self.got_floats)
195    else:
196      return False
197
198  def output_difference(self, example, got, optionflags):
199    got = [got]
200
201    # If the some of the float output is hidden with `...`, `float_size_good`
202    # will be False. This is because the floats extracted from the string is
203    # converted into a 1-D numpy array. Hence hidding floats is not allowed
204    # anymore.
205    if self.text_good:
206      if not self.float_size_good:
207        got.append("\n\nCAUTION: tf_doctest doesn't work if *some* of the "
208                   "*float output* is hidden with a \"...\".")
209
210    got.append(self.MESSAGE)
211    got = '\n'.join(got)
212    return (super(TfDoctestOutputChecker,
213                  self).output_difference(example, got, optionflags))
214