1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Run doctests for tensorflow.""" 16 17import doctest 18import re 19import textwrap 20 21import numpy as np 22 23 24class _FloatExtractor(object): 25 """Class for extracting floats from a string. 26 27 For example: 28 29 >>> text_parts, floats = _FloatExtractor()("Text 1.0 Text") 30 >>> text_parts 31 ["Text ", " Text"] 32 >>> floats 33 np.array([1.0]) 34 """ 35 36 # Note: non-capturing groups "(?" are not returned in matched groups, or by 37 # re.split. 38 _FLOAT_RE = re.compile( 39 r""" 40 ( # Captures the float value. 41 (?: 42 [-+]| # Start with a sign is okay anywhere. 43 (?: # Otherwise: 44 ^| # Start after the start of string 45 (?<=[^\w.]) # Not after a word char, or a . 46 ) 47 ) 48 (?: # Digits and exponent - something like: 49 {digits_dot_maybe_digits}{exponent}?| # "1.0" "1." "1.0e3", "1.e3" 50 {dot_digits}{exponent}?| # ".1" ".1e3" 51 {digits}{exponent}| # "1e3" 52 {digits}(?=j) # "300j" 53 ) 54 ) 55 j? # Optional j for cplx numbers, not captured. 56 (?= # Only accept the match if 57 $| # * At the end of the string, or 58 [^\w.] # * Next char is not a word char or "." 59 ) 60 """.format( 61 # Digits, a "." and optional more digits: "1.1". 62 digits_dot_maybe_digits=r'(?:[0-9]+\.(?:[0-9]*))', 63 # A "." with trailing digits ".23" 64 dot_digits=r'(?:\.[0-9]+)', 65 # digits: "12" 66 digits=r'(?:[0-9]+)', 67 # The exponent: An "e" or "E", optional sign, and at least one digit. 68 # "e-123", "E+12", "e12" 69 exponent=r'(?:[eE][-+]?[0-9]+)'), 70 re.VERBOSE) 71 72 def __call__(self, string): 73 """Extracts floats from a string. 74 75 >>> text_parts, floats = _FloatExtractor()("Text 1.0 Text") 76 >>> text_parts 77 ["Text ", " Text"] 78 >>> floats 79 np.array([1.0]) 80 81 Args: 82 string: the string to extract floats from. 83 84 Returns: 85 A (string, array) pair, where `string` has each float replaced by "..." 86 and `array` is a `float32` `numpy.array` containing the extracted floats. 87 """ 88 texts = [] 89 floats = [] 90 for i, part in enumerate(self._FLOAT_RE.split(string)): 91 if i % 2 == 0: 92 texts.append(part) 93 else: 94 floats.append(float(part)) 95 96 return texts, np.array(floats) 97 98 99class TfDoctestOutputChecker(doctest.OutputChecker, object): 100 """Customizes how `want` and `got` are compared, see `check_output`.""" 101 102 def __init__(self, *args, **kwargs): 103 super(TfDoctestOutputChecker, self).__init__(*args, **kwargs) 104 self.extract_floats = _FloatExtractor() 105 self.text_good = None 106 self.float_size_good = None 107 108 _ADDRESS_RE = re.compile(r'\bat 0x[0-9a-f]*?>') 109 # TODO(yashkatariya): Add other tensor's string substitutions too. 110 # tf.RaggedTensor doesn't need one. 111 _NUMPY_OUTPUT_RE = re.compile(r'<tf.Tensor.*?numpy=(.*?)>', re.DOTALL) 112 113 def _allclose(self, want, got, rtol=1e-3, atol=1e-3): 114 return np.allclose(want, got, rtol=rtol, atol=atol) 115 116 def _tf_tensor_numpy_output(self, string): 117 modified_string = self._NUMPY_OUTPUT_RE.sub(r'\1', string) 118 return modified_string, modified_string != string 119 120 MESSAGE = textwrap.dedent("""\n 121 ############################################################# 122 Check the documentation (https://www.tensorflow.org/community/contribute/docs_ref) on how to 123 write testable docstrings. 124 #############################################################""") 125 126 def check_output(self, want, got, optionflags): 127 """Compares the docstring output to the output gotten by running the code. 128 129 Python addresses in the output are replaced with wildcards. 130 131 Float values in the output compared as using `np.allclose`: 132 133 * Float values are extracted from the text and replaced with wildcards. 134 * The wildcard text is compared to the actual output. 135 * The float values are compared using `np.allclose`. 136 137 The method returns `True` if both the text comparison and the numeric 138 comparison are successful. 139 140 The numeric comparison will fail if either: 141 142 * The wrong number of floats are found. 143 * The float values are not within tolerence. 144 145 Args: 146 want: The output in the docstring. 147 got: The output generated after running the snippet. 148 optionflags: Flags passed to the doctest. 149 150 Returns: 151 A bool, indicating if the check was successful or not. 152 """ 153 154 # If the docstring's output is empty and there is some output generated 155 # after running the snippet, return True. This is because if the user 156 # doesn't want to display output, respect that over what the doctest wants. 157 if got and not want: 158 return True 159 160 if want is None: 161 want = '' 162 163 # Replace python's addresses with ellipsis (`...`) since it can change on 164 # each execution. 165 want = self._ADDRESS_RE.sub('at ...>', want) 166 167 # Replace tf.Tensor strings with only their numpy field values. 168 want, want_changed = self._tf_tensor_numpy_output(want) 169 if want_changed: 170 got, _ = self._tf_tensor_numpy_output(got) 171 172 # Separate out the floats, and replace `want` with the wild-card version 173 # "result=7.0" => "result=..." 174 want_text_parts, self.want_floats = self.extract_floats(want) 175 want_text_parts = [part.strip() for part in want_text_parts] 176 want_text_wild = '...'.join(want_text_parts) 177 178 # Find the floats in the string returned by the test 179 _, self.got_floats = self.extract_floats(got) 180 181 self.text_good = super(TfDoctestOutputChecker, self).check_output( 182 want=want_text_wild, got=got, optionflags=optionflags) 183 if not self.text_good: 184 return False 185 186 if self.want_floats.size == 0: 187 # If there are no floats in the "want" string, ignore all the floats in 188 # the result. "np.array([ ... ])" matches "np.array([ 1.0, 2.0 ])" 189 return True 190 191 self.float_size_good = (self.want_floats.size == self.got_floats.size) 192 193 if self.float_size_good: 194 return self._allclose(self.want_floats, self.got_floats) 195 else: 196 return False 197 198 def output_difference(self, example, got, optionflags): 199 got = [got] 200 201 # If the some of the float output is hidden with `...`, `float_size_good` 202 # will be False. This is because the floats extracted from the string is 203 # converted into a 1-D numpy array. Hence hidding floats is not allowed 204 # anymore. 205 if self.text_good: 206 if not self.float_size_good: 207 got.append("\n\nCAUTION: tf_doctest doesn't work if *some* of the " 208 "*float output* is hidden with a \"...\".") 209 210 got.append(self.MESSAGE) 211 got = '\n'.join(got) 212 return (super(TfDoctestOutputChecker, 213 self).output_difference(example, got, optionflags)) 214