1# Copyright 2017 The Abseil Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Base functionality for Abseil Python tests. 16 17This module contains base classes and high-level functions for Abseil-style 18tests. 19""" 20 21from collections import abc 22import contextlib 23import dataclasses 24import difflib 25import enum 26import errno 27import faulthandler 28import getpass 29import inspect 30import io 31import itertools 32import json 33import os 34import random 35import re 36import shlex 37import shutil 38import signal 39import stat 40import subprocess 41import sys 42import tempfile 43import textwrap 44import typing 45from typing import Any, AnyStr, BinaryIO, Callable, ContextManager, IO, Iterator, List, Mapping, MutableMapping, MutableSequence, NoReturn, Optional, Sequence, Text, TextIO, Tuple, Type, Union 46import unittest 47from unittest import mock # pylint: disable=unused-import Allow absltest.mock. 48from urllib import parse 49 50from absl import app # pylint: disable=g-import-not-at-top 51from absl import flags 52from absl import logging 53from absl.testing import _pretty_print_reporter 54from absl.testing import xml_reporter 55 56# Use an if-type-checking block to prevent leakage of type-checking only 57# symbols. We don't want people relying on these at runtime. 58if typing.TYPE_CHECKING: 59 # Unbounded TypeVar for general usage 60 _T = typing.TypeVar('_T') 61 62 import unittest.case # pylint: disable=g-import-not-at-top,g-bad-import-order 63 64 _OutcomeType = unittest.case._Outcome # pytype: disable=module-attr 65 66 67# pylint: enable=g-import-not-at-top 68 69# Re-export a bunch of unittest functions we support so that people don't 70# have to import unittest to get them 71# pylint: disable=invalid-name 72skip = unittest.skip 73skipIf = unittest.skipIf 74skipUnless = unittest.skipUnless 75SkipTest = unittest.SkipTest 76expectedFailure = unittest.expectedFailure 77# pylint: enable=invalid-name 78 79# End unittest re-exports 80 81FLAGS = flags.FLAGS 82 83_TEXT_OR_BINARY_TYPES = (str, bytes) 84 85# Suppress surplus entries in AssertionError stack traces. 86__unittest = True # pylint: disable=invalid-name 87 88 89def expectedFailureIf(condition, reason): # pylint: disable=invalid-name 90 """Expects the test to fail if the run condition is True. 91 92 Example usage:: 93 94 @expectedFailureIf(sys.version.major == 2, "Not yet working in py2") 95 def test_foo(self): 96 ... 97 98 Args: 99 condition: bool, whether to expect failure or not. 100 reason: Text, the reason to expect failure. 101 Returns: 102 Decorator function 103 """ 104 del reason # Unused 105 if condition: 106 return unittest.expectedFailure 107 else: 108 return lambda f: f 109 110 111class TempFileCleanup(enum.Enum): 112 # Always cleanup temp files when the test completes. 113 ALWAYS = 'always' 114 # Only cleanup temp file if the test passes. This allows easier inspection 115 # of tempfile contents on test failure. absltest.TEST_TMPDIR.value determines 116 # where tempfiles are created. 117 SUCCESS = 'success' 118 # Never cleanup temp files. 119 OFF = 'never' 120 121 122# Many of the methods in this module have names like assertSameElements. 123# This kind of name does not comply with PEP8 style, 124# but it is consistent with the naming of methods in unittest.py. 125# pylint: disable=invalid-name 126 127 128def _get_default_test_random_seed(): 129 # type: () -> int 130 random_seed = 301 131 value = os.environ.get('TEST_RANDOM_SEED', '') 132 try: 133 random_seed = int(value) 134 except ValueError: 135 pass 136 return random_seed 137 138 139def get_default_test_srcdir(): 140 # type: () -> Text 141 """Returns default test source dir.""" 142 return os.environ.get('TEST_SRCDIR', '') 143 144 145def get_default_test_tmpdir(): 146 # type: () -> Text 147 """Returns default test temp dir.""" 148 tmpdir = os.environ.get('TEST_TMPDIR', '') 149 if not tmpdir: 150 tmpdir = os.path.join(tempfile.gettempdir(), 'absl_testing') 151 152 return tmpdir 153 154 155def _get_default_randomize_ordering_seed(): 156 # type: () -> int 157 """Returns default seed to use for randomizing test order. 158 159 This function first checks the --test_randomize_ordering_seed flag, and then 160 the TEST_RANDOMIZE_ORDERING_SEED environment variable. If the first value 161 we find is: 162 * (not set): disable test randomization 163 * 0: disable test randomization 164 * 'random': choose a random seed in [1, 4294967295] for test order 165 randomization 166 * positive integer: use this seed for test order randomization 167 168 (The values used are patterned after 169 https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED). 170 171 In principle, it would be simpler to return None if no override is provided; 172 however, the python random module has no `get_seed()`, only `getstate()`, 173 which returns far more data than we want to pass via an environment variable 174 or flag. 175 176 Returns: 177 A default value for test case randomization (int). 0 means do not randomize. 178 179 Raises: 180 ValueError: Raised when the flag or env value is not one of the options 181 above. 182 """ 183 if FLAGS['test_randomize_ordering_seed'].present: 184 randomize = FLAGS.test_randomize_ordering_seed 185 elif 'TEST_RANDOMIZE_ORDERING_SEED' in os.environ: 186 randomize = os.environ['TEST_RANDOMIZE_ORDERING_SEED'] 187 else: 188 randomize = '' 189 if not randomize: 190 return 0 191 if randomize == 'random': 192 return random.Random().randint(1, 4294967295) 193 if randomize == '0': 194 return 0 195 try: 196 seed = int(randomize) 197 if seed > 0: 198 return seed 199 except ValueError: 200 pass 201 raise ValueError( 202 'Unknown test randomization seed value: {}'.format(randomize)) 203 204 205TEST_SRCDIR = flags.DEFINE_string( 206 'test_srcdir', 207 get_default_test_srcdir(), 208 'Root of directory tree where source files live', 209 allow_override_cpp=True) 210TEST_TMPDIR = flags.DEFINE_string( 211 'test_tmpdir', 212 get_default_test_tmpdir(), 213 'Directory for temporary testing files', 214 allow_override_cpp=True) 215 216flags.DEFINE_integer( 217 'test_random_seed', 218 _get_default_test_random_seed(), 219 'Random seed for testing. Some test frameworks may ' 220 'change the default value of this flag between runs, so ' 221 'it is not appropriate for seeding probabilistic tests.', 222 allow_override_cpp=True) 223flags.DEFINE_string( 224 'test_randomize_ordering_seed', 225 '', 226 'If positive, use this as a seed to randomize the ' 227 'execution order for test cases. If "random", pick a ' 228 'random seed to use. If 0 or not set, do not randomize ' 229 'test case execution order. This flag also overrides ' 230 'the TEST_RANDOMIZE_ORDERING_SEED environment variable.', 231 allow_override_cpp=True) 232flags.DEFINE_string('xml_output_file', '', 'File to store XML test results') 233 234 235# We might need to monkey-patch TestResult so that it stops considering an 236# unexpected pass as a as a "successful result". For details, see 237# http://bugs.python.org/issue20165 238def _monkey_patch_test_result_for_unexpected_passes(): 239 # type: () -> None 240 """Workaround for <http://bugs.python.org/issue20165>.""" 241 242 def wasSuccessful(self): 243 # type: () -> bool 244 """Tells whether or not this result was a success. 245 246 Any unexpected pass is to be counted as a non-success. 247 248 Args: 249 self: The TestResult instance. 250 251 Returns: 252 Whether or not this result was a success. 253 """ 254 return (len(self.failures) == len(self.errors) == 255 len(self.unexpectedSuccesses) == 0) 256 257 test_result = unittest.TestResult() 258 test_result.addUnexpectedSuccess(unittest.FunctionTestCase(lambda: None)) 259 if test_result.wasSuccessful(): # The bug is present. 260 unittest.TestResult.wasSuccessful = wasSuccessful 261 if test_result.wasSuccessful(): # Warn the user if our hot-fix failed. 262 sys.stderr.write('unittest.result.TestResult monkey patch to report' 263 ' unexpected passes as failures did not work.\n') 264 265 266_monkey_patch_test_result_for_unexpected_passes() 267 268 269def _open(filepath, mode, _open_func=open): 270 # type: (Text, Text, Callable[..., IO]) -> IO 271 """Opens a file. 272 273 Like open(), but ensure that we can open real files even if tests stub out 274 open(). 275 276 Args: 277 filepath: A filepath. 278 mode: A mode. 279 _open_func: A built-in open() function. 280 281 Returns: 282 The opened file object. 283 """ 284 return _open_func(filepath, mode, encoding='utf-8') 285 286 287class _TempDir(object): 288 """Represents a temporary directory for tests. 289 290 Creation of this class is internal. Using its public methods is OK. 291 292 This class implements the `os.PathLike` interface (specifically, 293 `os.PathLike[str]`). This means, in Python 3, it can be directly passed 294 to e.g. `os.path.join()`. 295 """ 296 297 def __init__(self, path): 298 # type: (Text) -> None 299 """Module-private: do not instantiate outside module.""" 300 self._path = path 301 302 @property 303 def full_path(self): 304 # type: () -> Text 305 """Returns the path, as a string, for the directory. 306 307 TIP: Instead of e.g. `os.path.join(temp_dir.full_path)`, you can simply 308 do `os.path.join(temp_dir)` because `__fspath__()` is implemented. 309 """ 310 return self._path 311 312 def __fspath__(self): 313 # type: () -> Text 314 """See os.PathLike.""" 315 return self.full_path 316 317 def create_file(self, file_path=None, content=None, mode='w', encoding='utf8', 318 errors='strict'): 319 # type: (Optional[Text], Optional[AnyStr], Text, Text, Text) -> _TempFile 320 """Create a file in the directory. 321 322 NOTE: If the file already exists, it will be made writable and overwritten. 323 324 Args: 325 file_path: Optional file path for the temp file. If not given, a unique 326 file name will be generated and used. Slashes are allowed in the name; 327 any missing intermediate directories will be created. NOTE: This path 328 is the path that will be cleaned up, including any directories in the 329 path, e.g., 'foo/bar/baz.txt' will `rm -r foo` 330 content: Optional string or bytes to initially write to the file. If not 331 specified, then an empty file is created. 332 mode: Mode string to use when writing content. Only used if `content` is 333 non-empty. 334 encoding: Encoding to use when writing string content. Only used if 335 `content` is text. 336 errors: How to handle text to bytes encoding errors. Only used if 337 `content` is text. 338 339 Returns: 340 A _TempFile representing the created file. 341 """ 342 tf, _ = _TempFile._create(self._path, file_path, content, mode, encoding, 343 errors) 344 return tf 345 346 def mkdir(self, dir_path=None): 347 # type: (Optional[Text]) -> _TempDir 348 """Create a directory in the directory. 349 350 Args: 351 dir_path: Optional path to the directory to create. If not given, 352 a unique name will be generated and used. 353 354 Returns: 355 A _TempDir representing the created directory. 356 """ 357 if dir_path: 358 path = os.path.join(self._path, dir_path) 359 else: 360 path = tempfile.mkdtemp(dir=self._path) 361 362 # Note: there's no need to clear the directory since the containing 363 # dir was cleared by the tempdir() function. 364 os.makedirs(path, exist_ok=True) 365 return _TempDir(path) 366 367 368class _TempFile(object): 369 """Represents a tempfile for tests. 370 371 Creation of this class is internal. Using its public methods is OK. 372 373 This class implements the `os.PathLike` interface (specifically, 374 `os.PathLike[str]`). This means, in Python 3, it can be directly passed 375 to e.g. `os.path.join()`. 376 """ 377 378 def __init__(self, path): 379 # type: (Text) -> None 380 """Private: use _create instead.""" 381 self._path = path 382 383 # pylint: disable=line-too-long 384 @classmethod 385 def _create(cls, base_path, file_path, content, mode, encoding, errors): 386 # type: (Text, Optional[Text], AnyStr, Text, Text, Text) -> Tuple[_TempFile, Text] 387 # pylint: enable=line-too-long 388 """Module-private: create a tempfile instance.""" 389 if file_path: 390 cleanup_path = os.path.join(base_path, _get_first_part(file_path)) 391 path = os.path.join(base_path, file_path) 392 os.makedirs(os.path.dirname(path), exist_ok=True) 393 # The file may already exist, in which case, ensure it's writable so that 394 # it can be truncated. 395 if os.path.exists(path) and not os.access(path, os.W_OK): 396 stat_info = os.stat(path) 397 os.chmod(path, stat_info.st_mode | stat.S_IWUSR) 398 else: 399 os.makedirs(base_path, exist_ok=True) 400 fd, path = tempfile.mkstemp(dir=str(base_path)) 401 os.close(fd) 402 cleanup_path = path 403 404 tf = cls(path) 405 406 if content: 407 if isinstance(content, str): 408 tf.write_text(content, mode=mode, encoding=encoding, errors=errors) 409 else: 410 tf.write_bytes(content, mode) 411 412 else: 413 tf.write_bytes(b'') 414 415 return tf, cleanup_path 416 417 @property 418 def full_path(self): 419 # type: () -> Text 420 """Returns the path, as a string, for the file. 421 422 TIP: Instead of e.g. `os.path.join(temp_file.full_path)`, you can simply 423 do `os.path.join(temp_file)` because `__fspath__()` is implemented. 424 """ 425 return self._path 426 427 def __fspath__(self): 428 # type: () -> Text 429 """See os.PathLike.""" 430 return self.full_path 431 432 def read_text(self, encoding='utf8', errors='strict'): 433 # type: (Text, Text) -> Text 434 """Return the contents of the file as text.""" 435 with self.open_text(encoding=encoding, errors=errors) as fp: 436 return fp.read() 437 438 def read_bytes(self): 439 # type: () -> bytes 440 """Return the content of the file as bytes.""" 441 with self.open_bytes() as fp: 442 return fp.read() 443 444 def write_text(self, text, mode='w', encoding='utf8', errors='strict'): 445 # type: (Text, Text, Text, Text) -> None 446 """Write text to the file. 447 448 Args: 449 text: Text to write. In Python 2, it can be bytes, which will be 450 decoded using the `encoding` arg (this is as an aid for code that 451 is 2 and 3 compatible). 452 mode: The mode to open the file for writing. 453 encoding: The encoding to use when writing the text to the file. 454 errors: The error handling strategy to use when converting text to bytes. 455 """ 456 with self.open_text(mode, encoding=encoding, errors=errors) as fp: 457 fp.write(text) 458 459 def write_bytes(self, data, mode='wb'): 460 # type: (bytes, Text) -> None 461 """Write bytes to the file. 462 463 Args: 464 data: bytes to write. 465 mode: Mode to open the file for writing. The "b" flag is implicit if 466 not already present. It must not have the "t" flag. 467 """ 468 with self.open_bytes(mode) as fp: 469 fp.write(data) 470 471 def open_text(self, mode='rt', encoding='utf8', errors='strict'): 472 # type: (Text, Text, Text) -> ContextManager[TextIO] 473 """Return a context manager for opening the file in text mode. 474 475 Args: 476 mode: The mode to open the file in. The "t" flag is implicit if not 477 already present. It must not have the "b" flag. 478 encoding: The encoding to use when opening the file. 479 errors: How to handle decoding errors. 480 481 Returns: 482 Context manager that yields an open file. 483 484 Raises: 485 ValueError: if invalid inputs are provided. 486 """ 487 if 'b' in mode: 488 raise ValueError('Invalid mode {!r}: "b" flag not allowed when opening ' 489 'file in text mode'.format(mode)) 490 if 't' not in mode: 491 mode += 't' 492 cm = self._open(mode, encoding, errors) 493 return cm 494 495 def open_bytes(self, mode='rb'): 496 # type: (Text) -> ContextManager[BinaryIO] 497 """Return a context manager for opening the file in binary mode. 498 499 Args: 500 mode: The mode to open the file in. The "b" mode is implicit if not 501 already present. It must not have the "t" flag. 502 503 Returns: 504 Context manager that yields an open file. 505 506 Raises: 507 ValueError: if invalid inputs are provided. 508 """ 509 if 't' in mode: 510 raise ValueError('Invalid mode {!r}: "t" flag not allowed when opening ' 511 'file in binary mode'.format(mode)) 512 if 'b' not in mode: 513 mode += 'b' 514 cm = self._open(mode, encoding=None, errors=None) 515 return cm 516 517 # TODO(b/123775699): Once pytype supports typing.Literal, use overload and 518 # Literal to express more precise return types. The contained type is 519 # currently `Any` to avoid [bad-return-type] errors in the open_* methods. 520 @contextlib.contextmanager 521 def _open( 522 self, 523 mode: str, 524 encoding: Optional[str] = 'utf8', 525 errors: Optional[str] = 'strict', 526 ) -> Iterator[Any]: 527 with io.open( 528 self.full_path, mode=mode, encoding=encoding, errors=errors) as fp: 529 yield fp 530 531 532class _method(object): 533 """A decorator that supports both instance and classmethod invocations. 534 535 Using similar semantics to the @property builtin, this decorator can augment 536 an instance method to support conditional logic when invoked on a class 537 object. This breaks support for invoking an instance method via the class 538 (e.g. Cls.method(self, ...)) but is still situationally useful. 539 """ 540 541 def __init__(self, finstancemethod): 542 # type: (Callable[..., Any]) -> None 543 self._finstancemethod = finstancemethod 544 self._fclassmethod = None 545 546 def classmethod(self, fclassmethod): 547 # type: (Callable[..., Any]) -> _method 548 self._fclassmethod = classmethod(fclassmethod) 549 return self 550 551 def __doc__(self): 552 # type: () -> str 553 if getattr(self._finstancemethod, '__doc__'): 554 return self._finstancemethod.__doc__ 555 elif getattr(self._fclassmethod, '__doc__'): 556 return self._fclassmethod.__doc__ 557 return '' 558 559 def __get__(self, obj, type_): 560 # type: (Optional[Any], Optional[Type[Any]]) -> Callable[..., Any] 561 func = self._fclassmethod if obj is None else self._finstancemethod 562 return func.__get__(obj, type_) # pytype: disable=attribute-error 563 564 565class TestCase(unittest.TestCase): 566 """Extension of unittest.TestCase providing more power.""" 567 568 # When to cleanup files/directories created by our `create_tempfile()` and 569 # `create_tempdir()` methods after each test case completes. This does *not* 570 # affect e.g., files created outside of those methods, e.g., using the stdlib 571 # tempfile module. This can be overridden at the class level, instance level, 572 # or with the `cleanup` arg of `create_tempfile()` and `create_tempdir()`. See 573 # `TempFileCleanup` for details on the different values. 574 # TODO(b/70517332): Remove the type comment and the disable once pytype has 575 # better support for enums. 576 tempfile_cleanup = TempFileCleanup.ALWAYS # type: TempFileCleanup # pytype: disable=annotation-type-mismatch 577 578 maxDiff = 80 * 20 579 longMessage = True 580 581 # Exit stacks for per-test and per-class scopes. 582 if sys.version_info < (3, 11): 583 _exit_stack = None 584 _cls_exit_stack = None 585 586 def __init__(self, *args, **kwargs): 587 super(TestCase, self).__init__(*args, **kwargs) 588 # This is to work around missing type stubs in unittest.pyi 589 self._outcome = getattr(self, '_outcome') # type: Optional[_OutcomeType] 590 591 def setUp(self): 592 super(TestCase, self).setUp() 593 # NOTE: Only Python 3 contextlib has ExitStack and 594 # Python 3.11+ already has enterContext. 595 if hasattr(contextlib, 'ExitStack') and sys.version_info < (3, 11): 596 self._exit_stack = contextlib.ExitStack() 597 self.addCleanup(self._exit_stack.close) 598 599 @classmethod 600 def setUpClass(cls): 601 super(TestCase, cls).setUpClass() 602 # NOTE: Only Python 3 contextlib has ExitStack, only Python 3.8+ has 603 # addClassCleanup and Python 3.11+ already has enterClassContext. 604 if ( 605 hasattr(contextlib, 'ExitStack') 606 and hasattr(cls, 'addClassCleanup') 607 and sys.version_info < (3, 11) 608 ): 609 cls._cls_exit_stack = contextlib.ExitStack() 610 cls.addClassCleanup(cls._cls_exit_stack.close) 611 612 def create_tempdir(self, name=None, cleanup=None): 613 # type: (Optional[Text], Optional[TempFileCleanup]) -> _TempDir 614 """Create a temporary directory specific to the test. 615 616 NOTE: The directory and its contents will be recursively cleared before 617 creation. This ensures that there is no pre-existing state. 618 619 This creates a named directory on disk that is isolated to this test, and 620 will be properly cleaned up by the test. This avoids several pitfalls of 621 creating temporary directories for test purposes, as well as makes it easier 622 to setup directories and verify their contents. For example:: 623 624 def test_foo(self): 625 out_dir = self.create_tempdir() 626 out_log = out_dir.create_file('output.log') 627 expected_outputs = [ 628 os.path.join(out_dir, 'data-0.txt'), 629 os.path.join(out_dir, 'data-1.txt'), 630 ] 631 code_under_test(out_dir) 632 self.assertTrue(os.path.exists(expected_paths[0])) 633 self.assertTrue(os.path.exists(expected_paths[1])) 634 self.assertEqual('foo', out_log.read_text()) 635 636 See also: :meth:`create_tempfile` for creating temporary files. 637 638 Args: 639 name: Optional name of the directory. If not given, a unique 640 name will be generated and used. 641 cleanup: Optional cleanup policy on when/if to remove the directory (and 642 all its contents) at the end of the test. If None, then uses 643 :attr:`tempfile_cleanup`. 644 645 Returns: 646 A _TempDir representing the created directory; see _TempDir class docs 647 for usage. 648 """ 649 test_path = self._get_tempdir_path_test() 650 651 if name: 652 path = os.path.join(test_path, name) 653 cleanup_path = os.path.join(test_path, _get_first_part(name)) 654 else: 655 os.makedirs(test_path, exist_ok=True) 656 path = tempfile.mkdtemp(dir=test_path) 657 cleanup_path = path 658 659 _rmtree_ignore_errors(cleanup_path) 660 os.makedirs(path, exist_ok=True) 661 662 self._maybe_add_temp_path_cleanup(cleanup_path, cleanup) 663 664 return _TempDir(path) 665 666 # pylint: disable=line-too-long 667 def create_tempfile(self, file_path=None, content=None, mode='w', 668 encoding='utf8', errors='strict', cleanup=None): 669 # type: (Optional[Text], Optional[AnyStr], Text, Text, Text, Optional[TempFileCleanup]) -> _TempFile 670 # pylint: enable=line-too-long 671 """Create a temporary file specific to the test. 672 673 This creates a named file on disk that is isolated to this test, and will 674 be properly cleaned up by the test. This avoids several pitfalls of 675 creating temporary files for test purposes, as well as makes it easier 676 to setup files, their data, read them back, and inspect them when 677 a test fails. For example:: 678 679 def test_foo(self): 680 output = self.create_tempfile() 681 code_under_test(output) 682 self.assertGreater(os.path.getsize(output), 0) 683 self.assertEqual('foo', output.read_text()) 684 685 NOTE: This will zero-out the file. This ensures there is no pre-existing 686 state. 687 NOTE: If the file already exists, it will be made writable and overwritten. 688 689 See also: :meth:`create_tempdir` for creating temporary directories, and 690 ``_TempDir.create_file`` for creating files within a temporary directory. 691 692 Args: 693 file_path: Optional file path for the temp file. If not given, a unique 694 file name will be generated and used. Slashes are allowed in the name; 695 any missing intermediate directories will be created. NOTE: This path is 696 the path that will be cleaned up, including any directories in the path, 697 e.g., ``'foo/bar/baz.txt'`` will ``rm -r foo``. 698 content: Optional string or 699 bytes to initially write to the file. If not 700 specified, then an empty file is created. 701 mode: Mode string to use when writing content. Only used if `content` is 702 non-empty. 703 encoding: Encoding to use when writing string content. Only used if 704 `content` is text. 705 errors: How to handle text to bytes encoding errors. Only used if 706 `content` is text. 707 cleanup: Optional cleanup policy on when/if to remove the directory (and 708 all its contents) at the end of the test. If None, then uses 709 :attr:`tempfile_cleanup`. 710 711 Returns: 712 A _TempFile representing the created file; see _TempFile class docs for 713 usage. 714 """ 715 test_path = self._get_tempdir_path_test() 716 tf, cleanup_path = _TempFile._create(test_path, file_path, content=content, 717 mode=mode, encoding=encoding, 718 errors=errors) 719 self._maybe_add_temp_path_cleanup(cleanup_path, cleanup) 720 return tf 721 722 @_method 723 def enter_context(self, manager): 724 # type: (ContextManager[_T]) -> _T 725 """Returns the CM's value after registering it with the exit stack. 726 727 Entering a context pushes it onto a stack of contexts. When `enter_context` 728 is called on the test instance (e.g. `self.enter_context`), the context is 729 exited after the test case's tearDown call. When called on the test class 730 (e.g. `TestCase.enter_context`), the context is exited after the test 731 class's tearDownClass call. 732 733 Contexts are exited in the reverse order of entering. They will always 734 be exited, regardless of test failure/success. 735 736 This is useful to eliminate per-test boilerplate when context managers 737 are used. For example, instead of decorating every test with `@mock.patch`, 738 simply do `self.foo = self.enter_context(mock.patch(...))' in `setUp()`. 739 740 NOTE: The context managers will always be exited without any error 741 information. This is an unfortunate implementation detail due to some 742 internals of how unittest runs tests. 743 744 Args: 745 manager: The context manager to enter. 746 """ 747 if sys.version_info >= (3, 11): 748 return self.enterContext(manager) 749 750 if not self._exit_stack: 751 raise AssertionError( 752 'self._exit_stack is not set: enter_context is Py3-only; also make ' 753 'sure that AbslTest.setUp() is called.') 754 return self._exit_stack.enter_context(manager) 755 756 @enter_context.classmethod 757 def enter_context(cls, manager): # pylint: disable=no-self-argument 758 # type: (ContextManager[_T]) -> _T 759 if sys.version_info >= (3, 11): 760 return cls.enterClassContext(manager) 761 762 if not cls._cls_exit_stack: 763 raise AssertionError( 764 'cls._cls_exit_stack is not set: cls.enter_context requires ' 765 'Python 3.8+; also make sure that AbslTest.setUpClass() is called.') 766 return cls._cls_exit_stack.enter_context(manager) 767 768 @classmethod 769 def _get_tempdir_path_cls(cls): 770 # type: () -> Text 771 return os.path.join(TEST_TMPDIR.value, 772 cls.__qualname__.replace('__main__.', '')) 773 774 def _get_tempdir_path_test(self): 775 # type: () -> Text 776 return os.path.join(self._get_tempdir_path_cls(), self._testMethodName) 777 778 def _get_tempfile_cleanup(self, override): 779 # type: (Optional[TempFileCleanup]) -> TempFileCleanup 780 if override is not None: 781 return override 782 return self.tempfile_cleanup 783 784 def _maybe_add_temp_path_cleanup(self, path, cleanup): 785 # type: (Text, Optional[TempFileCleanup]) -> None 786 cleanup = self._get_tempfile_cleanup(cleanup) 787 if cleanup == TempFileCleanup.OFF: 788 return 789 elif cleanup == TempFileCleanup.ALWAYS: 790 self.addCleanup(_rmtree_ignore_errors, path) 791 elif cleanup == TempFileCleanup.SUCCESS: 792 self._internal_add_cleanup_on_success(_rmtree_ignore_errors, path) 793 else: 794 raise AssertionError('Unexpected cleanup value: {}'.format(cleanup)) 795 796 def _internal_add_cleanup_on_success( 797 self, 798 function: Callable[..., Any], 799 *args: Any, 800 **kwargs: Any, 801 ) -> None: 802 """Adds `function` as cleanup when the test case succeeds.""" 803 outcome = self._outcome 804 assert outcome is not None 805 previous_failure_count = ( 806 len(outcome.result.failures) 807 + len(outcome.result.errors) 808 + len(outcome.result.unexpectedSuccesses) 809 ) 810 def _call_cleaner_on_success(*args, **kwargs): 811 if not self._internal_ran_and_passed_when_called_during_cleanup( 812 previous_failure_count): 813 return 814 function(*args, **kwargs) 815 self.addCleanup(_call_cleaner_on_success, *args, **kwargs) 816 817 def _internal_ran_and_passed_when_called_during_cleanup( 818 self, 819 previous_failure_count: int, 820 ) -> bool: 821 """Returns whether test is passed. Expected to be called during cleanup.""" 822 outcome = self._outcome 823 if sys.version_info[:2] >= (3, 11): 824 assert outcome is not None 825 current_failure_count = ( 826 len(outcome.result.failures) 827 + len(outcome.result.errors) 828 + len(outcome.result.unexpectedSuccesses) 829 ) 830 return current_failure_count == previous_failure_count 831 else: 832 # Before Python 3.11 https://github.com/python/cpython/pull/28180, errors 833 # were bufferred in _Outcome before calling cleanup. 834 result = self.defaultTestResult() 835 self._feedErrorsToResult(result, outcome.errors) # pytype: disable=attribute-error 836 return result.wasSuccessful() 837 838 def shortDescription(self): 839 # type: () -> Text 840 """Formats both the test method name and the first line of its docstring. 841 842 If no docstring is given, only returns the method name. 843 844 This method overrides unittest.TestCase.shortDescription(), which 845 only returns the first line of the docstring, obscuring the name 846 of the test upon failure. 847 848 Returns: 849 desc: A short description of a test method. 850 """ 851 desc = self.id() 852 853 # Omit the main name so that test name can be directly copy/pasted to 854 # the command line. 855 if desc.startswith('__main__.'): 856 desc = desc[len('__main__.'):] 857 858 # NOTE: super() is used here instead of directly invoking 859 # unittest.TestCase.shortDescription(self), because of the 860 # following line that occurs later on: 861 # unittest.TestCase = TestCase 862 # Because of this, direct invocation of what we think is the 863 # superclass will actually cause infinite recursion. 864 doc_first_line = super(TestCase, self).shortDescription() 865 if doc_first_line is not None: 866 desc = '\n'.join((desc, doc_first_line)) 867 return desc 868 869 def assertStartsWith(self, actual, expected_start, msg=None): 870 """Asserts that actual.startswith(expected_start) is True. 871 872 Args: 873 actual: str 874 expected_start: str 875 msg: Optional message to report on failure. 876 """ 877 if not actual.startswith(expected_start): 878 self.fail('%r does not start with %r' % (actual, expected_start), msg) 879 880 def assertNotStartsWith(self, actual, unexpected_start, msg=None): 881 """Asserts that actual.startswith(unexpected_start) is False. 882 883 Args: 884 actual: str 885 unexpected_start: str 886 msg: Optional message to report on failure. 887 """ 888 if actual.startswith(unexpected_start): 889 self.fail('%r does start with %r' % (actual, unexpected_start), msg) 890 891 def assertEndsWith(self, actual, expected_end, msg=None): 892 """Asserts that actual.endswith(expected_end) is True. 893 894 Args: 895 actual: str 896 expected_end: str 897 msg: Optional message to report on failure. 898 """ 899 if not actual.endswith(expected_end): 900 self.fail('%r does not end with %r' % (actual, expected_end), msg) 901 902 def assertNotEndsWith(self, actual, unexpected_end, msg=None): 903 """Asserts that actual.endswith(unexpected_end) is False. 904 905 Args: 906 actual: str 907 unexpected_end: str 908 msg: Optional message to report on failure. 909 """ 910 if actual.endswith(unexpected_end): 911 self.fail('%r does end with %r' % (actual, unexpected_end), msg) 912 913 def assertSequenceStartsWith(self, prefix, whole, msg=None): 914 """An equality assertion for the beginning of ordered sequences. 915 916 If prefix is an empty sequence, it will raise an error unless whole is also 917 an empty sequence. 918 919 If prefix is not a sequence, it will raise an error if the first element of 920 whole does not match. 921 922 Args: 923 prefix: A sequence expected at the beginning of the whole parameter. 924 whole: The sequence in which to look for prefix. 925 msg: Optional message to report on failure. 926 """ 927 try: 928 prefix_len = len(prefix) 929 except (TypeError, NotImplementedError): 930 prefix = [prefix] 931 prefix_len = 1 932 933 if isinstance(whole, abc.Mapping) or isinstance(whole, abc.Set): 934 self.fail( 935 'For whole: Mapping or Set objects are not supported, found type: %s' 936 % type(whole), 937 msg, 938 ) 939 try: 940 whole_len = len(whole) 941 except (TypeError, NotImplementedError): 942 self.fail('For whole: len(%s) is not supported, it appears to be type: ' 943 '%s' % (whole, type(whole)), msg) 944 945 assert prefix_len <= whole_len, self._formatMessage( 946 msg, 947 'Prefix length (%d) is longer than whole length (%d).' % 948 (prefix_len, whole_len) 949 ) 950 951 if not prefix_len and whole_len: 952 self.fail('Prefix length is 0 but whole length is %d: %s' % 953 (len(whole), whole), msg) 954 955 try: 956 self.assertSequenceEqual(prefix, whole[:prefix_len], msg) 957 except AssertionError: 958 self.fail('prefix: %s not found at start of whole: %s.' % 959 (prefix, whole), msg) 960 961 def assertEmpty(self, container, msg=None): 962 """Asserts that an object has zero length. 963 964 Args: 965 container: Anything that implements the collections.abc.Sized interface. 966 msg: Optional message to report on failure. 967 """ 968 if not isinstance(container, abc.Sized): 969 self.fail('Expected a Sized object, got: ' 970 '{!r}'.format(type(container).__name__), msg) 971 972 # explicitly check the length since some Sized objects (e.g. numpy.ndarray) 973 # have strange __nonzero__/__bool__ behavior. 974 if len(container): # pylint: disable=g-explicit-length-test 975 self.fail('{!r} has length of {}.'.format(container, len(container)), msg) 976 977 def assertNotEmpty(self, container, msg=None): 978 """Asserts that an object has non-zero length. 979 980 Args: 981 container: Anything that implements the collections.abc.Sized interface. 982 msg: Optional message to report on failure. 983 """ 984 if not isinstance(container, abc.Sized): 985 self.fail('Expected a Sized object, got: ' 986 '{!r}'.format(type(container).__name__), msg) 987 988 # explicitly check the length since some Sized objects (e.g. numpy.ndarray) 989 # have strange __nonzero__/__bool__ behavior. 990 if not len(container): # pylint: disable=g-explicit-length-test 991 self.fail('{!r} has length of 0.'.format(container), msg) 992 993 def assertLen(self, container, expected_len, msg=None): 994 """Asserts that an object has the expected length. 995 996 Args: 997 container: Anything that implements the collections.abc.Sized interface. 998 expected_len: The expected length of the container. 999 msg: Optional message to report on failure. 1000 """ 1001 if not isinstance(container, abc.Sized): 1002 self.fail('Expected a Sized object, got: ' 1003 '{!r}'.format(type(container).__name__), msg) 1004 if len(container) != expected_len: 1005 container_repr = unittest.util.safe_repr(container) # pytype: disable=module-attr 1006 self.fail('{} has length of {}, expected {}.'.format( 1007 container_repr, len(container), expected_len), msg) 1008 1009 def assertSequenceAlmostEqual(self, expected_seq, actual_seq, places=None, 1010 msg=None, delta=None): 1011 """An approximate equality assertion for ordered sequences. 1012 1013 Fail if the two sequences are unequal as determined by their value 1014 differences rounded to the given number of decimal places (default 7) and 1015 comparing to zero, or by comparing that the difference between each value 1016 in the two sequences is more than the given delta. 1017 1018 Note that decimal places (from zero) are usually not the same as significant 1019 digits (measured from the most significant digit). 1020 1021 If the two sequences compare equal then they will automatically compare 1022 almost equal. 1023 1024 Args: 1025 expected_seq: A sequence containing elements we are expecting. 1026 actual_seq: The sequence that we are testing. 1027 places: The number of decimal places to compare. 1028 msg: The message to be printed if the test fails. 1029 delta: The OK difference between compared values. 1030 """ 1031 if len(expected_seq) != len(actual_seq): 1032 self.fail('Sequence size mismatch: {} vs {}'.format( 1033 len(expected_seq), len(actual_seq)), msg) 1034 1035 err_list = [] 1036 for idx, (exp_elem, act_elem) in enumerate(zip(expected_seq, actual_seq)): 1037 try: 1038 # assertAlmostEqual should be called with at most one of `places` and 1039 # `delta`. However, it's okay for assertSequenceAlmostEqual to pass 1040 # both because we want the latter to fail if the former does. 1041 # pytype: disable=wrong-keyword-args 1042 self.assertAlmostEqual(exp_elem, act_elem, places=places, msg=msg, 1043 delta=delta) 1044 # pytype: enable=wrong-keyword-args 1045 except self.failureException as err: 1046 err_list.append('At index {}: {}'.format(idx, err)) 1047 1048 if err_list: 1049 if len(err_list) > 30: 1050 err_list = err_list[:30] + ['...'] 1051 msg = self._formatMessage(msg, '\n'.join(err_list)) 1052 self.fail(msg) 1053 1054 def assertContainsSubset(self, expected_subset, actual_set, msg=None): 1055 """Checks whether actual iterable is a superset of expected iterable.""" 1056 missing = set(expected_subset) - set(actual_set) 1057 if not missing: 1058 return 1059 1060 self.fail('Missing elements %s\nExpected: %s\nActual: %s' % ( 1061 missing, expected_subset, actual_set), msg) 1062 1063 def assertNoCommonElements(self, expected_seq, actual_seq, msg=None): 1064 """Checks whether actual iterable and expected iterable are disjoint.""" 1065 common = set(expected_seq) & set(actual_seq) 1066 if not common: 1067 return 1068 1069 self.fail('Common elements %s\nExpected: %s\nActual: %s' % ( 1070 common, expected_seq, actual_seq), msg) 1071 1072 def assertItemsEqual(self, expected_seq, actual_seq, msg=None): 1073 """Deprecated, please use assertCountEqual instead. 1074 1075 This is equivalent to assertCountEqual. 1076 1077 Args: 1078 expected_seq: A sequence containing elements we are expecting. 1079 actual_seq: The sequence that we are testing. 1080 msg: The message to be printed if the test fails. 1081 """ 1082 super().assertCountEqual(expected_seq, actual_seq, msg) 1083 1084 def assertSameElements(self, expected_seq, actual_seq, msg=None): 1085 """Asserts that two sequences have the same elements (in any order). 1086 1087 This method, unlike assertCountEqual, doesn't care about any 1088 duplicates in the expected and actual sequences:: 1089 1090 # Doesn't raise an AssertionError 1091 assertSameElements([1, 1, 1, 0, 0, 0], [0, 1]) 1092 1093 If possible, you should use assertCountEqual instead of 1094 assertSameElements. 1095 1096 Args: 1097 expected_seq: A sequence containing elements we are expecting. 1098 actual_seq: The sequence that we are testing. 1099 msg: The message to be printed if the test fails. 1100 """ 1101 # `unittest2.TestCase` used to have assertSameElements, but it was 1102 # removed in favor of assertItemsEqual. As there's a unit test 1103 # that explicitly checks this behavior, I am leaving this method 1104 # alone. 1105 # Fail on strings: empirically, passing strings to this test method 1106 # is almost always a bug. If comparing the character sets of two strings 1107 # is desired, cast the inputs to sets or lists explicitly. 1108 if (isinstance(expected_seq, _TEXT_OR_BINARY_TYPES) or 1109 isinstance(actual_seq, _TEXT_OR_BINARY_TYPES)): 1110 self.fail('Passing string/bytes to assertSameElements is usually a bug. ' 1111 'Did you mean to use assertEqual?\n' 1112 'Expected: %s\nActual: %s' % (expected_seq, actual_seq)) 1113 try: 1114 expected = dict([(element, None) for element in expected_seq]) 1115 actual = dict([(element, None) for element in actual_seq]) 1116 missing = [element for element in expected if element not in actual] 1117 unexpected = [element for element in actual if element not in expected] 1118 missing.sort() 1119 unexpected.sort() 1120 except TypeError: 1121 # Fall back to slower list-compare if any of the objects are 1122 # not hashable. 1123 expected = list(expected_seq) 1124 actual = list(actual_seq) 1125 expected.sort() 1126 actual.sort() 1127 missing, unexpected = _sorted_list_difference(expected, actual) 1128 errors = [] 1129 if msg: 1130 errors.extend((msg, ':\n')) 1131 if missing: 1132 errors.append('Expected, but missing:\n %r\n' % missing) 1133 if unexpected: 1134 errors.append('Unexpected, but present:\n %r\n' % unexpected) 1135 if missing or unexpected: 1136 self.fail(''.join(errors)) 1137 1138 # unittest.TestCase.assertMultiLineEqual works very similarly, but it 1139 # has a different error format. However, I find this slightly more readable. 1140 def assertMultiLineEqual(self, first, second, msg=None, **kwargs): 1141 """Asserts that two multi-line strings are equal.""" 1142 assert isinstance(first, 1143 str), ('First argument is not a string: %r' % (first,)) 1144 assert isinstance(second, 1145 str), ('Second argument is not a string: %r' % (second,)) 1146 line_limit = kwargs.pop('line_limit', 0) 1147 if kwargs: 1148 raise TypeError('Unexpected keyword args {}'.format(tuple(kwargs))) 1149 1150 if first == second: 1151 return 1152 if msg: 1153 failure_message = [msg + ':\n'] 1154 else: 1155 failure_message = ['\n'] 1156 if line_limit: 1157 line_limit += len(failure_message) 1158 for line in difflib.ndiff(first.splitlines(True), second.splitlines(True)): 1159 failure_message.append(line) 1160 if not line.endswith('\n'): 1161 failure_message.append('\n') 1162 if line_limit and len(failure_message) > line_limit: 1163 n_omitted = len(failure_message) - line_limit 1164 failure_message = failure_message[:line_limit] 1165 failure_message.append( 1166 '(... and {} more delta lines omitted for brevity.)\n'.format( 1167 n_omitted)) 1168 1169 raise self.failureException(''.join(failure_message)) 1170 1171 def assertBetween(self, value, minv, maxv, msg=None): 1172 """Asserts that value is between minv and maxv (inclusive).""" 1173 msg = self._formatMessage(msg, 1174 '"%r" unexpectedly not between "%r" and "%r"' % 1175 (value, minv, maxv)) 1176 self.assertTrue(minv <= value, msg) 1177 self.assertTrue(maxv >= value, msg) 1178 1179 def assertRegexMatch(self, actual_str, regexes, message=None): 1180 r"""Asserts that at least one regex in regexes matches str. 1181 1182 If possible you should use `assertRegex`, which is a simpler 1183 version of this method. `assertRegex` takes a single regular 1184 expression (a string or re compiled object) instead of a list. 1185 1186 Notes: 1187 1188 1. This function uses substring matching, i.e. the matching 1189 succeeds if *any* substring of the error message matches *any* 1190 regex in the list. This is more convenient for the user than 1191 full-string matching. 1192 1193 2. If regexes is the empty list, the matching will always fail. 1194 1195 3. Use regexes=[''] for a regex that will always pass. 1196 1197 4. '.' matches any single character *except* the newline. To 1198 match any character, use '(.|\n)'. 1199 1200 5. '^' matches the beginning of each line, not just the beginning 1201 of the string. Similarly, '$' matches the end of each line. 1202 1203 6. An exception will be thrown if regexes contains an invalid 1204 regex. 1205 1206 Args: 1207 actual_str: The string we try to match with the items in regexes. 1208 regexes: The regular expressions we want to match against str. 1209 See "Notes" above for detailed notes on how this is interpreted. 1210 message: The message to be printed if the test fails. 1211 """ 1212 if isinstance(regexes, _TEXT_OR_BINARY_TYPES): 1213 self.fail('regexes is string or bytes; use assertRegex instead.', 1214 message) 1215 if not regexes: 1216 self.fail('No regexes specified.', message) 1217 1218 regex_type = type(regexes[0]) 1219 for regex in regexes[1:]: 1220 if type(regex) is not regex_type: # pylint: disable=unidiomatic-typecheck 1221 self.fail('regexes list must all be the same type.', message) 1222 1223 if regex_type is bytes and isinstance(actual_str, str): 1224 regexes = [regex.decode('utf-8') for regex in regexes] 1225 regex_type = str 1226 elif regex_type is str and isinstance(actual_str, bytes): 1227 regexes = [regex.encode('utf-8') for regex in regexes] 1228 regex_type = bytes 1229 1230 if regex_type is str: 1231 regex = u'(?:%s)' % u')|(?:'.join(regexes) 1232 elif regex_type is bytes: 1233 regex = b'(?:' + (b')|(?:'.join(regexes)) + b')' 1234 else: 1235 self.fail('Only know how to deal with unicode str or bytes regexes.', 1236 message) 1237 1238 if not re.search(regex, actual_str, re.MULTILINE): 1239 self.fail('"%s" does not contain any of these regexes: %s.' % 1240 (actual_str, regexes), message) 1241 1242 def assertCommandSucceeds(self, command, regexes=(b'',), env=None, 1243 close_fds=True, msg=None): 1244 """Asserts that a shell command succeeds (i.e. exits with code 0). 1245 1246 Args: 1247 command: List or string representing the command to run. 1248 regexes: List of regular expression byte strings that match success. 1249 env: Dictionary of environment variable settings. If None, no environment 1250 variables will be set for the child process. This is to make tests 1251 more hermetic. NOTE: this behavior is different than the standard 1252 subprocess module. 1253 close_fds: Whether or not to close all open fd's in the child after 1254 forking. 1255 msg: Optional message to report on failure. 1256 """ 1257 (ret_code, err) = get_command_stderr(command, env, close_fds) 1258 1259 # We need bytes regexes here because `err` is bytes. 1260 # Accommodate code which listed their output regexes w/o the b'' prefix by 1261 # converting them to bytes for the user. 1262 if isinstance(regexes[0], str): 1263 regexes = [regex.encode('utf-8') for regex in regexes] 1264 1265 command_string = get_command_string(command) 1266 self.assertEqual( 1267 ret_code, 0, 1268 self._formatMessage(msg, 1269 'Running command\n' 1270 '%s failed with error code %s and message\n' 1271 '%s' % (_quote_long_string(command_string), 1272 ret_code, 1273 _quote_long_string(err))) 1274 ) 1275 self.assertRegexMatch( 1276 err, 1277 regexes, 1278 message=self._formatMessage( 1279 msg, 1280 'Running command\n' 1281 '%s failed with error code %s and message\n' 1282 '%s which matches no regex in %s' % ( 1283 _quote_long_string(command_string), 1284 ret_code, 1285 _quote_long_string(err), 1286 regexes))) 1287 1288 def assertCommandFails(self, command, regexes, env=None, close_fds=True, 1289 msg=None): 1290 """Asserts a shell command fails and the error matches a regex in a list. 1291 1292 Args: 1293 command: List or string representing the command to run. 1294 regexes: the list of regular expression strings. 1295 env: Dictionary of environment variable settings. If None, no environment 1296 variables will be set for the child process. This is to make tests 1297 more hermetic. NOTE: this behavior is different than the standard 1298 subprocess module. 1299 close_fds: Whether or not to close all open fd's in the child after 1300 forking. 1301 msg: Optional message to report on failure. 1302 """ 1303 (ret_code, err) = get_command_stderr(command, env, close_fds) 1304 1305 # We need bytes regexes here because `err` is bytes. 1306 # Accommodate code which listed their output regexes w/o the b'' prefix by 1307 # converting them to bytes for the user. 1308 if isinstance(regexes[0], str): 1309 regexes = [regex.encode('utf-8') for regex in regexes] 1310 1311 command_string = get_command_string(command) 1312 self.assertNotEqual( 1313 ret_code, 0, 1314 self._formatMessage(msg, 'The following command succeeded ' 1315 'while expected to fail:\n%s' % 1316 _quote_long_string(command_string))) 1317 self.assertRegexMatch( 1318 err, 1319 regexes, 1320 message=self._formatMessage( 1321 msg, 1322 'Running command\n' 1323 '%s failed with error code %s and message\n' 1324 '%s which matches no regex in %s' % ( 1325 _quote_long_string(command_string), 1326 ret_code, 1327 _quote_long_string(err), 1328 regexes))) 1329 1330 class _AssertRaisesContext(object): 1331 1332 def __init__(self, expected_exception, test_case, test_func, msg=None): 1333 self.expected_exception = expected_exception 1334 self.test_case = test_case 1335 self.test_func = test_func 1336 self.msg = msg 1337 1338 def __enter__(self): 1339 return self 1340 1341 def __exit__(self, exc_type, exc_value, tb): 1342 if exc_type is None: 1343 self.test_case.fail(self.expected_exception.__name__ + ' not raised', 1344 self.msg) 1345 if not issubclass(exc_type, self.expected_exception): 1346 return False 1347 self.test_func(exc_value) 1348 if exc_value: 1349 self.exception = exc_value.with_traceback(None) 1350 return True 1351 1352 @typing.overload 1353 def assertRaisesWithPredicateMatch( 1354 self, expected_exception, predicate) -> _AssertRaisesContext: 1355 # The purpose of this return statement is to work around 1356 # https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored. 1357 return self._AssertRaisesContext(None, None, None) 1358 1359 @typing.overload 1360 def assertRaisesWithPredicateMatch( 1361 self, expected_exception, predicate, callable_obj: Callable[..., Any], 1362 *args, **kwargs) -> None: 1363 # The purpose of this return statement is to work around 1364 # https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored. 1365 return self._AssertRaisesContext(None, None, None) 1366 1367 def assertRaisesWithPredicateMatch(self, expected_exception, predicate, 1368 callable_obj=None, *args, **kwargs): 1369 """Asserts that exception is thrown and predicate(exception) is true. 1370 1371 Args: 1372 expected_exception: Exception class expected to be raised. 1373 predicate: Function of one argument that inspects the passed-in exception 1374 and returns True (success) or False (please fail the test). 1375 callable_obj: Function to be called. 1376 *args: Extra args. 1377 **kwargs: Extra keyword args. 1378 1379 Returns: 1380 A context manager if callable_obj is None. Otherwise, None. 1381 1382 Raises: 1383 self.failureException if callable_obj does not raise a matching exception. 1384 """ 1385 def Check(err): 1386 self.assertTrue(predicate(err), 1387 '%r does not match predicate %r' % (err, predicate)) 1388 1389 context = self._AssertRaisesContext(expected_exception, self, Check) 1390 if callable_obj is None: 1391 return context 1392 with context: 1393 callable_obj(*args, **kwargs) 1394 1395 @typing.overload 1396 def assertRaisesWithLiteralMatch( 1397 self, expected_exception, expected_exception_message 1398 ) -> _AssertRaisesContext: 1399 # The purpose of this return statement is to work around 1400 # https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored. 1401 return self._AssertRaisesContext(None, None, None) 1402 1403 @typing.overload 1404 def assertRaisesWithLiteralMatch( 1405 self, expected_exception, expected_exception_message, 1406 callable_obj: Callable[..., Any], *args, **kwargs) -> None: 1407 # The purpose of this return statement is to work around 1408 # https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored. 1409 return self._AssertRaisesContext(None, None, None) 1410 1411 def assertRaisesWithLiteralMatch(self, expected_exception, 1412 expected_exception_message, 1413 callable_obj=None, *args, **kwargs): 1414 """Asserts that the message in a raised exception equals the given string. 1415 1416 Unlike assertRaisesRegex, this method takes a literal string, not 1417 a regular expression. 1418 1419 with self.assertRaisesWithLiteralMatch(ExType, 'message'): 1420 DoSomething() 1421 1422 Args: 1423 expected_exception: Exception class expected to be raised. 1424 expected_exception_message: String message expected in the raised 1425 exception. For a raise exception e, expected_exception_message must 1426 equal str(e). 1427 callable_obj: Function to be called, or None to return a context. 1428 *args: Extra args. 1429 **kwargs: Extra kwargs. 1430 1431 Returns: 1432 A context manager if callable_obj is None. Otherwise, None. 1433 1434 Raises: 1435 self.failureException if callable_obj does not raise a matching exception. 1436 """ 1437 def Check(err): 1438 actual_exception_message = str(err) 1439 self.assertTrue(expected_exception_message == actual_exception_message, 1440 'Exception message does not match.\n' 1441 'Expected: %r\n' 1442 'Actual: %r' % (expected_exception_message, 1443 actual_exception_message)) 1444 1445 context = self._AssertRaisesContext(expected_exception, self, Check) 1446 if callable_obj is None: 1447 return context 1448 with context: 1449 callable_obj(*args, **kwargs) 1450 1451 def assertContainsInOrder(self, strings, target, msg=None): 1452 """Asserts that the strings provided are found in the target in order. 1453 1454 This may be useful for checking HTML output. 1455 1456 Args: 1457 strings: A list of strings, such as [ 'fox', 'dog' ] 1458 target: A target string in which to look for the strings, such as 1459 'The quick brown fox jumped over the lazy dog'. 1460 msg: Optional message to report on failure. 1461 """ 1462 if isinstance(strings, (bytes, unicode if str is bytes else str)): 1463 strings = (strings,) 1464 1465 current_index = 0 1466 last_string = None 1467 for string in strings: 1468 index = target.find(str(string), current_index) 1469 if index == -1 and current_index == 0: 1470 self.fail("Did not find '%s' in '%s'" % 1471 (string, target), msg) 1472 elif index == -1: 1473 self.fail("Did not find '%s' after '%s' in '%s'" % 1474 (string, last_string, target), msg) 1475 last_string = string 1476 current_index = index 1477 1478 def assertContainsSubsequence(self, container, subsequence, msg=None): 1479 """Asserts that "container" contains "subsequence" as a subsequence. 1480 1481 Asserts that "container" contains all the elements of "subsequence", in 1482 order, but possibly with other elements interspersed. For example, [1, 2, 3] 1483 is a subsequence of [0, 0, 1, 2, 0, 3, 0] but not of [0, 0, 1, 3, 0, 2, 0]. 1484 1485 Args: 1486 container: the list we're testing for subsequence inclusion. 1487 subsequence: the list we hope will be a subsequence of container. 1488 msg: Optional message to report on failure. 1489 """ 1490 first_nonmatching = None 1491 reversed_container = list(reversed(container)) 1492 subsequence = list(subsequence) 1493 1494 for e in subsequence: 1495 if e not in reversed_container: 1496 first_nonmatching = e 1497 break 1498 while e != reversed_container.pop(): 1499 pass 1500 1501 if first_nonmatching is not None: 1502 self.fail('%s not a subsequence of %s. First non-matching element: %s' % 1503 (subsequence, container, first_nonmatching), msg) 1504 1505 def assertContainsExactSubsequence(self, container, subsequence, msg=None): 1506 """Asserts that "container" contains "subsequence" as an exact subsequence. 1507 1508 Asserts that "container" contains all the elements of "subsequence", in 1509 order, and without other elements interspersed. For example, [1, 2, 3] is an 1510 exact subsequence of [0, 0, 1, 2, 3, 0] but not of [0, 0, 1, 2, 0, 3, 0]. 1511 1512 Args: 1513 container: the list we're testing for subsequence inclusion. 1514 subsequence: the list we hope will be an exact subsequence of container. 1515 msg: Optional message to report on failure. 1516 """ 1517 container = list(container) 1518 subsequence = list(subsequence) 1519 longest_match = 0 1520 1521 for start in range(1 + len(container) - len(subsequence)): 1522 if longest_match == len(subsequence): 1523 break 1524 index = 0 1525 while (index < len(subsequence) and 1526 subsequence[index] == container[start + index]): 1527 index += 1 1528 longest_match = max(longest_match, index) 1529 1530 if longest_match < len(subsequence): 1531 self.fail('%s not an exact subsequence of %s. ' 1532 'Longest matching prefix: %s' % 1533 (subsequence, container, subsequence[:longest_match]), msg) 1534 1535 def assertTotallyOrdered(self, *groups, **kwargs): 1536 """Asserts that total ordering has been implemented correctly. 1537 1538 For example, say you have a class A that compares only on its attribute x. 1539 Comparators other than ``__lt__`` are omitted for brevity:: 1540 1541 class A(object): 1542 def __init__(self, x, y): 1543 self.x = x 1544 self.y = y 1545 1546 def __hash__(self): 1547 return hash(self.x) 1548 1549 def __lt__(self, other): 1550 try: 1551 return self.x < other.x 1552 except AttributeError: 1553 return NotImplemented 1554 1555 assertTotallyOrdered will check that instances can be ordered correctly. 1556 For example:: 1557 1558 self.assertTotallyOrdered( 1559 [None], # None should come before everything else. 1560 [1], # Integers sort earlier. 1561 [A(1, 'a')], 1562 [A(2, 'b')], # 2 is after 1. 1563 [A(3, 'c'), A(3, 'd')], # The second argument is irrelevant. 1564 [A(4, 'z')], 1565 ['foo']) # Strings sort last. 1566 1567 Args: 1568 *groups: A list of groups of elements. Each group of elements is a list 1569 of objects that are equal. The elements in each group must be less 1570 than the elements in the group after it. For example, these groups are 1571 totally ordered: ``[None]``, ``[1]``, ``[2, 2]``, ``[3]``. 1572 **kwargs: optional msg keyword argument can be passed. 1573 """ 1574 1575 def CheckOrder(small, big): 1576 """Ensures small is ordered before big.""" 1577 self.assertFalse(small == big, 1578 self._formatMessage(msg, '%r unexpectedly equals %r' % 1579 (small, big))) 1580 self.assertTrue(small != big, 1581 self._formatMessage(msg, '%r unexpectedly equals %r' % 1582 (small, big))) 1583 self.assertLess(small, big, msg) 1584 self.assertFalse(big < small, 1585 self._formatMessage(msg, 1586 '%r unexpectedly less than %r' % 1587 (big, small))) 1588 self.assertLessEqual(small, big, msg) 1589 self.assertFalse(big <= small, self._formatMessage( 1590 '%r unexpectedly less than or equal to %r' % (big, small), msg 1591 )) 1592 self.assertGreater(big, small, msg) 1593 self.assertFalse(small > big, 1594 self._formatMessage(msg, 1595 '%r unexpectedly greater than %r' % 1596 (small, big))) 1597 self.assertGreaterEqual(big, small) 1598 self.assertFalse(small >= big, self._formatMessage( 1599 msg, 1600 '%r unexpectedly greater than or equal to %r' % (small, big))) 1601 1602 def CheckEqual(a, b): 1603 """Ensures that a and b are equal.""" 1604 self.assertEqual(a, b, msg) 1605 self.assertFalse(a != b, 1606 self._formatMessage(msg, '%r unexpectedly unequals %r' % 1607 (a, b))) 1608 1609 # Objects that compare equal must hash to the same value, but this only 1610 # applies if both objects are hashable. 1611 if (isinstance(a, abc.Hashable) and 1612 isinstance(b, abc.Hashable)): 1613 self.assertEqual( 1614 hash(a), hash(b), 1615 self._formatMessage( 1616 msg, 'hash %d of %r unexpectedly not equal to hash %d of %r' % 1617 (hash(a), a, hash(b), b))) 1618 1619 self.assertFalse(a < b, 1620 self._formatMessage(msg, 1621 '%r unexpectedly less than %r' % 1622 (a, b))) 1623 self.assertFalse(b < a, 1624 self._formatMessage(msg, 1625 '%r unexpectedly less than %r' % 1626 (b, a))) 1627 self.assertLessEqual(a, b, msg) 1628 self.assertLessEqual(b, a, msg) # pylint: disable=arguments-out-of-order 1629 self.assertFalse(a > b, 1630 self._formatMessage(msg, 1631 '%r unexpectedly greater than %r' % 1632 (a, b))) 1633 self.assertFalse(b > a, 1634 self._formatMessage(msg, 1635 '%r unexpectedly greater than %r' % 1636 (b, a))) 1637 self.assertGreaterEqual(a, b, msg) 1638 self.assertGreaterEqual(b, a, msg) # pylint: disable=arguments-out-of-order 1639 1640 msg = kwargs.get('msg') 1641 1642 # For every combination of elements, check the order of every pair of 1643 # elements. 1644 for elements in itertools.product(*groups): 1645 elements = list(elements) 1646 for index, small in enumerate(elements[:-1]): 1647 for big in elements[index + 1:]: 1648 CheckOrder(small, big) 1649 1650 # Check that every element in each group is equal. 1651 for group in groups: 1652 for a in group: 1653 CheckEqual(a, a) 1654 for a, b in itertools.product(group, group): 1655 CheckEqual(a, b) 1656 1657 def assertDictEqual(self, a, b, msg=None): 1658 """Raises AssertionError if a and b are not equal dictionaries. 1659 1660 Args: 1661 a: A dict, the expected value. 1662 b: A dict, the actual value. 1663 msg: An optional str, the associated message. 1664 1665 Raises: 1666 AssertionError: if the dictionaries are not equal. 1667 """ 1668 self.assertIsInstance(a, dict, self._formatMessage( 1669 msg, 1670 'First argument is not a dictionary' 1671 )) 1672 self.assertIsInstance(b, dict, self._formatMessage( 1673 msg, 1674 'Second argument is not a dictionary' 1675 )) 1676 1677 def Sorted(list_of_items): 1678 try: 1679 return sorted(list_of_items) # In 3.3, unordered are possible. 1680 except TypeError: 1681 return list_of_items 1682 1683 if a == b: 1684 return 1685 a_items = Sorted(list(a.items())) 1686 b_items = Sorted(list(b.items())) 1687 1688 unexpected = [] 1689 missing = [] 1690 different = [] 1691 1692 safe_repr = unittest.util.safe_repr # pytype: disable=module-attr 1693 1694 def Repr(dikt): 1695 """Deterministic repr for dict.""" 1696 # Sort the entries based on their repr, not based on their sort order, 1697 # which will be non-deterministic across executions, for many types. 1698 entries = sorted((safe_repr(k), safe_repr(v)) for k, v in dikt.items()) 1699 return '{%s}' % (', '.join('%s: %s' % pair for pair in entries)) 1700 1701 message = ['%s != %s%s' % (Repr(a), Repr(b), ' (%s)' % msg if msg else '')] 1702 1703 # The standard library default output confounds lexical difference with 1704 # value difference; treat them separately. 1705 for a_key, a_value in a_items: 1706 if a_key not in b: 1707 missing.append((a_key, a_value)) 1708 elif a_value != b[a_key]: 1709 different.append((a_key, a_value, b[a_key])) 1710 1711 for b_key, b_value in b_items: 1712 if b_key not in a: 1713 unexpected.append((b_key, b_value)) 1714 1715 if unexpected: 1716 message.append( 1717 'Unexpected, but present entries:\n%s' % ''.join( 1718 '%s: %s\n' % (safe_repr(k), safe_repr(v)) for k, v in unexpected)) 1719 1720 if different: 1721 message.append( 1722 'repr() of differing entries:\n%s' % ''.join( 1723 '%s: %s != %s\n' % (safe_repr(k), safe_repr(a_value), 1724 safe_repr(b_value)) 1725 for k, a_value, b_value in different)) 1726 1727 if missing: 1728 message.append( 1729 'Missing entries:\n%s' % ''.join( 1730 ('%s: %s\n' % (safe_repr(k), safe_repr(v)) for k, v in missing))) 1731 1732 raise self.failureException('\n'.join(message)) 1733 1734 def assertDataclassEqual(self, first, second, msg=None): 1735 """Asserts two dataclasses are equal with more informative errors. 1736 1737 Arguments must both be dataclasses. This compares equality of individual 1738 fields and takes care to not compare fields that are marked as 1739 non-comparable. It gives per field differences, which are easier to parse 1740 than the comparison of the string representations from assertEqual. 1741 1742 In cases where the dataclass has a custom __eq__, and it is defined in a 1743 way that is inconsistent with equality of comparable fields, we raise an 1744 exception without further trying to figure out how they are different. 1745 1746 Args: 1747 first: A dataclass, the first value. 1748 second: A dataclass, the second value. 1749 msg: An optional str, the associated message. 1750 1751 Raises: 1752 AssertionError: if the dataclasses are not equal. 1753 """ 1754 1755 if not dataclasses.is_dataclass(first) or isinstance(first, type): 1756 raise self.failureException('First argument is not a dataclass instance.') 1757 if not dataclasses.is_dataclass(second) or isinstance(second, type): 1758 raise self.failureException( 1759 'Second argument is not a dataclass instance.' 1760 ) 1761 1762 if first == second: 1763 return 1764 1765 if type(first) is not type(second): 1766 self.fail( 1767 'Found different dataclass types: %s != %s' 1768 % (type(first), type(second)), 1769 msg, 1770 ) 1771 1772 # Make sure to skip fields that are marked compare=False. 1773 different = [ 1774 (f.name, getattr(first, f.name), getattr(second, f.name)) 1775 for f in dataclasses.fields(first) 1776 if f.compare and getattr(first, f.name) != getattr(second, f.name) 1777 ] 1778 1779 safe_repr = unittest.util.safe_repr # pytype: disable=module-attr 1780 message = ['%s != %s' % (safe_repr(first), safe_repr(second))] 1781 if different: 1782 message.append('Fields that differ:') 1783 message.extend( 1784 '%s: %s != %s' % (k, safe_repr(first_v), safe_repr(second_v)) 1785 for k, first_v, second_v in different 1786 ) 1787 else: 1788 message.append( 1789 'Cannot detect difference by examining the fields of the dataclass.' 1790 ) 1791 1792 raise self.fail('\n'.join(message), msg) 1793 1794 def assertUrlEqual(self, a, b, msg=None): 1795 """Asserts that urls are equal, ignoring ordering of query params.""" 1796 parsed_a = parse.urlparse(a) 1797 parsed_b = parse.urlparse(b) 1798 self.assertEqual(parsed_a.scheme, parsed_b.scheme, msg) 1799 self.assertEqual(parsed_a.netloc, parsed_b.netloc, msg) 1800 self.assertEqual(parsed_a.path, parsed_b.path, msg) 1801 self.assertEqual(parsed_a.fragment, parsed_b.fragment, msg) 1802 self.assertEqual(sorted(parsed_a.params.split(';')), 1803 sorted(parsed_b.params.split(';')), msg) 1804 self.assertDictEqual( 1805 parse.parse_qs(parsed_a.query, keep_blank_values=True), 1806 parse.parse_qs(parsed_b.query, keep_blank_values=True), msg) 1807 1808 def assertSameStructure(self, a, b, aname='a', bname='b', msg=None): 1809 """Asserts that two values contain the same structural content. 1810 1811 The two arguments should be data trees consisting of trees of dicts and 1812 lists. They will be deeply compared by walking into the contents of dicts 1813 and lists; other items will be compared using the == operator. 1814 If the two structures differ in content, the failure message will indicate 1815 the location within the structures where the first difference is found. 1816 This may be helpful when comparing large structures. 1817 1818 Mixed Sequence and Set types are supported. Mixed Mapping types are 1819 supported, but the order of the keys will not be considered in the 1820 comparison. 1821 1822 Args: 1823 a: The first structure to compare. 1824 b: The second structure to compare. 1825 aname: Variable name to use for the first structure in assertion messages. 1826 bname: Variable name to use for the second structure. 1827 msg: Additional text to include in the failure message. 1828 """ 1829 1830 # Accumulate all the problems found so we can report all of them at once 1831 # rather than just stopping at the first 1832 problems = [] 1833 1834 _walk_structure_for_problems(a, b, aname, bname, problems, 1835 self.assertEqual, self.failureException) 1836 1837 # Avoid spamming the user toooo much 1838 if self.maxDiff is not None: 1839 max_problems_to_show = self.maxDiff // 80 1840 if len(problems) > max_problems_to_show: 1841 problems = problems[0:max_problems_to_show-1] + ['...'] 1842 1843 if problems: 1844 self.fail('; '.join(problems), msg) 1845 1846 def assertJsonEqual(self, first, second, msg=None): 1847 """Asserts that the JSON objects defined in two strings are equal. 1848 1849 A summary of the differences will be included in the failure message 1850 using assertSameStructure. 1851 1852 Args: 1853 first: A string containing JSON to decode and compare to second. 1854 second: A string containing JSON to decode and compare to first. 1855 msg: Additional text to include in the failure message. 1856 """ 1857 try: 1858 first_structured = json.loads(first) 1859 except ValueError as e: 1860 raise ValueError(self._formatMessage( 1861 msg, 1862 'could not decode first JSON value %s: %s' % (first, e))) 1863 1864 try: 1865 second_structured = json.loads(second) 1866 except ValueError as e: 1867 raise ValueError(self._formatMessage( 1868 msg, 1869 'could not decode second JSON value %s: %s' % (second, e))) 1870 1871 self.assertSameStructure(first_structured, second_structured, 1872 aname='first', bname='second', msg=msg) 1873 1874 def _getAssertEqualityFunc(self, first, second): 1875 # type: (Any, Any) -> Callable[..., None] 1876 try: 1877 return super(TestCase, self)._getAssertEqualityFunc(first, second) 1878 except AttributeError: 1879 # This is a workaround if unittest.TestCase.__init__ was never run. 1880 # It usually means that somebody created a subclass just for the 1881 # assertions and has overridden __init__. "assertTrue" is a safe 1882 # value that will not make __init__ raise a ValueError. 1883 test_method = getattr(self, '_testMethodName', 'assertTrue') 1884 super(TestCase, self).__init__(test_method) 1885 1886 return super(TestCase, self)._getAssertEqualityFunc(first, second) 1887 1888 def fail(self, msg=None, user_msg=None) -> NoReturn: 1889 """Fail immediately with the given standard message and user message.""" 1890 return super(TestCase, self).fail(self._formatMessage(user_msg, msg)) 1891 1892 1893def _sorted_list_difference(expected, actual): 1894 # type: (List[_T], List[_T]) -> Tuple[List[_T], List[_T]] 1895 """Finds elements in only one or the other of two, sorted input lists. 1896 1897 Returns a two-element tuple of lists. The first list contains those 1898 elements in the "expected" list but not in the "actual" list, and the 1899 second contains those elements in the "actual" list but not in the 1900 "expected" list. Duplicate elements in either input list are ignored. 1901 1902 Args: 1903 expected: The list we expected. 1904 actual: The list we actually got. 1905 Returns: 1906 (missing, unexpected) 1907 missing: items in expected that are not in actual. 1908 unexpected: items in actual that are not in expected. 1909 """ 1910 i = j = 0 1911 missing = [] 1912 unexpected = [] 1913 while True: 1914 try: 1915 e = expected[i] 1916 a = actual[j] 1917 if e < a: 1918 missing.append(e) 1919 i += 1 1920 while expected[i] == e: 1921 i += 1 1922 elif e > a: 1923 unexpected.append(a) 1924 j += 1 1925 while actual[j] == a: 1926 j += 1 1927 else: 1928 i += 1 1929 try: 1930 while expected[i] == e: 1931 i += 1 1932 finally: 1933 j += 1 1934 while actual[j] == a: 1935 j += 1 1936 except IndexError: 1937 missing.extend(expected[i:]) 1938 unexpected.extend(actual[j:]) 1939 break 1940 return missing, unexpected 1941 1942 1943def _are_both_of_integer_type(a, b): 1944 # type: (object, object) -> bool 1945 return isinstance(a, int) and isinstance(b, int) 1946 1947 1948def _are_both_of_sequence_type(a, b): 1949 # type: (object, object) -> bool 1950 return isinstance(a, abc.Sequence) and isinstance( 1951 b, abc.Sequence) and not isinstance( 1952 a, _TEXT_OR_BINARY_TYPES) and not isinstance(b, _TEXT_OR_BINARY_TYPES) 1953 1954 1955def _are_both_of_set_type(a, b): 1956 # type: (object, object) -> bool 1957 return isinstance(a, abc.Set) and isinstance(b, abc.Set) 1958 1959 1960def _are_both_of_mapping_type(a, b): 1961 # type: (object, object) -> bool 1962 return isinstance(a, abc.Mapping) and isinstance( 1963 b, abc.Mapping) 1964 1965 1966def _walk_structure_for_problems( 1967 a, b, aname, bname, problem_list, leaf_assert_equal_func, failure_exception 1968): 1969 """The recursive comparison behind assertSameStructure.""" 1970 if type(a) != type(b) and not ( # pylint: disable=unidiomatic-typecheck 1971 _are_both_of_integer_type(a, b) or _are_both_of_sequence_type(a, b) or 1972 _are_both_of_set_type(a, b) or _are_both_of_mapping_type(a, b)): 1973 # We do not distinguish between int and long types as 99.99% of Python 2 1974 # code should never care. They collapse into a single type in Python 3. 1975 problem_list.append('%s is a %r but %s is a %r' % 1976 (aname, type(a), bname, type(b))) 1977 # If they have different types there's no point continuing 1978 return 1979 1980 if isinstance(a, abc.Set): 1981 for k in a: 1982 if k not in b: 1983 problem_list.append( 1984 '%s has %r but %s does not' % (aname, k, bname)) 1985 for k in b: 1986 if k not in a: 1987 problem_list.append('%s lacks %r but %s has it' % (aname, k, bname)) 1988 1989 # NOTE: a or b could be a defaultdict, so we must take care that the traversal 1990 # doesn't modify the data. 1991 elif isinstance(a, abc.Mapping): 1992 for k in a: 1993 if k in b: 1994 _walk_structure_for_problems( 1995 a[k], b[k], '%s[%r]' % (aname, k), '%s[%r]' % (bname, k), 1996 problem_list, leaf_assert_equal_func, failure_exception) 1997 else: 1998 problem_list.append( 1999 "%s has [%r] with value %r but it's missing in %s" % 2000 (aname, k, a[k], bname)) 2001 for k in b: 2002 if k not in a: 2003 problem_list.append( 2004 '%s lacks [%r] but %s has it with value %r' % 2005 (aname, k, bname, b[k])) 2006 2007 # Strings/bytes are Sequences but we'll just do those with regular != 2008 elif (isinstance(a, abc.Sequence) and 2009 not isinstance(a, _TEXT_OR_BINARY_TYPES)): 2010 minlen = min(len(a), len(b)) 2011 for i in range(minlen): 2012 _walk_structure_for_problems( 2013 a[i], b[i], '%s[%d]' % (aname, i), '%s[%d]' % (bname, i), 2014 problem_list, leaf_assert_equal_func, failure_exception) 2015 for i in range(minlen, len(a)): 2016 problem_list.append('%s has [%i] with value %r but %s does not' % 2017 (aname, i, a[i], bname)) 2018 for i in range(minlen, len(b)): 2019 problem_list.append('%s lacks [%i] but %s has it with value %r' % 2020 (aname, i, bname, b[i])) 2021 2022 else: 2023 try: 2024 leaf_assert_equal_func(a, b) 2025 except failure_exception: 2026 problem_list.append('%s is %r but %s is %r' % (aname, a, bname, b)) 2027 2028 2029def get_command_string(command): 2030 """Returns an escaped string that can be used as a shell command. 2031 2032 Args: 2033 command: List or string representing the command to run. 2034 Returns: 2035 A string suitable for use as a shell command. 2036 """ 2037 if isinstance(command, str): 2038 return command 2039 else: 2040 if os.name == 'nt': 2041 return ' '.join(command) 2042 else: 2043 # The following is identical to Python 3's shlex.quote function. 2044 command_string = '' 2045 for word in command: 2046 # Single quote word, and replace each ' in word with '"'"' 2047 command_string += "'" + word.replace("'", "'\"'\"'") + "' " 2048 return command_string[:-1] 2049 2050 2051def get_command_stderr(command, env=None, close_fds=True): 2052 """Runs the given shell command and returns a tuple. 2053 2054 Args: 2055 command: List or string representing the command to run. 2056 env: Dictionary of environment variable settings. If None, no environment 2057 variables will be set for the child process. This is to make tests 2058 more hermetic. NOTE: this behavior is different than the standard 2059 subprocess module. 2060 close_fds: Whether or not to close all open fd's in the child after forking. 2061 On Windows, this is ignored and close_fds is always False. 2062 2063 Returns: 2064 Tuple of (exit status, text printed to stdout and stderr by the command). 2065 """ 2066 if env is None: env = {} 2067 if os.name == 'nt': 2068 # Windows does not support setting close_fds to True while also redirecting 2069 # standard handles. 2070 close_fds = False 2071 2072 use_shell = isinstance(command, str) 2073 process = subprocess.Popen( 2074 command, 2075 close_fds=close_fds, 2076 env=env, 2077 shell=use_shell, 2078 stderr=subprocess.STDOUT, 2079 stdout=subprocess.PIPE) 2080 output = process.communicate()[0] 2081 exit_status = process.wait() 2082 return (exit_status, output) 2083 2084 2085def _quote_long_string(s): 2086 # type: (Union[Text, bytes, bytearray]) -> Text 2087 """Quotes a potentially multi-line string to make the start and end obvious. 2088 2089 Args: 2090 s: A string. 2091 2092 Returns: 2093 The quoted string. 2094 """ 2095 if isinstance(s, (bytes, bytearray)): 2096 try: 2097 s = s.decode('utf-8') 2098 except UnicodeDecodeError: 2099 s = str(s) 2100 return ('8<-----------\n' + 2101 s + '\n' + 2102 '----------->8\n') 2103 2104 2105def print_python_version(): 2106 # type: () -> None 2107 # Having this in the test output logs by default helps debugging when all 2108 # you've got is the log and no other idea of which Python was used. 2109 sys.stderr.write('Running tests under Python {0[0]}.{0[1]}.{0[2]}: ' 2110 '{1}\n'.format( 2111 sys.version_info, 2112 sys.executable if sys.executable else 'embedded.')) 2113 2114 2115def main(*args, **kwargs): 2116 # type: (Text, Any) -> None 2117 """Executes a set of Python unit tests. 2118 2119 Usually this function is called without arguments, so the 2120 unittest.TestProgram instance will get created with the default settings, 2121 so it will run all test methods of all TestCase classes in the ``__main__`` 2122 module. 2123 2124 Args: 2125 *args: Positional arguments passed through to 2126 ``unittest.TestProgram.__init__``. 2127 **kwargs: Keyword arguments passed through to 2128 ``unittest.TestProgram.__init__``. 2129 """ 2130 print_python_version() 2131 _run_in_app(run_tests, args, kwargs) 2132 2133 2134def _is_in_app_main(): 2135 # type: () -> bool 2136 """Returns True iff app.run is active.""" 2137 f = sys._getframe().f_back # pylint: disable=protected-access 2138 while f: 2139 if f.f_code == app.run.__code__: 2140 return True 2141 f = f.f_back 2142 return False 2143 2144 2145def _register_sigterm_with_faulthandler(): 2146 # type: () -> None 2147 """Have faulthandler dump stacks on SIGTERM. Useful to diagnose timeouts.""" 2148 if getattr(faulthandler, 'register', None): 2149 # faulthandler.register is not available on Windows. 2150 # faulthandler.enable() is already called by app.run. 2151 try: 2152 faulthandler.register(signal.SIGTERM, chain=True) # pytype: disable=module-attr 2153 except Exception as e: # pylint: disable=broad-except 2154 sys.stderr.write('faulthandler.register(SIGTERM) failed ' 2155 '%r; ignoring.\n' % e) 2156 2157 2158def _run_in_app(function, args, kwargs): 2159 # type: (Callable[..., None], Sequence[Text], Mapping[Text, Any]) -> None 2160 """Executes a set of Python unit tests, ensuring app.run. 2161 2162 This is a private function, users should call absltest.main(). 2163 2164 _run_in_app calculates argv to be the command-line arguments of this program 2165 (without the flags), sets the default of FLAGS.alsologtostderr to True, 2166 then it calls function(argv, args, kwargs), making sure that `function' 2167 will get called within app.run(). _run_in_app does this by checking whether 2168 it is called by app.run(), or by calling app.run() explicitly. 2169 2170 The reason why app.run has to be ensured is to make sure that 2171 flags are parsed and stripped properly, and other initializations done by 2172 the app module are also carried out, no matter if absltest.run() is called 2173 from within or outside app.run(). 2174 2175 If _run_in_app is called from within app.run(), then it will reparse 2176 sys.argv and pass the result without command-line flags into the argv 2177 argument of `function'. The reason why this parsing is needed is that 2178 __main__.main() calls absltest.main() without passing its argv. So the 2179 only way _run_in_app could get to know the argv without the flags is that 2180 it reparses sys.argv. 2181 2182 _run_in_app changes the default of FLAGS.alsologtostderr to True so that the 2183 test program's stderr will contain all the log messages unless otherwise 2184 specified on the command-line. This overrides any explicit assignment to 2185 FLAGS.alsologtostderr by the test program prior to the call to _run_in_app() 2186 (e.g. in __main__.main). 2187 2188 Please note that _run_in_app (and the function it calls) is allowed to make 2189 changes to kwargs. 2190 2191 Args: 2192 function: absltest.run_tests or a similar function. It will be called as 2193 function(argv, args, kwargs) where argv is a list containing the 2194 elements of sys.argv without the command-line flags. 2195 args: Positional arguments passed through to unittest.TestProgram.__init__. 2196 kwargs: Keyword arguments passed through to unittest.TestProgram.__init__. 2197 """ 2198 if _is_in_app_main(): 2199 _register_sigterm_with_faulthandler() 2200 2201 # Change the default of alsologtostderr from False to True, so the test 2202 # programs's stderr will contain all the log messages. 2203 # If --alsologtostderr=false is specified in the command-line, or user 2204 # has called FLAGS.alsologtostderr = False before, then the value is kept 2205 # False. 2206 FLAGS.set_default('alsologtostderr', True) 2207 2208 # Here we only want to get the `argv` without the flags. To avoid any 2209 # side effects of parsing flags, we temporarily stub out the `parse` method 2210 stored_parse_methods = {} 2211 noop_parse = lambda _: None 2212 for name in FLAGS: 2213 # Avoid any side effects of parsing flags. 2214 stored_parse_methods[name] = FLAGS[name].parse 2215 # This must be a separate loop since multiple flag names (short_name=) can 2216 # point to the same flag object. 2217 for name in FLAGS: 2218 FLAGS[name].parse = noop_parse 2219 try: 2220 argv = FLAGS(sys.argv) 2221 finally: 2222 for name in FLAGS: 2223 FLAGS[name].parse = stored_parse_methods[name] 2224 sys.stdout.flush() 2225 2226 function(argv, args, kwargs) 2227 else: 2228 # Send logging to stderr. Use --alsologtostderr instead of --logtostderr 2229 # in case tests are reading their own logs. 2230 FLAGS.set_default('alsologtostderr', True) 2231 2232 def main_function(argv): 2233 _register_sigterm_with_faulthandler() 2234 function(argv, args, kwargs) 2235 2236 app.run(main=main_function) 2237 2238 2239def _is_suspicious_attribute(testCaseClass, name): 2240 # type: (Type, Text) -> bool 2241 """Returns True if an attribute is a method named like a test method.""" 2242 if name.startswith('Test') and len(name) > 4 and name[4].isupper(): 2243 attr = getattr(testCaseClass, name) 2244 if inspect.isfunction(attr) or inspect.ismethod(attr): 2245 args = inspect.getfullargspec(attr) 2246 return (len(args.args) == 1 and args.args[0] == 'self' and 2247 args.varargs is None and args.varkw is None and 2248 not args.kwonlyargs) 2249 return False 2250 2251 2252def skipThisClass(reason): 2253 # type: (Text) -> Callable[[_T], _T] 2254 """Skip tests in the decorated TestCase, but not any of its subclasses. 2255 2256 This decorator indicates that this class should skip all its tests, but not 2257 any of its subclasses. Useful for if you want to share testMethod or setUp 2258 implementations between a number of concrete testcase classes. 2259 2260 Example usage, showing how you can share some common test methods between 2261 subclasses. In this example, only ``BaseTest`` will be marked as skipped, and 2262 not RealTest or SecondRealTest:: 2263 2264 @absltest.skipThisClass("Shared functionality") 2265 class BaseTest(absltest.TestCase): 2266 def test_simple_functionality(self): 2267 self.assertEqual(self.system_under_test.method(), 1) 2268 2269 class RealTest(BaseTest): 2270 def setUp(self): 2271 super().setUp() 2272 self.system_under_test = MakeSystem(argument) 2273 2274 def test_specific_behavior(self): 2275 ... 2276 2277 class SecondRealTest(BaseTest): 2278 def setUp(self): 2279 super().setUp() 2280 self.system_under_test = MakeSystem(other_arguments) 2281 2282 def test_other_behavior(self): 2283 ... 2284 2285 Args: 2286 reason: The reason we have a skip in place. For instance: 'shared test 2287 methods' or 'shared assertion methods'. 2288 2289 Returns: 2290 Decorator function that will cause a class to be skipped. 2291 """ 2292 if isinstance(reason, type): 2293 raise TypeError('Got {!r}, expected reason as string'.format(reason)) 2294 2295 def _skip_class(test_case_class): 2296 if not issubclass(test_case_class, unittest.TestCase): 2297 raise TypeError( 2298 'Decorating {!r}, expected TestCase subclass'.format(test_case_class)) 2299 2300 # Only shadow the setUpClass method if it is directly defined. If it is 2301 # in the parent class we invoke it via a super() call instead of holding 2302 # a reference to it. 2303 shadowed_setupclass = test_case_class.__dict__.get('setUpClass', None) 2304 2305 @classmethod 2306 def replacement_setupclass(cls, *args, **kwargs): 2307 # Skip this class if it is the one that was decorated with @skipThisClass 2308 if cls is test_case_class: 2309 raise SkipTest(reason) 2310 if shadowed_setupclass: 2311 # Pass along `cls` so the MRO chain doesn't break. 2312 # The original method is a `classmethod` descriptor, which can't 2313 # be directly called, but `__func__` has the underlying function. 2314 return shadowed_setupclass.__func__(cls, *args, **kwargs) 2315 else: 2316 # Because there's no setUpClass() defined directly on test_case_class, 2317 # we call super() ourselves to continue execution of the inheritance 2318 # chain. 2319 return super(test_case_class, cls).setUpClass(*args, **kwargs) 2320 2321 test_case_class.setUpClass = replacement_setupclass 2322 return test_case_class 2323 2324 return _skip_class 2325 2326 2327class TestLoader(unittest.TestLoader): 2328 """A test loader which supports common test features. 2329 2330 Supported features include: 2331 * Banning untested methods with test-like names: methods attached to this 2332 testCase with names starting with `Test` are ignored by the test runner, 2333 and often represent mistakenly-omitted test cases. This loader will raise 2334 a TypeError when attempting to load a TestCase with such methods. 2335 * Randomization of test case execution order (optional). 2336 """ 2337 2338 _ERROR_MSG = textwrap.dedent("""Method '%s' is named like a test case but 2339 is not one. This is often a bug. If you want it to be a test method, 2340 name it with 'test' in lowercase. If not, rename the method to not begin 2341 with 'Test'.""") 2342 2343 def __init__(self, *args, **kwds): 2344 super(TestLoader, self).__init__(*args, **kwds) 2345 seed = _get_default_randomize_ordering_seed() 2346 if seed: 2347 self._randomize_ordering_seed = seed 2348 self._random = random.Random(self._randomize_ordering_seed) 2349 else: 2350 self._randomize_ordering_seed = None 2351 self._random = None 2352 2353 def getTestCaseNames(self, testCaseClass): # pylint:disable=invalid-name 2354 """Validates and returns a (possibly randomized) list of test case names.""" 2355 for name in dir(testCaseClass): 2356 if _is_suspicious_attribute(testCaseClass, name): 2357 raise TypeError(TestLoader._ERROR_MSG % name) 2358 names = list(super(TestLoader, self).getTestCaseNames(testCaseClass)) 2359 if self._randomize_ordering_seed is not None: 2360 logging.info( 2361 'Randomizing test order with seed: %d', self._randomize_ordering_seed) 2362 logging.info( 2363 'To reproduce this order, re-run with ' 2364 '--test_randomize_ordering_seed=%d', self._randomize_ordering_seed) 2365 self._random.shuffle(names) 2366 return names 2367 2368 2369def get_default_xml_output_filename(): 2370 # type: () -> Optional[Text] 2371 if os.environ.get('XML_OUTPUT_FILE'): 2372 return os.environ['XML_OUTPUT_FILE'] 2373 elif os.environ.get('RUNNING_UNDER_TEST_DAEMON'): 2374 return os.path.join(os.path.dirname(TEST_TMPDIR.value), 'test_detail.xml') 2375 elif os.environ.get('TEST_XMLOUTPUTDIR'): 2376 return os.path.join( 2377 os.environ['TEST_XMLOUTPUTDIR'], 2378 os.path.splitext(os.path.basename(sys.argv[0]))[0] + '.xml') 2379 2380 2381def _setup_filtering(argv: MutableSequence[str]) -> bool: 2382 """Implements the bazel test filtering protocol. 2383 2384 The following environment variable is used in this method: 2385 2386 TESTBRIDGE_TEST_ONLY: string, if set, is forwarded to the unittest 2387 framework to use as a test filter. Its value is split with shlex, then: 2388 1. On Python 3.6 and before, split values are passed as positional 2389 arguments on argv. 2390 2. On Python 3.7+, split values are passed to unittest's `-k` flag. Tests 2391 are matched by glob patterns or substring. See 2392 https://docs.python.org/3/library/unittest.html#cmdoption-unittest-k 2393 2394 Args: 2395 argv: the argv to mutate in-place. 2396 2397 Returns: 2398 Whether test filtering is requested. 2399 """ 2400 test_filter = os.environ.get('TESTBRIDGE_TEST_ONLY') 2401 if argv is None or not test_filter: 2402 return False 2403 2404 filters = shlex.split(test_filter) 2405 if sys.version_info[:2] >= (3, 7): 2406 filters = ['-k=' + test_filter for test_filter in filters] 2407 2408 argv[1:1] = filters 2409 return True 2410 2411 2412def _setup_test_runner_fail_fast(argv): 2413 # type: (MutableSequence[Text]) -> None 2414 """Implements the bazel test fail fast protocol. 2415 2416 The following environment variable is used in this method: 2417 2418 TESTBRIDGE_TEST_RUNNER_FAIL_FAST=<1|0> 2419 2420 If set to 1, --failfast is passed to the unittest framework to return upon 2421 first failure. 2422 2423 Args: 2424 argv: the argv to mutate in-place. 2425 """ 2426 2427 if argv is None: 2428 return 2429 2430 if os.environ.get('TESTBRIDGE_TEST_RUNNER_FAIL_FAST') != '1': 2431 return 2432 2433 argv[1:1] = ['--failfast'] 2434 2435 2436def _setup_sharding( 2437 custom_loader: Optional[unittest.TestLoader] = None, 2438) -> Tuple[unittest.TestLoader, Optional[int]]: 2439 """Implements the bazel sharding protocol. 2440 2441 The following environment variables are used in this method: 2442 2443 TEST_SHARD_STATUS_FILE: string, if set, points to a file. We write a blank 2444 file to tell the test runner that this test implements the test sharding 2445 protocol. 2446 2447 TEST_TOTAL_SHARDS: int, if set, sharding is requested. 2448 2449 TEST_SHARD_INDEX: int, must be set if TEST_TOTAL_SHARDS is set. Specifies 2450 the shard index for this instance of the test process. Must satisfy: 2451 0 <= TEST_SHARD_INDEX < TEST_TOTAL_SHARDS. 2452 2453 Args: 2454 custom_loader: A TestLoader to be made sharded. 2455 2456 Returns: 2457 A tuple of ``(test_loader, shard_index)``. ``test_loader`` is for 2458 shard-filtering or the standard test loader depending on the sharding 2459 environment variables. ``shard_index`` is the shard index, or ``None`` when 2460 sharding is not used. 2461 """ 2462 2463 # It may be useful to write the shard file even if the other sharding 2464 # environment variables are not set. Test runners may use this functionality 2465 # to query whether a test binary implements the test sharding protocol. 2466 if 'TEST_SHARD_STATUS_FILE' in os.environ: 2467 try: 2468 with open(os.environ['TEST_SHARD_STATUS_FILE'], 'w') as f: 2469 f.write('') 2470 except IOError: 2471 sys.stderr.write('Error opening TEST_SHARD_STATUS_FILE (%s). Exiting.' 2472 % os.environ['TEST_SHARD_STATUS_FILE']) 2473 sys.exit(1) 2474 2475 base_loader = custom_loader or TestLoader() 2476 if 'TEST_TOTAL_SHARDS' not in os.environ: 2477 # Not using sharding, use the expected test loader. 2478 return base_loader, None 2479 2480 total_shards = int(os.environ['TEST_TOTAL_SHARDS']) 2481 shard_index = int(os.environ['TEST_SHARD_INDEX']) 2482 2483 if shard_index < 0 or shard_index >= total_shards: 2484 sys.stderr.write('ERROR: Bad sharding values. index=%d, total=%d\n' % 2485 (shard_index, total_shards)) 2486 sys.exit(1) 2487 2488 # Replace the original getTestCaseNames with one that returns 2489 # the test case names for this shard. 2490 delegate_get_names = base_loader.getTestCaseNames 2491 2492 bucket_iterator = itertools.cycle(range(total_shards)) 2493 2494 def getShardedTestCaseNames(testCaseClass): 2495 filtered_names = [] 2496 # We need to sort the list of tests in order to determine which tests this 2497 # shard is responsible for; however, it's important to preserve the order 2498 # returned by the base loader, e.g. in the case of randomized test ordering. 2499 ordered_names = delegate_get_names(testCaseClass) 2500 for testcase in sorted(ordered_names): 2501 bucket = next(bucket_iterator) 2502 if bucket == shard_index: 2503 filtered_names.append(testcase) 2504 return [x for x in ordered_names if x in filtered_names] 2505 2506 base_loader.getTestCaseNames = getShardedTestCaseNames 2507 return base_loader, shard_index 2508 2509 2510def _run_and_get_tests_result( 2511 argv: MutableSequence[str], 2512 args: Sequence[Any], 2513 kwargs: MutableMapping[str, Any], 2514 xml_test_runner_class: Type[unittest.TextTestRunner], 2515) -> Tuple[unittest.TestResult, bool]: 2516 """Same as run_tests, but it doesn't exit. 2517 2518 Args: 2519 argv: sys.argv with the command-line flags removed from the front, i.e. the 2520 argv with which :func:`app.run()<absl.app.run>` has called 2521 ``__main__.main``. It is passed to 2522 ``unittest.TestProgram.__init__(argv=)``, which does its own flag parsing. 2523 It is ignored if kwargs contains an argv entry. 2524 args: Positional arguments passed through to 2525 ``unittest.TestProgram.__init__``. 2526 kwargs: Keyword arguments passed through to 2527 ``unittest.TestProgram.__init__``. 2528 xml_test_runner_class: The type of the test runner class. 2529 2530 Returns: 2531 A tuple of ``(test_result, fail_when_no_tests_ran)``. 2532 ``fail_when_no_tests_ran`` indicates whether the test should fail when 2533 no tests ran. 2534 """ 2535 2536 # The entry from kwargs overrides argv. 2537 argv = kwargs.pop('argv', argv) 2538 2539 if sys.version_info[:2] >= (3, 12): 2540 # Python 3.12 unittest changed the behavior from PASS to FAIL in 2541 # https://github.com/python/cpython/pull/102051. absltest follows this. 2542 fail_when_no_tests_ran = True 2543 else: 2544 # Historically, absltest and unittest before Python 3.12 passes if no tests 2545 # ran. 2546 fail_when_no_tests_ran = False 2547 2548 # Set up test filtering if requested in environment. 2549 if _setup_filtering(argv): 2550 # When test filtering is requested, ideally we also want to fail when no 2551 # tests ran. However, the test filters are usually done when running bazel. 2552 # When you run multiple targets, e.g. `bazel test //my_dir/... 2553 # --test_filter=MyTest`, you don't necessarily want individual tests to fail 2554 # because no tests match in that particular target. 2555 # Due to this use case, we don't fail when test filtering is requested via 2556 # the environment variable from bazel. 2557 fail_when_no_tests_ran = False 2558 2559 # Set up --failfast as requested in environment 2560 _setup_test_runner_fail_fast(argv) 2561 2562 # Shard the (default or custom) loader if sharding is turned on. 2563 kwargs['testLoader'], shard_index = _setup_sharding( 2564 kwargs.get('testLoader', None) 2565 ) 2566 if shard_index is not None and shard_index > 0: 2567 # When sharding is requested, all the shards except the first one shall not 2568 # fail when no tests ran. This happens when the shard count is greater than 2569 # the test case count. 2570 fail_when_no_tests_ran = False 2571 2572 # XML file name is based upon (sorted by priority): 2573 # --xml_output_file flag, XML_OUTPUT_FILE variable, 2574 # TEST_XMLOUTPUTDIR variable or RUNNING_UNDER_TEST_DAEMON variable. 2575 if not FLAGS.xml_output_file: 2576 FLAGS.xml_output_file = get_default_xml_output_filename() 2577 xml_output_file = FLAGS.xml_output_file 2578 2579 xml_buffer = None 2580 if xml_output_file: 2581 xml_output_dir = os.path.dirname(xml_output_file) 2582 if xml_output_dir and not os.path.isdir(xml_output_dir): 2583 try: 2584 os.makedirs(xml_output_dir) 2585 except OSError as e: 2586 # File exists error can occur with concurrent tests 2587 if e.errno != errno.EEXIST: 2588 raise 2589 # Fail early if we can't write to the XML output file. This is so that we 2590 # don't waste people's time running tests that will just fail anyways. 2591 with _open(xml_output_file, 'w'): 2592 pass 2593 2594 # We can reuse testRunner if it supports XML output (e. g. by inheriting 2595 # from xml_reporter.TextAndXMLTestRunner). Otherwise we need to use 2596 # xml_reporter.TextAndXMLTestRunner. 2597 if (kwargs.get('testRunner') is not None 2598 and not hasattr(kwargs['testRunner'], 'set_default_xml_stream')): 2599 sys.stderr.write('WARNING: XML_OUTPUT_FILE or --xml_output_file setting ' 2600 'overrides testRunner=%r setting (possibly from --pdb)' 2601 % (kwargs['testRunner'])) 2602 # Passing a class object here allows TestProgram to initialize 2603 # instances based on its kwargs and/or parsed command-line args. 2604 kwargs['testRunner'] = xml_test_runner_class 2605 if kwargs.get('testRunner') is None: 2606 kwargs['testRunner'] = xml_test_runner_class 2607 # Use an in-memory buffer (not backed by the actual file) to store the XML 2608 # report, because some tools modify the file (e.g., create a placeholder 2609 # with partial information, in case the test process crashes). 2610 xml_buffer = io.StringIO() 2611 kwargs['testRunner'].set_default_xml_stream(xml_buffer) # pytype: disable=attribute-error 2612 2613 # If we've used a seed to randomize test case ordering, we want to record it 2614 # as a top-level attribute in the `testsuites` section of the XML output. 2615 randomize_ordering_seed = getattr( 2616 kwargs['testLoader'], '_randomize_ordering_seed', None) 2617 setter = getattr(kwargs['testRunner'], 'set_testsuites_property', None) 2618 if randomize_ordering_seed and setter: 2619 setter('test_randomize_ordering_seed', randomize_ordering_seed) 2620 elif kwargs.get('testRunner') is None: 2621 kwargs['testRunner'] = _pretty_print_reporter.TextTestRunner 2622 2623 if FLAGS.pdb_post_mortem: 2624 runner = kwargs['testRunner'] 2625 # testRunner can be a class or an instance, which must be tested for 2626 # differently. 2627 # Overriding testRunner isn't uncommon, so only enable the debugging 2628 # integration if the runner claims it does; we don't want to accidentally 2629 # clobber something on the runner. 2630 if ((isinstance(runner, type) and 2631 issubclass(runner, _pretty_print_reporter.TextTestRunner)) or 2632 isinstance(runner, _pretty_print_reporter.TextTestRunner)): 2633 runner.run_for_debugging = True 2634 2635 # Make sure tmpdir exists. 2636 if not os.path.isdir(TEST_TMPDIR.value): 2637 try: 2638 os.makedirs(TEST_TMPDIR.value) 2639 except OSError as e: 2640 # Concurrent test might have created the directory. 2641 if e.errno != errno.EEXIST: 2642 raise 2643 2644 # Let unittest.TestProgram.__init__ do its own argv parsing, e.g. for '-v', 2645 # on argv, which is sys.argv without the command-line flags. 2646 kwargs['argv'] = argv 2647 2648 # Request unittest.TestProgram to not exit. The exit will be handled by 2649 # `absltest.run_tests`. 2650 kwargs['exit'] = False 2651 2652 try: 2653 test_program = unittest.TestProgram(*args, **kwargs) 2654 return test_program.result, fail_when_no_tests_ran 2655 finally: 2656 if xml_buffer: 2657 try: 2658 with _open(xml_output_file, 'w') as f: 2659 f.write(xml_buffer.getvalue()) 2660 finally: 2661 xml_buffer.close() 2662 2663 2664def run_tests( 2665 argv: MutableSequence[Text], 2666 args: Sequence[Any], 2667 kwargs: MutableMapping[Text, Any], 2668) -> None: 2669 """Executes a set of Python unit tests. 2670 2671 Most users should call absltest.main() instead of run_tests. 2672 2673 Please note that run_tests should be called from app.run. 2674 Calling absltest.main() would ensure that. 2675 2676 Please note that run_tests is allowed to make changes to kwargs. 2677 2678 Args: 2679 argv: sys.argv with the command-line flags removed from the front, i.e. the 2680 argv with which :func:`app.run()<absl.app.run>` has called 2681 ``__main__.main``. It is passed to 2682 ``unittest.TestProgram.__init__(argv=)``, which does its own flag parsing. 2683 It is ignored if kwargs contains an argv entry. 2684 args: Positional arguments passed through to 2685 ``unittest.TestProgram.__init__``. 2686 kwargs: Keyword arguments passed through to 2687 ``unittest.TestProgram.__init__``. 2688 """ 2689 result, fail_when_no_tests_ran = _run_and_get_tests_result( 2690 argv, args, kwargs, xml_reporter.TextAndXMLTestRunner 2691 ) 2692 if fail_when_no_tests_ran and result.testsRun == 0 and not result.skipped: 2693 # Python 3.12 unittest exits with 5 when no tests ran. The exit code 5 comes 2694 # from pytest which does the same thing. 2695 sys.exit(5) 2696 sys.exit(not result.wasSuccessful()) 2697 2698 2699def _rmtree_ignore_errors(path): 2700 # type: (Text) -> None 2701 if os.path.isfile(path): 2702 try: 2703 os.unlink(path) 2704 except OSError: 2705 pass 2706 else: 2707 shutil.rmtree(path, ignore_errors=True) 2708 2709 2710def _get_first_part(path): 2711 # type: (Text) -> Text 2712 parts = path.split(os.sep, 1) 2713 return parts[0] 2714