1# Copyright 2017 The Abseil Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""A Python test reporter that generates test reports in JUnit XML format."""
16
17import datetime
18import re
19import sys
20import threading
21import time
22import traceback
23import unittest
24from xml.sax import saxutils
25from absl.testing import _pretty_print_reporter
26
27
28# See http://www.w3.org/TR/REC-xml/#NT-Char
29_bad_control_character_codes = set(range(0, 0x20)) - {0x9, 0xA, 0xD}
30
31
32_control_character_conversions = {
33    chr(i): '\\x{:02x}'.format(i) for i in _bad_control_character_codes}
34
35
36_escape_xml_attr_conversions = {
37    '"': '"',
38    "'": ''',
39    '\n': '
',
40    '\t': '	',
41    '\r': '
',
42    ' ': ' '}
43_escape_xml_attr_conversions.update(_control_character_conversions)
44
45
46# When class or module level function fails, unittest/suite.py adds a
47# _ErrorHolder instance instead of a real TestCase, and it has a description
48# like "setUpClass (__main__.MyTestCase)".
49_CLASS_OR_MODULE_LEVEL_TEST_DESC_REGEX = re.compile(r'^(\w+) \((\S+)\)$')
50
51
52# NOTE: while saxutils.quoteattr() theoretically does the same thing; it
53# seems to often end up being too smart for it's own good not escaping properly.
54# This function is much more reliable.
55def _escape_xml_attr(content):
56  """Escapes xml attributes."""
57  # Note: saxutils doesn't escape the quotes.
58  return saxutils.escape(content, _escape_xml_attr_conversions)
59
60
61def _escape_cdata(s):
62  """Escapes a string to be used as XML CDATA.
63
64  CDATA characters are treated strictly as character data, not as XML markup,
65  but there are still certain restrictions on them.
66
67  Args:
68    s: the string to be escaped.
69  Returns:
70    An escaped version of the input string.
71  """
72  for char, escaped in _control_character_conversions.items():
73    s = s.replace(char, escaped)
74  return s.replace(']]>', ']] >')
75
76
77def _iso8601_timestamp(timestamp):
78  """Produces an ISO8601 datetime.
79
80  Args:
81    timestamp: an Epoch based timestamp in seconds.
82
83  Returns:
84    A iso8601 format timestamp if the input is a valid timestamp, None otherwise
85  """
86  if timestamp is None or timestamp < 0:
87    return None
88  return datetime.datetime.fromtimestamp(
89      timestamp, tz=datetime.timezone.utc).isoformat()
90
91
92def _print_xml_element_header(element, attributes, stream, indentation=''):
93  """Prints an XML header of an arbitrary element.
94
95  Args:
96    element: element name (testsuites, testsuite, testcase)
97    attributes: 2-tuple list with (attributes, values) already escaped
98    stream: output stream to write test report XML to
99    indentation: indentation added to the element header
100  """
101  stream.write('%s<%s' % (indentation, element))
102  for attribute in attributes:
103    if (len(attribute) == 2 and attribute[0] is not None and
104        attribute[1] is not None):
105      stream.write(' %s="%s"' % (attribute[0], attribute[1]))
106  stream.write('>\n')
107
108# Copy time.time which ensures the real time is used internally.
109# This prevents bad interactions with tests that stub out time.
110_time_copy = time.time
111
112if hasattr(traceback, '_some_str'):
113  # Use the traceback module str function to format safely.
114  _safe_str = traceback._some_str
115else:
116  _safe_str = str  # pylint: disable=invalid-name
117
118
119class _TestCaseResult(object):
120  """Private helper for _TextAndXMLTestResult that represents a test result.
121
122  Attributes:
123    test: A TestCase instance of an individual test method.
124    name: The name of the individual test method.
125    full_class_name: The full name of the test class.
126    run_time: The duration (in seconds) it took to run the test.
127    start_time: Epoch relative timestamp of when test started (in seconds)
128    errors: A list of error 4-tuples. Error tuple entries are
129        1) a string identifier of either "failure" or "error"
130        2) an exception_type
131        3) an exception_message
132        4) a string version of a sys.exc_info()-style tuple of values
133           ('error', err[0], err[1], self._exc_info_to_string(err))
134           If the length of errors is 0, then the test is either passed or
135           skipped.
136    skip_reason: A string explaining why the test was skipped.
137  """
138
139  def __init__(self, test):
140    self.run_time = -1
141    self.start_time = -1
142    self.skip_reason = None
143    self.errors = []
144    self.test = test
145
146    # Parse the test id to get its test name and full class path.
147    # Unfortunately there is no better way of knowning the test and class.
148    # Worse, unittest uses _ErrorHandler instances to represent class / module
149    # level failures.
150    test_desc = test.id() or str(test)
151    # Check if it's something like "setUpClass (__main__.TestCase)".
152    match = _CLASS_OR_MODULE_LEVEL_TEST_DESC_REGEX.match(test_desc)
153    if match:
154      name = match.group(1)
155      full_class_name = match.group(2)
156    else:
157      class_name = unittest.util.strclass(test.__class__)
158      if isinstance(test, unittest.case._SubTest):
159        # If the test case is a _SubTest, the real TestCase instance is
160        # available as _SubTest.test_case.
161        class_name = unittest.util.strclass(test.test_case.__class__)
162      if test_desc.startswith(class_name + '.'):
163        # In a typical unittest.TestCase scenario, test.id() returns with
164        # a class name formatted using unittest.util.strclass.
165        name = test_desc[len(class_name)+1:]
166        full_class_name = class_name
167      else:
168        # Otherwise make a best effort to guess the test name and full class
169        # path.
170        parts = test_desc.rsplit('.', 1)
171        name = parts[-1]
172        full_class_name = parts[0] if len(parts) == 2 else ''
173    self.name = _escape_xml_attr(name)
174    self.full_class_name = _escape_xml_attr(full_class_name)
175
176  def set_run_time(self, time_in_secs):
177    self.run_time = time_in_secs
178
179  def set_start_time(self, time_in_secs):
180    self.start_time = time_in_secs
181
182  def print_xml_summary(self, stream):
183    """Prints an XML Summary of a TestCase.
184
185    Status and result are populated as per JUnit XML test result reporter.
186    A test that has been skipped will always have a skip reason,
187    as every skip method in Python's unittest requires the reason arg to be
188    passed.
189
190    Args:
191      stream: output stream to write test report XML to
192    """
193
194    if self.skip_reason is None:
195      status = 'run'
196      result = 'completed'
197    else:
198      status = 'notrun'
199      result = 'suppressed'
200
201    test_case_attributes = [
202        ('name', '%s' % self.name),
203        ('status', '%s' % status),
204        ('result', '%s' % result),
205        ('time', '%.3f' % self.run_time),
206        ('classname', self.full_class_name),
207        ('timestamp', _iso8601_timestamp(self.start_time)),
208    ]
209    _print_xml_element_header('testcase', test_case_attributes, stream, '  ')
210    self._print_testcase_details(stream)
211    stream.write('  </testcase>\n')
212
213  def _print_testcase_details(self, stream):
214    for error in self.errors:
215      outcome, exception_type, message, error_msg = error  # pylint: disable=unpacking-non-sequence
216      message = _escape_xml_attr(_safe_str(message))
217      exception_type = _escape_xml_attr(str(exception_type))
218      error_msg = _escape_cdata(error_msg)
219      stream.write('  <%s message="%s" type="%s"><![CDATA[%s]]></%s>\n'
220                   % (outcome, message, exception_type, error_msg, outcome))
221
222
223class _TestSuiteResult(object):
224  """Private helper for _TextAndXMLTestResult."""
225
226  def __init__(self):
227    self.suites = {}
228    self.failure_counts = {}
229    self.error_counts = {}
230    self.overall_start_time = -1
231    self.overall_end_time = -1
232    self._testsuites_properties = {}
233
234  def add_test_case_result(self, test_case_result):
235    suite_name = type(test_case_result.test).__name__
236    if suite_name == '_ErrorHolder':
237      # _ErrorHolder is a special case created by unittest for class / module
238      # level functions.
239      suite_name = test_case_result.full_class_name.rsplit('.')[-1]
240    if isinstance(test_case_result.test, unittest.case._SubTest):
241      # If the test case is a _SubTest, the real TestCase instance is
242      # available as _SubTest.test_case.
243      suite_name = type(test_case_result.test.test_case).__name__
244
245    self._setup_test_suite(suite_name)
246    self.suites[suite_name].append(test_case_result)
247    for error in test_case_result.errors:
248      # Only count the first failure or error so that the sum is equal to the
249      # total number of *testcases* that have failures or errors.
250      if error[0] == 'failure':
251        self.failure_counts[suite_name] += 1
252        break
253      elif error[0] == 'error':
254        self.error_counts[suite_name] += 1
255        break
256
257  def print_xml_summary(self, stream):
258    overall_test_count = sum(len(x) for x in self.suites.values())
259    overall_failures = sum(self.failure_counts.values())
260    overall_errors = sum(self.error_counts.values())
261    overall_attributes = [
262        ('name', ''),
263        ('tests', '%d' % overall_test_count),
264        ('failures', '%d' % overall_failures),
265        ('errors', '%d' % overall_errors),
266        ('time', '%.3f' % (self.overall_end_time - self.overall_start_time)),
267        ('timestamp', _iso8601_timestamp(self.overall_start_time)),
268    ]
269    _print_xml_element_header('testsuites', overall_attributes, stream)
270    if self._testsuites_properties:
271      stream.write('    <properties>\n')
272      for name, value in sorted(self._testsuites_properties.items()):
273        stream.write('      <property name="%s" value="%s"></property>\n' %
274                     (_escape_xml_attr(name), _escape_xml_attr(str(value))))
275      stream.write('    </properties>\n')
276
277    for suite_name in self.suites:
278      suite = self.suites[suite_name]
279      suite_end_time = max(x.start_time + x.run_time for x in suite)
280      suite_start_time = min(x.start_time for x in suite)
281      failures = self.failure_counts[suite_name]
282      errors = self.error_counts[suite_name]
283      suite_attributes = [
284          ('name', '%s' % suite_name),
285          ('tests', '%d' % len(suite)),
286          ('failures', '%d' % failures),
287          ('errors', '%d' % errors),
288          ('time', '%.3f' % (suite_end_time - suite_start_time)),
289          ('timestamp', _iso8601_timestamp(suite_start_time)),
290      ]
291      _print_xml_element_header('testsuite', suite_attributes, stream)
292
293      # test_case_result entries are not guaranteed to be in any user-friendly
294      # order, especially when using subtests. So sort them.
295      for test_case_result in sorted(suite, key=lambda t: t.name):
296        test_case_result.print_xml_summary(stream)
297      stream.write('</testsuite>\n')
298    stream.write('</testsuites>\n')
299
300  def _setup_test_suite(self, suite_name):
301    """Adds a test suite to the set of suites tracked by this test run.
302
303    Args:
304      suite_name: string, The name of the test suite being initialized.
305    """
306    if suite_name in self.suites:
307      return
308    self.suites[suite_name] = []
309    self.failure_counts[suite_name] = 0
310    self.error_counts[suite_name] = 0
311
312  def set_end_time(self, timestamp_in_secs):
313    """Sets the start timestamp of this test suite.
314
315    Args:
316      timestamp_in_secs: timestamp in seconds since epoch
317    """
318    self.overall_end_time = timestamp_in_secs
319
320  def set_start_time(self, timestamp_in_secs):
321    """Sets the end timestamp of this test suite.
322
323    Args:
324      timestamp_in_secs: timestamp in seconds since epoch
325    """
326    self.overall_start_time = timestamp_in_secs
327
328
329class _TextAndXMLTestResult(_pretty_print_reporter.TextTestResult):
330  """Private TestResult class that produces both formatted text results and XML.
331
332  Used by TextAndXMLTestRunner.
333  """
334
335  _TEST_SUITE_RESULT_CLASS = _TestSuiteResult
336  _TEST_CASE_RESULT_CLASS = _TestCaseResult
337
338  def __init__(self, xml_stream, stream, descriptions, verbosity,
339               time_getter=_time_copy, testsuites_properties=None):
340    super(_TextAndXMLTestResult, self).__init__(stream, descriptions, verbosity)
341    self.xml_stream = xml_stream
342    self.pending_test_case_results = {}
343    self.suite = self._TEST_SUITE_RESULT_CLASS()
344    if testsuites_properties:
345      self.suite._testsuites_properties = testsuites_properties
346    self.time_getter = time_getter
347
348    # This lock guards any mutations on pending_test_case_results.
349    self._pending_test_case_results_lock = threading.RLock()
350
351  def startTest(self, test):
352    self.start_time = self.time_getter()
353    super(_TextAndXMLTestResult, self).startTest(test)
354
355  def stopTest(self, test):
356    # Grabbing the write lock to avoid conflicting with stopTestRun.
357    with self._pending_test_case_results_lock:
358      super(_TextAndXMLTestResult, self).stopTest(test)
359      result = self.get_pending_test_case_result(test)
360      if not result:
361        test_name = test.id() or str(test)
362        sys.stderr.write('No pending test case: %s\n' % test_name)
363        return
364      if getattr(self, 'start_time', None) is None:
365        # startTest may not be called for skipped tests since Python 3.12.1.
366        self.start_time = self.time_getter()
367      test_id = id(test)
368      run_time = self.time_getter() - self.start_time
369      result.set_run_time(run_time)
370      result.set_start_time(self.start_time)
371      self.suite.add_test_case_result(result)
372      del self.pending_test_case_results[test_id]
373
374  def startTestRun(self):
375    self.suite.set_start_time(self.time_getter())
376    super(_TextAndXMLTestResult, self).startTestRun()
377
378  def stopTestRun(self):
379    self.suite.set_end_time(self.time_getter())
380    # All pending_test_case_results will be added to the suite and removed from
381    # the pending_test_case_results dictionary. Grabbing the write lock to avoid
382    # results from being added during this process to avoid duplicating adds or
383    # accidentally erasing newly appended pending results.
384    with self._pending_test_case_results_lock:
385      # Errors in the test fixture (setUpModule, tearDownModule,
386      # setUpClass, tearDownClass) can leave a pending result which
387      # never gets added to the suite.  The runner calls stopTestRun
388      # which gives us an opportunity to add these errors for
389      # reporting here.
390      for test_id in self.pending_test_case_results:
391        result = self.pending_test_case_results[test_id]
392        if getattr(self, 'start_time', None) is not None:
393          run_time = self.suite.overall_end_time - self.start_time
394          result.set_run_time(run_time)
395          result.set_start_time(self.start_time)
396        self.suite.add_test_case_result(result)
397      self.pending_test_case_results.clear()
398
399  def _exc_info_to_string(self, err, test=None):
400    """Converts a sys.exc_info()-style tuple of values into a string.
401
402    This method must be overridden because the method signature in
403    unittest.TestResult changed between Python 2.2 and 2.4.
404
405    Args:
406      err: A sys.exc_info() tuple of values for an error.
407      test: The test method.
408
409    Returns:
410      A formatted exception string.
411    """
412    if test:
413      return super(_TextAndXMLTestResult, self)._exc_info_to_string(err, test)
414    return ''.join(traceback.format_exception(*err))
415
416  def add_pending_test_case_result(self, test, error_summary=None,
417                                   skip_reason=None):
418    """Adds result information to a test case result which may still be running.
419
420    If a result entry for the test already exists, add_pending_test_case_result
421    will add error summary tuples and/or overwrite skip_reason for the result.
422    If it does not yet exist, a result entry will be created.
423    Note that a test result is considered to have been run and passed
424    only if there are no errors or skip_reason.
425
426    Args:
427      test: A test method as defined by unittest
428      error_summary: A 4-tuple with the following entries:
429          1) a string identifier of either "failure" or "error"
430          2) an exception_type
431          3) an exception_message
432          4) a string version of a sys.exc_info()-style tuple of values
433             ('error', err[0], err[1], self._exc_info_to_string(err))
434             If the length of errors is 0, then the test is either passed or
435             skipped.
436      skip_reason: a string explaining why the test was skipped
437    """
438    with self._pending_test_case_results_lock:
439      test_id = id(test)
440      if test_id not in self.pending_test_case_results:
441        self.pending_test_case_results[test_id] = self._TEST_CASE_RESULT_CLASS(
442            test)
443      if error_summary:
444        self.pending_test_case_results[test_id].errors.append(error_summary)
445      if skip_reason:
446        self.pending_test_case_results[test_id].skip_reason = skip_reason
447
448  def delete_pending_test_case_result(self, test):
449    with self._pending_test_case_results_lock:
450      test_id = id(test)
451      del self.pending_test_case_results[test_id]
452
453  def get_pending_test_case_result(self, test):
454    test_id = id(test)
455    return self.pending_test_case_results.get(test_id, None)
456
457  def addSuccess(self, test):
458    super(_TextAndXMLTestResult, self).addSuccess(test)
459    self.add_pending_test_case_result(test)
460
461  def addError(self, test, err):
462    super(_TextAndXMLTestResult, self).addError(test, err)
463    error_summary = ('error', err[0], err[1],
464                     self._exc_info_to_string(err, test=test))
465    self.add_pending_test_case_result(test, error_summary=error_summary)
466
467  def addFailure(self, test, err):
468    super(_TextAndXMLTestResult, self).addFailure(test, err)
469    error_summary = ('failure', err[0], err[1],
470                     self._exc_info_to_string(err, test=test))
471    self.add_pending_test_case_result(test, error_summary=error_summary)
472
473  def addSkip(self, test, reason):
474    super(_TextAndXMLTestResult, self).addSkip(test, reason)
475    self.add_pending_test_case_result(test, skip_reason=reason)
476
477  def addExpectedFailure(self, test, err):
478    super(_TextAndXMLTestResult, self).addExpectedFailure(test, err)
479    if callable(getattr(test, 'recordProperty', None)):
480      test.recordProperty('EXPECTED_FAILURE',
481                          self._exc_info_to_string(err, test=test))
482    self.add_pending_test_case_result(test)
483
484  def addUnexpectedSuccess(self, test):
485    super(_TextAndXMLTestResult, self).addUnexpectedSuccess(test)
486    test_name = test.id() or str(test)
487    error_summary = ('error', '', '',
488                     'Test case %s should have failed, but passed.'
489                     % (test_name))
490    self.add_pending_test_case_result(test, error_summary=error_summary)
491
492  def addSubTest(self, test, subtest, err):  # pylint: disable=invalid-name
493    super(_TextAndXMLTestResult, self).addSubTest(test, subtest, err)
494    if err is not None:
495      if issubclass(err[0], test.failureException):
496        error_summary = ('failure', err[0], err[1],
497                         self._exc_info_to_string(err, test=test))
498      else:
499        error_summary = ('error', err[0], err[1],
500                         self._exc_info_to_string(err, test=test))
501    else:
502      error_summary = None
503    self.add_pending_test_case_result(subtest, error_summary=error_summary)
504
505  def printErrors(self):
506    super(_TextAndXMLTestResult, self).printErrors()
507    self.xml_stream.write('<?xml version="1.0"?>\n')
508    self.suite.print_xml_summary(self.xml_stream)
509
510
511class TextAndXMLTestRunner(unittest.TextTestRunner):
512  """A test runner that produces both formatted text results and XML.
513
514  It prints out the names of tests as they are run, errors as they
515  occur, and a summary of the results at the end of the test run.
516  """
517
518  _TEST_RESULT_CLASS = _TextAndXMLTestResult
519
520  _xml_stream = None
521  _testsuites_properties = {}
522
523  def __init__(self, xml_stream=None, *args, **kwargs):
524    """Initialize a TextAndXMLTestRunner.
525
526    Args:
527      xml_stream: file-like or None; XML-formatted test results are output
528          via this object's write() method.  If None (the default), the
529          new instance behaves as described in the set_default_xml_stream method
530          documentation below.
531      *args: passed unmodified to unittest.TextTestRunner.__init__.
532      **kwargs: passed unmodified to unittest.TextTestRunner.__init__.
533    """
534    super(TextAndXMLTestRunner, self).__init__(*args, **kwargs)
535    if xml_stream is not None:
536      self._xml_stream = xml_stream
537    # else, do not set self._xml_stream to None -- this allows implicit fallback
538    # to the class attribute's value.
539
540  @classmethod
541  def set_default_xml_stream(cls, xml_stream):
542    """Sets the default XML stream for the class.
543
544    Args:
545      xml_stream: file-like or None; used for instances when xml_stream is None
546          or not passed to their constructors.  If None is passed, instances
547          created with xml_stream=None will act as ordinary TextTestRunner
548          instances; this is the default state before any calls to this method
549          have been made.
550    """
551    cls._xml_stream = xml_stream
552
553  def _makeResult(self):
554    if self._xml_stream is None:
555      return super(TextAndXMLTestRunner, self)._makeResult()
556    else:
557      return self._TEST_RESULT_CLASS(
558          self._xml_stream, self.stream, self.descriptions, self.verbosity,
559          testsuites_properties=self._testsuites_properties)
560
561  @classmethod
562  def set_testsuites_property(cls, key, value):
563    cls._testsuites_properties[key] = value
564