1"""Test suite for statistics module, including helper NumericTestCase and
2approx_equal function.
3
4"""
5
6import bisect
7import collections
8import collections.abc
9import copy
10import decimal
11import doctest
12import itertools
13import math
14import pickle
15import random
16import sys
17import unittest
18from test import support
19from test.support import import_helper, requires_IEEE_754
20
21from decimal import Decimal
22from fractions import Fraction
23
24
25# Module to be tested.
26import statistics
27
28
29# === Helper functions and class ===
30
31def sign(x):
32    """Return -1.0 for negatives, including -0.0, otherwise +1.0."""
33    return math.copysign(1, x)
34
35def _nan_equal(a, b):
36    """Return True if a and b are both the same kind of NAN.
37
38    >>> _nan_equal(Decimal('NAN'), Decimal('NAN'))
39    True
40    >>> _nan_equal(Decimal('sNAN'), Decimal('sNAN'))
41    True
42    >>> _nan_equal(Decimal('NAN'), Decimal('sNAN'))
43    False
44    >>> _nan_equal(Decimal(42), Decimal('NAN'))
45    False
46
47    >>> _nan_equal(float('NAN'), float('NAN'))
48    True
49    >>> _nan_equal(float('NAN'), 0.5)
50    False
51
52    >>> _nan_equal(float('NAN'), Decimal('NAN'))
53    False
54
55    NAN payloads are not compared.
56    """
57    if type(a) is not type(b):
58        return False
59    if isinstance(a, float):
60        return math.isnan(a) and math.isnan(b)
61    aexp = a.as_tuple()[2]
62    bexp = b.as_tuple()[2]
63    return (aexp == bexp) and (aexp in ('n', 'N'))  # Both NAN or both sNAN.
64
65
66def _calc_errors(actual, expected):
67    """Return the absolute and relative errors between two numbers.
68
69    >>> _calc_errors(100, 75)
70    (25, 0.25)
71    >>> _calc_errors(100, 100)
72    (0, 0.0)
73
74    Returns the (absolute error, relative error) between the two arguments.
75    """
76    base = max(abs(actual), abs(expected))
77    abs_err = abs(actual - expected)
78    rel_err = abs_err/base if base else float('inf')
79    return (abs_err, rel_err)
80
81
82def approx_equal(x, y, tol=1e-12, rel=1e-7):
83    """approx_equal(x, y [, tol [, rel]]) => True|False
84
85    Return True if numbers x and y are approximately equal, to within some
86    margin of error, otherwise return False. Numbers which compare equal
87    will also compare approximately equal.
88
89    x is approximately equal to y if the difference between them is less than
90    an absolute error tol or a relative error rel, whichever is bigger.
91
92    If given, both tol and rel must be finite, non-negative numbers. If not
93    given, default values are tol=1e-12 and rel=1e-7.
94
95    >>> approx_equal(1.2589, 1.2587, tol=0.0003, rel=0)
96    True
97    >>> approx_equal(1.2589, 1.2587, tol=0.0001, rel=0)
98    False
99
100    Absolute error is defined as abs(x-y); if that is less than or equal to
101    tol, x and y are considered approximately equal.
102
103    Relative error is defined as abs((x-y)/x) or abs((x-y)/y), whichever is
104    smaller, provided x or y are not zero. If that figure is less than or
105    equal to rel, x and y are considered approximately equal.
106
107    Complex numbers are not directly supported. If you wish to compare to
108    complex numbers, extract their real and imaginary parts and compare them
109    individually.
110
111    NANs always compare unequal, even with themselves. Infinities compare
112    approximately equal if they have the same sign (both positive or both
113    negative). Infinities with different signs compare unequal; so do
114    comparisons of infinities with finite numbers.
115    """
116    if tol < 0 or rel < 0:
117        raise ValueError('error tolerances must be non-negative')
118    # NANs are never equal to anything, approximately or otherwise.
119    if math.isnan(x) or math.isnan(y):
120        return False
121    # Numbers which compare equal also compare approximately equal.
122    if x == y:
123        # This includes the case of two infinities with the same sign.
124        return True
125    if math.isinf(x) or math.isinf(y):
126        # This includes the case of two infinities of opposite sign, or
127        # one infinity and one finite number.
128        return False
129    # Two finite numbers.
130    actual_error = abs(x - y)
131    allowed_error = max(tol, rel*max(abs(x), abs(y)))
132    return actual_error <= allowed_error
133
134
135# This class exists only as somewhere to stick a docstring containing
136# doctests. The following docstring and tests were originally in a separate
137# module. Now that it has been merged in here, I need somewhere to hang the.
138# docstring. Ultimately, this class will die, and the information below will
139# either become redundant, or be moved into more appropriate places.
140class _DoNothing:
141    """
142    When doing numeric work, especially with floats, exact equality is often
143    not what you want. Due to round-off error, it is often a bad idea to try
144    to compare floats with equality. Instead the usual procedure is to test
145    them with some (hopefully small!) allowance for error.
146
147    The ``approx_equal`` function allows you to specify either an absolute
148    error tolerance, or a relative error, or both.
149
150    Absolute error tolerances are simple, but you need to know the magnitude
151    of the quantities being compared:
152
153    >>> approx_equal(12.345, 12.346, tol=1e-3)
154    True
155    >>> approx_equal(12.345e6, 12.346e6, tol=1e-3)  # tol is too small.
156    False
157
158    Relative errors are more suitable when the values you are comparing can
159    vary in magnitude:
160
161    >>> approx_equal(12.345, 12.346, rel=1e-4)
162    True
163    >>> approx_equal(12.345e6, 12.346e6, rel=1e-4)
164    True
165
166    but a naive implementation of relative error testing can run into trouble
167    around zero.
168
169    If you supply both an absolute tolerance and a relative error, the
170    comparison succeeds if either individual test succeeds:
171
172    >>> approx_equal(12.345e6, 12.346e6, tol=1e-3, rel=1e-4)
173    True
174
175    """
176    pass
177
178
179
180# We prefer this for testing numeric values that may not be exactly equal,
181# and avoid using TestCase.assertAlmostEqual, because it sucks :-)
182
183py_statistics = import_helper.import_fresh_module('statistics',
184                                                  blocked=['_statistics'])
185c_statistics = import_helper.import_fresh_module('statistics',
186                                                 fresh=['_statistics'])
187
188
189class TestModules(unittest.TestCase):
190    func_names = ['_normal_dist_inv_cdf']
191
192    def test_py_functions(self):
193        for fname in self.func_names:
194            self.assertEqual(getattr(py_statistics, fname).__module__, 'statistics')
195
196    @unittest.skipUnless(c_statistics, 'requires _statistics')
197    def test_c_functions(self):
198        for fname in self.func_names:
199            self.assertEqual(getattr(c_statistics, fname).__module__, '_statistics')
200
201
202class NumericTestCase(unittest.TestCase):
203    """Unit test class for numeric work.
204
205    This subclasses TestCase. In addition to the standard method
206    ``TestCase.assertAlmostEqual``,  ``assertApproxEqual`` is provided.
207    """
208    # By default, we expect exact equality, unless overridden.
209    tol = rel = 0
210
211    def assertApproxEqual(
212            self, first, second, tol=None, rel=None, msg=None
213            ):
214        """Test passes if ``first`` and ``second`` are approximately equal.
215
216        This test passes if ``first`` and ``second`` are equal to
217        within ``tol``, an absolute error, or ``rel``, a relative error.
218
219        If either ``tol`` or ``rel`` are None or not given, they default to
220        test attributes of the same name (by default, 0).
221
222        The objects may be either numbers, or sequences of numbers. Sequences
223        are tested element-by-element.
224
225        >>> class MyTest(NumericTestCase):
226        ...     def test_number(self):
227        ...         x = 1.0/6
228        ...         y = sum([x]*6)
229        ...         self.assertApproxEqual(y, 1.0, tol=1e-15)
230        ...     def test_sequence(self):
231        ...         a = [1.001, 1.001e-10, 1.001e10]
232        ...         b = [1.0, 1e-10, 1e10]
233        ...         self.assertApproxEqual(a, b, rel=1e-3)
234        ...
235        >>> import unittest
236        >>> from io import StringIO  # Suppress test runner output.
237        >>> suite = unittest.TestLoader().loadTestsFromTestCase(MyTest)
238        >>> unittest.TextTestRunner(stream=StringIO()).run(suite)
239        <unittest.runner.TextTestResult run=2 errors=0 failures=0>
240
241        """
242        if tol is None:
243            tol = self.tol
244        if rel is None:
245            rel = self.rel
246        if (
247                isinstance(first, collections.abc.Sequence) and
248                isinstance(second, collections.abc.Sequence)
249            ):
250            check = self._check_approx_seq
251        else:
252            check = self._check_approx_num
253        check(first, second, tol, rel, msg)
254
255    def _check_approx_seq(self, first, second, tol, rel, msg):
256        if len(first) != len(second):
257            standardMsg = (
258                "sequences differ in length: %d items != %d items"
259                % (len(first), len(second))
260                )
261            msg = self._formatMessage(msg, standardMsg)
262            raise self.failureException(msg)
263        for i, (a,e) in enumerate(zip(first, second)):
264            self._check_approx_num(a, e, tol, rel, msg, i)
265
266    def _check_approx_num(self, first, second, tol, rel, msg, idx=None):
267        if approx_equal(first, second, tol, rel):
268            # Test passes. Return early, we are done.
269            return None
270        # Otherwise we failed.
271        standardMsg = self._make_std_err_msg(first, second, tol, rel, idx)
272        msg = self._formatMessage(msg, standardMsg)
273        raise self.failureException(msg)
274
275    @staticmethod
276    def _make_std_err_msg(first, second, tol, rel, idx):
277        # Create the standard error message for approx_equal failures.
278        assert first != second
279        template = (
280            '  %r != %r\n'
281            '  values differ by more than tol=%r and rel=%r\n'
282            '  -> absolute error = %r\n'
283            '  -> relative error = %r'
284            )
285        if idx is not None:
286            header = 'numeric sequences first differ at index %d.\n' % idx
287            template = header + template
288        # Calculate actual errors:
289        abs_err, rel_err = _calc_errors(first, second)
290        return template % (first, second, tol, rel, abs_err, rel_err)
291
292
293# ========================
294# === Test the helpers ===
295# ========================
296
297class TestSign(unittest.TestCase):
298    """Test that the helper function sign() works correctly."""
299    def testZeroes(self):
300        # Test that signed zeroes report their sign correctly.
301        self.assertEqual(sign(0.0), +1)
302        self.assertEqual(sign(-0.0), -1)
303
304
305# --- Tests for approx_equal ---
306
307class ApproxEqualSymmetryTest(unittest.TestCase):
308    # Test symmetry of approx_equal.
309
310    def test_relative_symmetry(self):
311        # Check that approx_equal treats relative error symmetrically.
312        # (a-b)/a is usually not equal to (a-b)/b. Ensure that this
313        # doesn't matter.
314        #
315        #   Note: the reason for this test is that an early version
316        #   of approx_equal was not symmetric. A relative error test
317        #   would pass, or fail, depending on which value was passed
318        #   as the first argument.
319        #
320        args1 = [2456, 37.8, -12.45, Decimal('2.54'), Fraction(17, 54)]
321        args2 = [2459, 37.2, -12.41, Decimal('2.59'), Fraction(15, 54)]
322        assert len(args1) == len(args2)
323        for a, b in zip(args1, args2):
324            self.do_relative_symmetry(a, b)
325
326    def do_relative_symmetry(self, a, b):
327        a, b = min(a, b), max(a, b)
328        assert a < b
329        delta = b - a  # The absolute difference between the values.
330        rel_err1, rel_err2 = abs(delta/a), abs(delta/b)
331        # Choose an error margin halfway between the two.
332        rel = (rel_err1 + rel_err2)/2
333        # Now see that values a and b compare approx equal regardless of
334        # which is given first.
335        self.assertTrue(approx_equal(a, b, tol=0, rel=rel))
336        self.assertTrue(approx_equal(b, a, tol=0, rel=rel))
337
338    def test_symmetry(self):
339        # Test that approx_equal(a, b) == approx_equal(b, a)
340        args = [-23, -2, 5, 107, 93568]
341        delta = 2
342        for a in args:
343            for type_ in (int, float, Decimal, Fraction):
344                x = type_(a)*100
345                y = x + delta
346                r = abs(delta/max(x, y))
347                # There are five cases to check:
348                # 1) actual error <= tol, <= rel
349                self.do_symmetry_test(x, y, tol=delta, rel=r)
350                self.do_symmetry_test(x, y, tol=delta+1, rel=2*r)
351                # 2) actual error > tol, > rel
352                self.do_symmetry_test(x, y, tol=delta-1, rel=r/2)
353                # 3) actual error <= tol, > rel
354                self.do_symmetry_test(x, y, tol=delta, rel=r/2)
355                # 4) actual error > tol, <= rel
356                self.do_symmetry_test(x, y, tol=delta-1, rel=r)
357                self.do_symmetry_test(x, y, tol=delta-1, rel=2*r)
358                # 5) exact equality test
359                self.do_symmetry_test(x, x, tol=0, rel=0)
360                self.do_symmetry_test(x, y, tol=0, rel=0)
361
362    def do_symmetry_test(self, a, b, tol, rel):
363        template = "approx_equal comparisons don't match for %r"
364        flag1 = approx_equal(a, b, tol, rel)
365        flag2 = approx_equal(b, a, tol, rel)
366        self.assertEqual(flag1, flag2, template.format((a, b, tol, rel)))
367
368
369class ApproxEqualExactTest(unittest.TestCase):
370    # Test the approx_equal function with exactly equal values.
371    # Equal values should compare as approximately equal.
372    # Test cases for exactly equal values, which should compare approx
373    # equal regardless of the error tolerances given.
374
375    def do_exactly_equal_test(self, x, tol, rel):
376        result = approx_equal(x, x, tol=tol, rel=rel)
377        self.assertTrue(result, 'equality failure for x=%r' % x)
378        result = approx_equal(-x, -x, tol=tol, rel=rel)
379        self.assertTrue(result, 'equality failure for x=%r' % -x)
380
381    def test_exactly_equal_ints(self):
382        # Test that equal int values are exactly equal.
383        for n in [42, 19740, 14974, 230, 1795, 700245, 36587]:
384            self.do_exactly_equal_test(n, 0, 0)
385
386    def test_exactly_equal_floats(self):
387        # Test that equal float values are exactly equal.
388        for x in [0.42, 1.9740, 1497.4, 23.0, 179.5, 70.0245, 36.587]:
389            self.do_exactly_equal_test(x, 0, 0)
390
391    def test_exactly_equal_fractions(self):
392        # Test that equal Fraction values are exactly equal.
393        F = Fraction
394        for f in [F(1, 2), F(0), F(5, 3), F(9, 7), F(35, 36), F(3, 7)]:
395            self.do_exactly_equal_test(f, 0, 0)
396
397    def test_exactly_equal_decimals(self):
398        # Test that equal Decimal values are exactly equal.
399        D = Decimal
400        for d in map(D, "8.2 31.274 912.04 16.745 1.2047".split()):
401            self.do_exactly_equal_test(d, 0, 0)
402
403    def test_exactly_equal_absolute(self):
404        # Test that equal values are exactly equal with an absolute error.
405        for n in [16, 1013, 1372, 1198, 971, 4]:
406            # Test as ints.
407            self.do_exactly_equal_test(n, 0.01, 0)
408            # Test as floats.
409            self.do_exactly_equal_test(n/10, 0.01, 0)
410            # Test as Fractions.
411            f = Fraction(n, 1234)
412            self.do_exactly_equal_test(f, 0.01, 0)
413
414    def test_exactly_equal_absolute_decimals(self):
415        # Test equal Decimal values are exactly equal with an absolute error.
416        self.do_exactly_equal_test(Decimal("3.571"), Decimal("0.01"), 0)
417        self.do_exactly_equal_test(-Decimal("81.3971"), Decimal("0.01"), 0)
418
419    def test_exactly_equal_relative(self):
420        # Test that equal values are exactly equal with a relative error.
421        for x in [8347, 101.3, -7910.28, Fraction(5, 21)]:
422            self.do_exactly_equal_test(x, 0, 0.01)
423        self.do_exactly_equal_test(Decimal("11.68"), 0, Decimal("0.01"))
424
425    def test_exactly_equal_both(self):
426        # Test that equal values are equal when both tol and rel are given.
427        for x in [41017, 16.742, -813.02, Fraction(3, 8)]:
428            self.do_exactly_equal_test(x, 0.1, 0.01)
429        D = Decimal
430        self.do_exactly_equal_test(D("7.2"), D("0.1"), D("0.01"))
431
432
433class ApproxEqualUnequalTest(unittest.TestCase):
434    # Unequal values should compare unequal with zero error tolerances.
435    # Test cases for unequal values, with exact equality test.
436
437    def do_exactly_unequal_test(self, x):
438        for a in (x, -x):
439            result = approx_equal(a, a+1, tol=0, rel=0)
440            self.assertFalse(result, 'inequality failure for x=%r' % a)
441
442    def test_exactly_unequal_ints(self):
443        # Test unequal int values are unequal with zero error tolerance.
444        for n in [951, 572305, 478, 917, 17240]:
445            self.do_exactly_unequal_test(n)
446
447    def test_exactly_unequal_floats(self):
448        # Test unequal float values are unequal with zero error tolerance.
449        for x in [9.51, 5723.05, 47.8, 9.17, 17.24]:
450            self.do_exactly_unequal_test(x)
451
452    def test_exactly_unequal_fractions(self):
453        # Test that unequal Fractions are unequal with zero error tolerance.
454        F = Fraction
455        for f in [F(1, 5), F(7, 9), F(12, 11), F(101, 99023)]:
456            self.do_exactly_unequal_test(f)
457
458    def test_exactly_unequal_decimals(self):
459        # Test that unequal Decimals are unequal with zero error tolerance.
460        for d in map(Decimal, "3.1415 298.12 3.47 18.996 0.00245".split()):
461            self.do_exactly_unequal_test(d)
462
463
464class ApproxEqualInexactTest(unittest.TestCase):
465    # Inexact test cases for approx_error.
466    # Test cases when comparing two values that are not exactly equal.
467
468    # === Absolute error tests ===
469
470    def do_approx_equal_abs_test(self, x, delta):
471        template = "Test failure for x={!r}, y={!r}"
472        for y in (x + delta, x - delta):
473            msg = template.format(x, y)
474            self.assertTrue(approx_equal(x, y, tol=2*delta, rel=0), msg)
475            self.assertFalse(approx_equal(x, y, tol=delta/2, rel=0), msg)
476
477    def test_approx_equal_absolute_ints(self):
478        # Test approximate equality of ints with an absolute error.
479        for n in [-10737, -1975, -7, -2, 0, 1, 9, 37, 423, 9874, 23789110]:
480            self.do_approx_equal_abs_test(n, 10)
481            self.do_approx_equal_abs_test(n, 2)
482
483    def test_approx_equal_absolute_floats(self):
484        # Test approximate equality of floats with an absolute error.
485        for x in [-284.126, -97.1, -3.4, -2.15, 0.5, 1.0, 7.8, 4.23, 3817.4]:
486            self.do_approx_equal_abs_test(x, 1.5)
487            self.do_approx_equal_abs_test(x, 0.01)
488            self.do_approx_equal_abs_test(x, 0.0001)
489
490    def test_approx_equal_absolute_fractions(self):
491        # Test approximate equality of Fractions with an absolute error.
492        delta = Fraction(1, 29)
493        numerators = [-84, -15, -2, -1, 0, 1, 5, 17, 23, 34, 71]
494        for f in (Fraction(n, 29) for n in numerators):
495            self.do_approx_equal_abs_test(f, delta)
496            self.do_approx_equal_abs_test(f, float(delta))
497
498    def test_approx_equal_absolute_decimals(self):
499        # Test approximate equality of Decimals with an absolute error.
500        delta = Decimal("0.01")
501        for d in map(Decimal, "1.0 3.5 36.08 61.79 7912.3648".split()):
502            self.do_approx_equal_abs_test(d, delta)
503            self.do_approx_equal_abs_test(-d, delta)
504
505    def test_cross_zero(self):
506        # Test for the case of the two values having opposite signs.
507        self.assertTrue(approx_equal(1e-5, -1e-5, tol=1e-4, rel=0))
508
509    # === Relative error tests ===
510
511    def do_approx_equal_rel_test(self, x, delta):
512        template = "Test failure for x={!r}, y={!r}"
513        for y in (x*(1+delta), x*(1-delta)):
514            msg = template.format(x, y)
515            self.assertTrue(approx_equal(x, y, tol=0, rel=2*delta), msg)
516            self.assertFalse(approx_equal(x, y, tol=0, rel=delta/2), msg)
517
518    def test_approx_equal_relative_ints(self):
519        # Test approximate equality of ints with a relative error.
520        self.assertTrue(approx_equal(64, 47, tol=0, rel=0.36))
521        self.assertTrue(approx_equal(64, 47, tol=0, rel=0.37))
522        # ---
523        self.assertTrue(approx_equal(449, 512, tol=0, rel=0.125))
524        self.assertTrue(approx_equal(448, 512, tol=0, rel=0.125))
525        self.assertFalse(approx_equal(447, 512, tol=0, rel=0.125))
526
527    def test_approx_equal_relative_floats(self):
528        # Test approximate equality of floats with a relative error.
529        for x in [-178.34, -0.1, 0.1, 1.0, 36.97, 2847.136, 9145.074]:
530            self.do_approx_equal_rel_test(x, 0.02)
531            self.do_approx_equal_rel_test(x, 0.0001)
532
533    def test_approx_equal_relative_fractions(self):
534        # Test approximate equality of Fractions with a relative error.
535        F = Fraction
536        delta = Fraction(3, 8)
537        for f in [F(3, 84), F(17, 30), F(49, 50), F(92, 85)]:
538            for d in (delta, float(delta)):
539                self.do_approx_equal_rel_test(f, d)
540                self.do_approx_equal_rel_test(-f, d)
541
542    def test_approx_equal_relative_decimals(self):
543        # Test approximate equality of Decimals with a relative error.
544        for d in map(Decimal, "0.02 1.0 5.7 13.67 94.138 91027.9321".split()):
545            self.do_approx_equal_rel_test(d, Decimal("0.001"))
546            self.do_approx_equal_rel_test(-d, Decimal("0.05"))
547
548    # === Both absolute and relative error tests ===
549
550    # There are four cases to consider:
551    #   1) actual error <= both absolute and relative error
552    #   2) actual error <= absolute error but > relative error
553    #   3) actual error <= relative error but > absolute error
554    #   4) actual error > both absolute and relative error
555
556    def do_check_both(self, a, b, tol, rel, tol_flag, rel_flag):
557        check = self.assertTrue if tol_flag else self.assertFalse
558        check(approx_equal(a, b, tol=tol, rel=0))
559        check = self.assertTrue if rel_flag else self.assertFalse
560        check(approx_equal(a, b, tol=0, rel=rel))
561        check = self.assertTrue if (tol_flag or rel_flag) else self.assertFalse
562        check(approx_equal(a, b, tol=tol, rel=rel))
563
564    def test_approx_equal_both1(self):
565        # Test actual error <= both absolute and relative error.
566        self.do_check_both(7.955, 7.952, 0.004, 3.8e-4, True, True)
567        self.do_check_both(-7.387, -7.386, 0.002, 0.0002, True, True)
568
569    def test_approx_equal_both2(self):
570        # Test actual error <= absolute error but > relative error.
571        self.do_check_both(7.955, 7.952, 0.004, 3.7e-4, True, False)
572
573    def test_approx_equal_both3(self):
574        # Test actual error <= relative error but > absolute error.
575        self.do_check_both(7.955, 7.952, 0.001, 3.8e-4, False, True)
576
577    def test_approx_equal_both4(self):
578        # Test actual error > both absolute and relative error.
579        self.do_check_both(2.78, 2.75, 0.01, 0.001, False, False)
580        self.do_check_both(971.44, 971.47, 0.02, 3e-5, False, False)
581
582
583class ApproxEqualSpecialsTest(unittest.TestCase):
584    # Test approx_equal with NANs and INFs and zeroes.
585
586    def test_inf(self):
587        for type_ in (float, Decimal):
588            inf = type_('inf')
589            self.assertTrue(approx_equal(inf, inf))
590            self.assertTrue(approx_equal(inf, inf, 0, 0))
591            self.assertTrue(approx_equal(inf, inf, 1, 0.01))
592            self.assertTrue(approx_equal(-inf, -inf))
593            self.assertFalse(approx_equal(inf, -inf))
594            self.assertFalse(approx_equal(inf, 1000))
595
596    def test_nan(self):
597        for type_ in (float, Decimal):
598            nan = type_('nan')
599            for other in (nan, type_('inf'), 1000):
600                self.assertFalse(approx_equal(nan, other))
601
602    def test_float_zeroes(self):
603        nzero = math.copysign(0.0, -1)
604        self.assertTrue(approx_equal(nzero, 0.0, tol=0.1, rel=0.1))
605
606    def test_decimal_zeroes(self):
607        nzero = Decimal("-0.0")
608        self.assertTrue(approx_equal(nzero, Decimal(0), tol=0.1, rel=0.1))
609
610
611class TestApproxEqualErrors(unittest.TestCase):
612    # Test error conditions of approx_equal.
613
614    def test_bad_tol(self):
615        # Test negative tol raises.
616        self.assertRaises(ValueError, approx_equal, 100, 100, -1, 0.1)
617
618    def test_bad_rel(self):
619        # Test negative rel raises.
620        self.assertRaises(ValueError, approx_equal, 100, 100, 1, -0.1)
621
622
623# --- Tests for NumericTestCase ---
624
625# The formatting routine that generates the error messages is complex enough
626# that it too needs testing.
627
628class TestNumericTestCase(unittest.TestCase):
629    # The exact wording of NumericTestCase error messages is *not* guaranteed,
630    # but we need to give them some sort of test to ensure that they are
631    # generated correctly. As a compromise, we look for specific substrings
632    # that are expected to be found even if the overall error message changes.
633
634    def do_test(self, args):
635        actual_msg = NumericTestCase._make_std_err_msg(*args)
636        expected = self.generate_substrings(*args)
637        for substring in expected:
638            self.assertIn(substring, actual_msg)
639
640    def test_numerictestcase_is_testcase(self):
641        # Ensure that NumericTestCase actually is a TestCase.
642        self.assertTrue(issubclass(NumericTestCase, unittest.TestCase))
643
644    def test_error_msg_numeric(self):
645        # Test the error message generated for numeric comparisons.
646        args = (2.5, 4.0, 0.5, 0.25, None)
647        self.do_test(args)
648
649    def test_error_msg_sequence(self):
650        # Test the error message generated for sequence comparisons.
651        args = (3.75, 8.25, 1.25, 0.5, 7)
652        self.do_test(args)
653
654    def generate_substrings(self, first, second, tol, rel, idx):
655        """Return substrings we expect to see in error messages."""
656        abs_err, rel_err = _calc_errors(first, second)
657        substrings = [
658                'tol=%r' % tol,
659                'rel=%r' % rel,
660                'absolute error = %r' % abs_err,
661                'relative error = %r' % rel_err,
662                ]
663        if idx is not None:
664            substrings.append('differ at index %d' % idx)
665        return substrings
666
667
668# =======================================
669# === Tests for the statistics module ===
670# =======================================
671
672
673class GlobalsTest(unittest.TestCase):
674    module = statistics
675    expected_metadata = ["__doc__", "__all__"]
676
677    def test_meta(self):
678        # Test for the existence of metadata.
679        for meta in self.expected_metadata:
680            self.assertTrue(hasattr(self.module, meta),
681                            "%s not present" % meta)
682
683    def test_check_all(self):
684        # Check everything in __all__ exists and is public.
685        module = self.module
686        for name in module.__all__:
687            # No private names in __all__:
688            self.assertFalse(name.startswith("_"),
689                             'private name "%s" in __all__' % name)
690            # And anything in __all__ must exist:
691            self.assertTrue(hasattr(module, name),
692                            'missing name "%s" in __all__' % name)
693
694
695class DocTests(unittest.TestCase):
696    @unittest.skipIf(sys.flags.optimize >= 2,
697                     "Docstrings are omitted with -OO and above")
698    def test_doc_tests(self):
699        failed, tried = doctest.testmod(statistics, optionflags=doctest.ELLIPSIS)
700        self.assertGreater(tried, 0)
701        self.assertEqual(failed, 0)
702
703class StatisticsErrorTest(unittest.TestCase):
704    def test_has_exception(self):
705        errmsg = (
706                "Expected StatisticsError to be a ValueError, but got a"
707                " subclass of %r instead."
708                )
709        self.assertTrue(hasattr(statistics, 'StatisticsError'))
710        self.assertTrue(
711                issubclass(statistics.StatisticsError, ValueError),
712                errmsg % statistics.StatisticsError.__base__
713                )
714
715
716# === Tests for private utility functions ===
717
718class ExactRatioTest(unittest.TestCase):
719    # Test _exact_ratio utility.
720
721    def test_int(self):
722        for i in (-20, -3, 0, 5, 99, 10**20):
723            self.assertEqual(statistics._exact_ratio(i), (i, 1))
724
725    def test_fraction(self):
726        numerators = (-5, 1, 12, 38)
727        for n in numerators:
728            f = Fraction(n, 37)
729            self.assertEqual(statistics._exact_ratio(f), (n, 37))
730
731    def test_float(self):
732        self.assertEqual(statistics._exact_ratio(0.125), (1, 8))
733        self.assertEqual(statistics._exact_ratio(1.125), (9, 8))
734        data = [random.uniform(-100, 100) for _ in range(100)]
735        for x in data:
736            num, den = statistics._exact_ratio(x)
737            self.assertEqual(x, num/den)
738
739    def test_decimal(self):
740        D = Decimal
741        _exact_ratio = statistics._exact_ratio
742        self.assertEqual(_exact_ratio(D("0.125")), (1, 8))
743        self.assertEqual(_exact_ratio(D("12.345")), (2469, 200))
744        self.assertEqual(_exact_ratio(D("-1.98")), (-99, 50))
745
746    def test_inf(self):
747        INF = float("INF")
748        class MyFloat(float):
749            pass
750        class MyDecimal(Decimal):
751            pass
752        for inf in (INF, -INF):
753            for type_ in (float, MyFloat, Decimal, MyDecimal):
754                x = type_(inf)
755                ratio = statistics._exact_ratio(x)
756                self.assertEqual(ratio, (x, None))
757                self.assertEqual(type(ratio[0]), type_)
758                self.assertTrue(math.isinf(ratio[0]))
759
760    def test_float_nan(self):
761        NAN = float("NAN")
762        class MyFloat(float):
763            pass
764        for nan in (NAN, MyFloat(NAN)):
765            ratio = statistics._exact_ratio(nan)
766            self.assertTrue(math.isnan(ratio[0]))
767            self.assertIs(ratio[1], None)
768            self.assertEqual(type(ratio[0]), type(nan))
769
770    def test_decimal_nan(self):
771        NAN = Decimal("NAN")
772        sNAN = Decimal("sNAN")
773        class MyDecimal(Decimal):
774            pass
775        for nan in (NAN, MyDecimal(NAN), sNAN, MyDecimal(sNAN)):
776            ratio = statistics._exact_ratio(nan)
777            self.assertTrue(_nan_equal(ratio[0], nan))
778            self.assertIs(ratio[1], None)
779            self.assertEqual(type(ratio[0]), type(nan))
780
781
782class DecimalToRatioTest(unittest.TestCase):
783    # Test _exact_ratio private function.
784
785    def test_infinity(self):
786        # Test that INFs are handled correctly.
787        inf = Decimal('INF')
788        self.assertEqual(statistics._exact_ratio(inf), (inf, None))
789        self.assertEqual(statistics._exact_ratio(-inf), (-inf, None))
790
791    def test_nan(self):
792        # Test that NANs are handled correctly.
793        for nan in (Decimal('NAN'), Decimal('sNAN')):
794            num, den = statistics._exact_ratio(nan)
795            # Because NANs always compare non-equal, we cannot use assertEqual.
796            # Nor can we use an identity test, as we don't guarantee anything
797            # about the object identity.
798            self.assertTrue(_nan_equal(num, nan))
799            self.assertIs(den, None)
800
801    def test_sign(self):
802        # Test sign is calculated correctly.
803        numbers = [Decimal("9.8765e12"), Decimal("9.8765e-12")]
804        for d in numbers:
805            # First test positive decimals.
806            assert d > 0
807            num, den = statistics._exact_ratio(d)
808            self.assertGreaterEqual(num, 0)
809            self.assertGreater(den, 0)
810            # Then test negative decimals.
811            num, den = statistics._exact_ratio(-d)
812            self.assertLessEqual(num, 0)
813            self.assertGreater(den, 0)
814
815    def test_negative_exponent(self):
816        # Test result when the exponent is negative.
817        t = statistics._exact_ratio(Decimal("0.1234"))
818        self.assertEqual(t, (617, 5000))
819
820    def test_positive_exponent(self):
821        # Test results when the exponent is positive.
822        t = statistics._exact_ratio(Decimal("1.234e7"))
823        self.assertEqual(t, (12340000, 1))
824
825    def test_regression_20536(self):
826        # Regression test for issue 20536.
827        # See http://bugs.python.org/issue20536
828        t = statistics._exact_ratio(Decimal("1e2"))
829        self.assertEqual(t, (100, 1))
830        t = statistics._exact_ratio(Decimal("1.47e5"))
831        self.assertEqual(t, (147000, 1))
832
833
834class IsFiniteTest(unittest.TestCase):
835    # Test _isfinite private function.
836
837    def test_finite(self):
838        # Test that finite numbers are recognised as finite.
839        for x in (5, Fraction(1, 3), 2.5, Decimal("5.5")):
840            self.assertTrue(statistics._isfinite(x))
841
842    def test_infinity(self):
843        # Test that INFs are not recognised as finite.
844        for x in (float("inf"), Decimal("inf")):
845            self.assertFalse(statistics._isfinite(x))
846
847    def test_nan(self):
848        # Test that NANs are not recognised as finite.
849        for x in (float("nan"), Decimal("NAN"), Decimal("sNAN")):
850            self.assertFalse(statistics._isfinite(x))
851
852
853class CoerceTest(unittest.TestCase):
854    # Test that private function _coerce correctly deals with types.
855
856    # The coercion rules are currently an implementation detail, although at
857    # some point that should change. The tests and comments here define the
858    # correct implementation.
859
860    # Pre-conditions of _coerce:
861    #
862    #   - The first time _sum calls _coerce, the
863    #   - coerce(T, S) will never be called with bool as the first argument;
864    #     this is a pre-condition, guarded with an assertion.
865
866    #
867    #   - coerce(T, T) will always return T; we assume T is a valid numeric
868    #     type. Violate this assumption at your own risk.
869    #
870    #   - Apart from as above, bool is treated as if it were actually int.
871    #
872    #   - coerce(int, X) and coerce(X, int) return X.
873    #   -
874    def test_bool(self):
875        # bool is somewhat special, due to the pre-condition that it is
876        # never given as the first argument to _coerce, and that it cannot
877        # be subclassed. So we test it specially.
878        for T in (int, float, Fraction, Decimal):
879            self.assertIs(statistics._coerce(T, bool), T)
880            class MyClass(T): pass
881            self.assertIs(statistics._coerce(MyClass, bool), MyClass)
882
883    def assertCoerceTo(self, A, B):
884        """Assert that type A coerces to B."""
885        self.assertIs(statistics._coerce(A, B), B)
886        self.assertIs(statistics._coerce(B, A), B)
887
888    def check_coerce_to(self, A, B):
889        """Checks that type A coerces to B, including subclasses."""
890        # Assert that type A is coerced to B.
891        self.assertCoerceTo(A, B)
892        # Subclasses of A are also coerced to B.
893        class SubclassOfA(A): pass
894        self.assertCoerceTo(SubclassOfA, B)
895        # A, and subclasses of A, are coerced to subclasses of B.
896        class SubclassOfB(B): pass
897        self.assertCoerceTo(A, SubclassOfB)
898        self.assertCoerceTo(SubclassOfA, SubclassOfB)
899
900    def assertCoerceRaises(self, A, B):
901        """Assert that coercing A to B, or vice versa, raises TypeError."""
902        self.assertRaises(TypeError, statistics._coerce, (A, B))
903        self.assertRaises(TypeError, statistics._coerce, (B, A))
904
905    def check_type_coercions(self, T):
906        """Check that type T coerces correctly with subclasses of itself."""
907        assert T is not bool
908        # Coercing a type with itself returns the same type.
909        self.assertIs(statistics._coerce(T, T), T)
910        # Coercing a type with a subclass of itself returns the subclass.
911        class U(T): pass
912        class V(T): pass
913        class W(U): pass
914        for typ in (U, V, W):
915            self.assertCoerceTo(T, typ)
916        self.assertCoerceTo(U, W)
917        # Coercing two subclasses that aren't parent/child is an error.
918        self.assertCoerceRaises(U, V)
919        self.assertCoerceRaises(V, W)
920
921    def test_int(self):
922        # Check that int coerces correctly.
923        self.check_type_coercions(int)
924        for typ in (float, Fraction, Decimal):
925            self.check_coerce_to(int, typ)
926
927    def test_fraction(self):
928        # Check that Fraction coerces correctly.
929        self.check_type_coercions(Fraction)
930        self.check_coerce_to(Fraction, float)
931
932    def test_decimal(self):
933        # Check that Decimal coerces correctly.
934        self.check_type_coercions(Decimal)
935
936    def test_float(self):
937        # Check that float coerces correctly.
938        self.check_type_coercions(float)
939
940    def test_non_numeric_types(self):
941        for bad_type in (str, list, type(None), tuple, dict):
942            for good_type in (int, float, Fraction, Decimal):
943                self.assertCoerceRaises(good_type, bad_type)
944
945    def test_incompatible_types(self):
946        # Test that incompatible types raise.
947        for T in (float, Fraction):
948            class MySubclass(T): pass
949            self.assertCoerceRaises(T, Decimal)
950            self.assertCoerceRaises(MySubclass, Decimal)
951
952
953class ConvertTest(unittest.TestCase):
954    # Test private _convert function.
955
956    def check_exact_equal(self, x, y):
957        """Check that x equals y, and has the same type as well."""
958        self.assertEqual(x, y)
959        self.assertIs(type(x), type(y))
960
961    def test_int(self):
962        # Test conversions to int.
963        x = statistics._convert(Fraction(71), int)
964        self.check_exact_equal(x, 71)
965        class MyInt(int): pass
966        x = statistics._convert(Fraction(17), MyInt)
967        self.check_exact_equal(x, MyInt(17))
968
969    def test_fraction(self):
970        # Test conversions to Fraction.
971        x = statistics._convert(Fraction(95, 99), Fraction)
972        self.check_exact_equal(x, Fraction(95, 99))
973        class MyFraction(Fraction):
974            def __truediv__(self, other):
975                return self.__class__(super().__truediv__(other))
976        x = statistics._convert(Fraction(71, 13), MyFraction)
977        self.check_exact_equal(x, MyFraction(71, 13))
978
979    def test_float(self):
980        # Test conversions to float.
981        x = statistics._convert(Fraction(-1, 2), float)
982        self.check_exact_equal(x, -0.5)
983        class MyFloat(float):
984            def __truediv__(self, other):
985                return self.__class__(super().__truediv__(other))
986        x = statistics._convert(Fraction(9, 8), MyFloat)
987        self.check_exact_equal(x, MyFloat(1.125))
988
989    def test_decimal(self):
990        # Test conversions to Decimal.
991        x = statistics._convert(Fraction(1, 40), Decimal)
992        self.check_exact_equal(x, Decimal("0.025"))
993        class MyDecimal(Decimal):
994            def __truediv__(self, other):
995                return self.__class__(super().__truediv__(other))
996        x = statistics._convert(Fraction(-15, 16), MyDecimal)
997        self.check_exact_equal(x, MyDecimal("-0.9375"))
998
999    def test_inf(self):
1000        for INF in (float('inf'), Decimal('inf')):
1001            for inf in (INF, -INF):
1002                x = statistics._convert(inf, type(inf))
1003                self.check_exact_equal(x, inf)
1004
1005    def test_nan(self):
1006        for nan in (float('nan'), Decimal('NAN'), Decimal('sNAN')):
1007            x = statistics._convert(nan, type(nan))
1008            self.assertTrue(_nan_equal(x, nan))
1009
1010    def test_invalid_input_type(self):
1011        with self.assertRaises(TypeError):
1012            statistics._convert(None, float)
1013
1014
1015class FailNegTest(unittest.TestCase):
1016    """Test _fail_neg private function."""
1017
1018    def test_pass_through(self):
1019        # Test that values are passed through unchanged.
1020        values = [1, 2.0, Fraction(3), Decimal(4)]
1021        new = list(statistics._fail_neg(values))
1022        self.assertEqual(values, new)
1023
1024    def test_negatives_raise(self):
1025        # Test that negatives raise an exception.
1026        for x in [1, 2.0, Fraction(3), Decimal(4)]:
1027            seq = [-x]
1028            it = statistics._fail_neg(seq)
1029            self.assertRaises(statistics.StatisticsError, next, it)
1030
1031    def test_error_msg(self):
1032        # Test that a given error message is used.
1033        msg = "badness #%d" % random.randint(10000, 99999)
1034        try:
1035            next(statistics._fail_neg([-1], msg))
1036        except statistics.StatisticsError as e:
1037            errmsg = e.args[0]
1038        else:
1039            self.fail("expected exception, but it didn't happen")
1040        self.assertEqual(errmsg, msg)
1041
1042
1043# === Tests for public functions ===
1044
1045class UnivariateCommonMixin:
1046    # Common tests for most univariate functions that take a data argument.
1047
1048    def test_no_args(self):
1049        # Fail if given no arguments.
1050        self.assertRaises(TypeError, self.func)
1051
1052    def test_empty_data(self):
1053        # Fail when the data argument (first argument) is empty.
1054        for empty in ([], (), iter([])):
1055            self.assertRaises(statistics.StatisticsError, self.func, empty)
1056
1057    def prepare_data(self):
1058        """Return int data for various tests."""
1059        data = list(range(10))
1060        while data == sorted(data):
1061            random.shuffle(data)
1062        return data
1063
1064    def test_no_inplace_modifications(self):
1065        # Test that the function does not modify its input data.
1066        data = self.prepare_data()
1067        assert len(data) != 1  # Necessary to avoid infinite loop.
1068        assert data != sorted(data)
1069        saved = data[:]
1070        assert data is not saved
1071        _ = self.func(data)
1072        self.assertListEqual(data, saved, "data has been modified")
1073
1074    def test_order_doesnt_matter(self):
1075        # Test that the order of data points doesn't change the result.
1076
1077        # CAUTION: due to floating point rounding errors, the result actually
1078        # may depend on the order. Consider this test representing an ideal.
1079        # To avoid this test failing, only test with exact values such as ints
1080        # or Fractions.
1081        data = [1, 2, 3, 3, 3, 4, 5, 6]*100
1082        expected = self.func(data)
1083        random.shuffle(data)
1084        actual = self.func(data)
1085        self.assertEqual(expected, actual)
1086
1087    def test_type_of_data_collection(self):
1088        # Test that the type of iterable data doesn't effect the result.
1089        class MyList(list):
1090            pass
1091        class MyTuple(tuple):
1092            pass
1093        def generator(data):
1094            return (obj for obj in data)
1095        data = self.prepare_data()
1096        expected = self.func(data)
1097        for kind in (list, tuple, iter, MyList, MyTuple, generator):
1098            result = self.func(kind(data))
1099            self.assertEqual(result, expected)
1100
1101    def test_range_data(self):
1102        # Test that functions work with range objects.
1103        data = range(20, 50, 3)
1104        expected = self.func(list(data))
1105        self.assertEqual(self.func(data), expected)
1106
1107    def test_bad_arg_types(self):
1108        # Test that function raises when given data of the wrong type.
1109
1110        # Don't roll the following into a loop like this:
1111        #   for bad in list_of_bad:
1112        #       self.check_for_type_error(bad)
1113        #
1114        # Since assertRaises doesn't show the arguments that caused the test
1115        # failure, it is very difficult to debug these test failures when the
1116        # following are in a loop.
1117        self.check_for_type_error(None)
1118        self.check_for_type_error(23)
1119        self.check_for_type_error(42.0)
1120        self.check_for_type_error(object())
1121
1122    def check_for_type_error(self, *args):
1123        self.assertRaises(TypeError, self.func, *args)
1124
1125    def test_type_of_data_element(self):
1126        # Check the type of data elements doesn't affect the numeric result.
1127        # This is a weaker test than UnivariateTypeMixin.testTypesConserved,
1128        # because it checks the numeric result by equality, but not by type.
1129        class MyFloat(float):
1130            def __truediv__(self, other):
1131                return type(self)(super().__truediv__(other))
1132            def __add__(self, other):
1133                return type(self)(super().__add__(other))
1134            __radd__ = __add__
1135
1136        raw = self.prepare_data()
1137        expected = self.func(raw)
1138        for kind in (float, MyFloat, Decimal, Fraction):
1139            data = [kind(x) for x in raw]
1140            result = type(expected)(self.func(data))
1141            self.assertEqual(result, expected)
1142
1143
1144class UnivariateTypeMixin:
1145    """Mixin class for type-conserving functions.
1146
1147    This mixin class holds test(s) for functions which conserve the type of
1148    individual data points. E.g. the mean of a list of Fractions should itself
1149    be a Fraction.
1150
1151    Not all tests to do with types need go in this class. Only those that
1152    rely on the function returning the same type as its input data.
1153    """
1154    def prepare_types_for_conservation_test(self):
1155        """Return the types which are expected to be conserved."""
1156        class MyFloat(float):
1157            def __truediv__(self, other):
1158                return type(self)(super().__truediv__(other))
1159            def __rtruediv__(self, other):
1160                return type(self)(super().__rtruediv__(other))
1161            def __sub__(self, other):
1162                return type(self)(super().__sub__(other))
1163            def __rsub__(self, other):
1164                return type(self)(super().__rsub__(other))
1165            def __pow__(self, other):
1166                return type(self)(super().__pow__(other))
1167            def __add__(self, other):
1168                return type(self)(super().__add__(other))
1169            __radd__ = __add__
1170            def __mul__(self, other):
1171                return type(self)(super().__mul__(other))
1172            __rmul__ = __mul__
1173        return (float, Decimal, Fraction, MyFloat)
1174
1175    def test_types_conserved(self):
1176        # Test that functions keeps the same type as their data points.
1177        # (Excludes mixed data types.) This only tests the type of the return
1178        # result, not the value.
1179        data = self.prepare_data()
1180        for kind in self.prepare_types_for_conservation_test():
1181            d = [kind(x) for x in data]
1182            result = self.func(d)
1183            self.assertIs(type(result), kind)
1184
1185
1186class TestSumCommon(UnivariateCommonMixin, UnivariateTypeMixin):
1187    # Common test cases for statistics._sum() function.
1188
1189    # This test suite looks only at the numeric value returned by _sum,
1190    # after conversion to the appropriate type.
1191    def setUp(self):
1192        def simplified_sum(*args):
1193            T, value, n = statistics._sum(*args)
1194            return statistics._coerce(value, T)
1195        self.func = simplified_sum
1196
1197
1198class TestSum(NumericTestCase):
1199    # Test cases for statistics._sum() function.
1200
1201    # These tests look at the entire three value tuple returned by _sum.
1202
1203    def setUp(self):
1204        self.func = statistics._sum
1205
1206    def test_empty_data(self):
1207        # Override test for empty data.
1208        for data in ([], (), iter([])):
1209            self.assertEqual(self.func(data), (int, Fraction(0), 0))
1210
1211    def test_ints(self):
1212        self.assertEqual(self.func([1, 5, 3, -4, -8, 20, 42, 1]),
1213                         (int, Fraction(60), 8))
1214
1215    def test_floats(self):
1216        self.assertEqual(self.func([0.25]*20),
1217                         (float, Fraction(5.0), 20))
1218
1219    def test_fractions(self):
1220        self.assertEqual(self.func([Fraction(1, 1000)]*500),
1221                         (Fraction, Fraction(1, 2), 500))
1222
1223    def test_decimals(self):
1224        D = Decimal
1225        data = [D("0.001"), D("5.246"), D("1.702"), D("-0.025"),
1226                D("3.974"), D("2.328"), D("4.617"), D("2.843"),
1227                ]
1228        self.assertEqual(self.func(data),
1229                         (Decimal, Decimal("20.686"), 8))
1230
1231    def test_compare_with_math_fsum(self):
1232        # Compare with the math.fsum function.
1233        # Ideally we ought to get the exact same result, but sometimes
1234        # we differ by a very slight amount :-(
1235        data = [random.uniform(-100, 1000) for _ in range(1000)]
1236        self.assertApproxEqual(float(self.func(data)[1]), math.fsum(data), rel=2e-16)
1237
1238    def test_strings_fail(self):
1239        # Sum of strings should fail.
1240        self.assertRaises(TypeError, self.func, [1, 2, 3], '999')
1241        self.assertRaises(TypeError, self.func, [1, 2, 3, '999'])
1242
1243    def test_bytes_fail(self):
1244        # Sum of bytes should fail.
1245        self.assertRaises(TypeError, self.func, [1, 2, 3], b'999')
1246        self.assertRaises(TypeError, self.func, [1, 2, 3, b'999'])
1247
1248    def test_mixed_sum(self):
1249        # Mixed input types are not (currently) allowed.
1250        # Check that mixed data types fail.
1251        self.assertRaises(TypeError, self.func, [1, 2.0, Decimal(1)])
1252        # And so does mixed start argument.
1253        self.assertRaises(TypeError, self.func, [1, 2.0], Decimal(1))
1254
1255
1256class SumTortureTest(NumericTestCase):
1257    def test_torture(self):
1258        # Tim Peters' torture test for sum, and variants of same.
1259        self.assertEqual(statistics._sum([1, 1e100, 1, -1e100]*10000),
1260                         (float, Fraction(20000.0), 40000))
1261        self.assertEqual(statistics._sum([1e100, 1, 1, -1e100]*10000),
1262                         (float, Fraction(20000.0), 40000))
1263        T, num, count = statistics._sum([1e-100, 1, 1e-100, -1]*10000)
1264        self.assertIs(T, float)
1265        self.assertEqual(count, 40000)
1266        self.assertApproxEqual(float(num), 2.0e-96, rel=5e-16)
1267
1268
1269class SumSpecialValues(NumericTestCase):
1270    # Test that sum works correctly with IEEE-754 special values.
1271
1272    def test_nan(self):
1273        for type_ in (float, Decimal):
1274            nan = type_('nan')
1275            result = statistics._sum([1, nan, 2])[1]
1276            self.assertIs(type(result), type_)
1277            self.assertTrue(math.isnan(result))
1278
1279    def check_infinity(self, x, inf):
1280        """Check x is an infinity of the same type and sign as inf."""
1281        self.assertTrue(math.isinf(x))
1282        self.assertIs(type(x), type(inf))
1283        self.assertEqual(x > 0, inf > 0)
1284        assert x == inf
1285
1286    def do_test_inf(self, inf):
1287        # Adding a single infinity gives infinity.
1288        result = statistics._sum([1, 2, inf, 3])[1]
1289        self.check_infinity(result, inf)
1290        # Adding two infinities of the same sign also gives infinity.
1291        result = statistics._sum([1, 2, inf, 3, inf, 4])[1]
1292        self.check_infinity(result, inf)
1293
1294    def test_float_inf(self):
1295        inf = float('inf')
1296        for sign in (+1, -1):
1297            self.do_test_inf(sign*inf)
1298
1299    def test_decimal_inf(self):
1300        inf = Decimal('inf')
1301        for sign in (+1, -1):
1302            self.do_test_inf(sign*inf)
1303
1304    def test_float_mismatched_infs(self):
1305        # Test that adding two infinities of opposite sign gives a NAN.
1306        inf = float('inf')
1307        result = statistics._sum([1, 2, inf, 3, -inf, 4])[1]
1308        self.assertTrue(math.isnan(result))
1309
1310    def test_decimal_extendedcontext_mismatched_infs_to_nan(self):
1311        # Test adding Decimal INFs with opposite sign returns NAN.
1312        inf = Decimal('inf')
1313        data = [1, 2, inf, 3, -inf, 4]
1314        with decimal.localcontext(decimal.ExtendedContext):
1315            self.assertTrue(math.isnan(statistics._sum(data)[1]))
1316
1317    def test_decimal_basiccontext_mismatched_infs_to_nan(self):
1318        # Test adding Decimal INFs with opposite sign raises InvalidOperation.
1319        inf = Decimal('inf')
1320        data = [1, 2, inf, 3, -inf, 4]
1321        with decimal.localcontext(decimal.BasicContext):
1322            self.assertRaises(decimal.InvalidOperation, statistics._sum, data)
1323
1324    def test_decimal_snan_raises(self):
1325        # Adding sNAN should raise InvalidOperation.
1326        sNAN = Decimal('sNAN')
1327        data = [1, sNAN, 2]
1328        self.assertRaises(decimal.InvalidOperation, statistics._sum, data)
1329
1330
1331# === Tests for averages ===
1332
1333class AverageMixin(UnivariateCommonMixin):
1334    # Mixin class holding common tests for averages.
1335
1336    def test_single_value(self):
1337        # Average of a single value is the value itself.
1338        for x in (23, 42.5, 1.3e15, Fraction(15, 19), Decimal('0.28')):
1339            self.assertEqual(self.func([x]), x)
1340
1341    def prepare_values_for_repeated_single_test(self):
1342        return (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.9712'))
1343
1344    def test_repeated_single_value(self):
1345        # The average of a single repeated value is the value itself.
1346        for x in self.prepare_values_for_repeated_single_test():
1347            for count in (2, 5, 10, 20):
1348                with self.subTest(x=x, count=count):
1349                    data = [x]*count
1350                    self.assertEqual(self.func(data), x)
1351
1352
1353class TestMean(NumericTestCase, AverageMixin, UnivariateTypeMixin):
1354    def setUp(self):
1355        self.func = statistics.mean
1356
1357    def test_torture_pep(self):
1358        # "Torture Test" from PEP-450.
1359        self.assertEqual(self.func([1e100, 1, 3, -1e100]), 1)
1360
1361    def test_ints(self):
1362        # Test mean with ints.
1363        data = [0, 1, 2, 3, 3, 3, 4, 5, 5, 6, 7, 7, 7, 7, 8, 9]
1364        random.shuffle(data)
1365        self.assertEqual(self.func(data), 4.8125)
1366
1367    def test_floats(self):
1368        # Test mean with floats.
1369        data = [17.25, 19.75, 20.0, 21.5, 21.75, 23.25, 25.125, 27.5]
1370        random.shuffle(data)
1371        self.assertEqual(self.func(data), 22.015625)
1372
1373    def test_decimals(self):
1374        # Test mean with Decimals.
1375        D = Decimal
1376        data = [D("1.634"), D("2.517"), D("3.912"), D("4.072"), D("5.813")]
1377        random.shuffle(data)
1378        self.assertEqual(self.func(data), D("3.5896"))
1379
1380    def test_fractions(self):
1381        # Test mean with Fractions.
1382        F = Fraction
1383        data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)]
1384        random.shuffle(data)
1385        self.assertEqual(self.func(data), F(1479, 1960))
1386
1387    def test_inf(self):
1388        # Test mean with infinities.
1389        raw = [1, 3, 5, 7, 9]  # Use only ints, to avoid TypeError later.
1390        for kind in (float, Decimal):
1391            for sign in (1, -1):
1392                inf = kind("inf")*sign
1393                data = raw + [inf]
1394                result = self.func(data)
1395                self.assertTrue(math.isinf(result))
1396                self.assertEqual(result, inf)
1397
1398    def test_mismatched_infs(self):
1399        # Test mean with infinities of opposite sign.
1400        data = [2, 4, 6, float('inf'), 1, 3, 5, float('-inf')]
1401        result = self.func(data)
1402        self.assertTrue(math.isnan(result))
1403
1404    def test_nan(self):
1405        # Test mean with NANs.
1406        raw = [1, 3, 5, 7, 9]  # Use only ints, to avoid TypeError later.
1407        for kind in (float, Decimal):
1408            inf = kind("nan")
1409            data = raw + [inf]
1410            result = self.func(data)
1411            self.assertTrue(math.isnan(result))
1412
1413    def test_big_data(self):
1414        # Test adding a large constant to every data point.
1415        c = 1e9
1416        data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4]
1417        expected = self.func(data) + c
1418        assert expected != c
1419        result = self.func([x+c for x in data])
1420        self.assertEqual(result, expected)
1421
1422    def test_doubled_data(self):
1423        # Mean of [a,b,c...z] should be same as for [a,a,b,b,c,c...z,z].
1424        data = [random.uniform(-3, 5) for _ in range(1000)]
1425        expected = self.func(data)
1426        actual = self.func(data*2)
1427        self.assertApproxEqual(actual, expected)
1428
1429    def test_regression_20561(self):
1430        # Regression test for issue 20561.
1431        # See http://bugs.python.org/issue20561
1432        d = Decimal('1e4')
1433        self.assertEqual(statistics.mean([d]), d)
1434
1435    def test_regression_25177(self):
1436        # Regression test for issue 25177.
1437        # Ensure very big and very small floats don't overflow.
1438        # See http://bugs.python.org/issue25177.
1439        self.assertEqual(statistics.mean(
1440            [8.988465674311579e+307, 8.98846567431158e+307]),
1441            8.98846567431158e+307)
1442        big = 8.98846567431158e+307
1443        tiny = 5e-324
1444        for n in (2, 3, 5, 200):
1445            self.assertEqual(statistics.mean([big]*n), big)
1446            self.assertEqual(statistics.mean([tiny]*n), tiny)
1447
1448
1449class TestHarmonicMean(NumericTestCase, AverageMixin, UnivariateTypeMixin):
1450    def setUp(self):
1451        self.func = statistics.harmonic_mean
1452
1453    def prepare_data(self):
1454        # Override mixin method.
1455        values = super().prepare_data()
1456        values.remove(0)
1457        return values
1458
1459    def prepare_values_for_repeated_single_test(self):
1460        # Override mixin method.
1461        return (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.125'))
1462
1463    def test_zero(self):
1464        # Test that harmonic mean returns zero when given zero.
1465        values = [1, 0, 2]
1466        self.assertEqual(self.func(values), 0)
1467
1468    def test_negative_error(self):
1469        # Test that harmonic mean raises when given a negative value.
1470        exc = statistics.StatisticsError
1471        for values in ([-1], [1, -2, 3]):
1472            with self.subTest(values=values):
1473                self.assertRaises(exc, self.func, values)
1474
1475    def test_invalid_type_error(self):
1476        # Test error is raised when input contains invalid type(s)
1477        for data in [
1478            ['3.14'],               # single string
1479            ['1', '2', '3'],        # multiple strings
1480            [1, '2', 3, '4', 5],    # mixed strings and valid integers
1481            [2.3, 3.4, 4.5, '5.6']  # only one string and valid floats
1482        ]:
1483            with self.subTest(data=data):
1484                with self.assertRaises(TypeError):
1485                    self.func(data)
1486
1487    def test_ints(self):
1488        # Test harmonic mean with ints.
1489        data = [2, 4, 4, 8, 16, 16]
1490        random.shuffle(data)
1491        self.assertEqual(self.func(data), 6*4/5)
1492
1493    def test_floats_exact(self):
1494        # Test harmonic mean with some carefully chosen floats.
1495        data = [1/8, 1/4, 1/4, 1/2, 1/2]
1496        random.shuffle(data)
1497        self.assertEqual(self.func(data), 1/4)
1498        self.assertEqual(self.func([0.25, 0.5, 1.0, 1.0]), 0.5)
1499
1500    def test_singleton_lists(self):
1501        # Test that harmonic mean([x]) returns (approximately) x.
1502        for x in range(1, 101):
1503            self.assertEqual(self.func([x]), x)
1504
1505    def test_decimals_exact(self):
1506        # Test harmonic mean with some carefully chosen Decimals.
1507        D = Decimal
1508        self.assertEqual(self.func([D(15), D(30), D(60), D(60)]), D(30))
1509        data = [D("0.05"), D("0.10"), D("0.20"), D("0.20")]
1510        random.shuffle(data)
1511        self.assertEqual(self.func(data), D("0.10"))
1512        data = [D("1.68"), D("0.32"), D("5.94"), D("2.75")]
1513        random.shuffle(data)
1514        self.assertEqual(self.func(data), D(66528)/70723)
1515
1516    def test_fractions(self):
1517        # Test harmonic mean with Fractions.
1518        F = Fraction
1519        data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)]
1520        random.shuffle(data)
1521        self.assertEqual(self.func(data), F(7*420, 4029))
1522
1523    def test_inf(self):
1524        # Test harmonic mean with infinity.
1525        values = [2.0, float('inf'), 1.0]
1526        self.assertEqual(self.func(values), 2.0)
1527
1528    def test_nan(self):
1529        # Test harmonic mean with NANs.
1530        values = [2.0, float('nan'), 1.0]
1531        self.assertTrue(math.isnan(self.func(values)))
1532
1533    def test_multiply_data_points(self):
1534        # Test multiplying every data point by a constant.
1535        c = 111
1536        data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4]
1537        expected = self.func(data)*c
1538        result = self.func([x*c for x in data])
1539        self.assertEqual(result, expected)
1540
1541    def test_doubled_data(self):
1542        # Harmonic mean of [a,b...z] should be same as for [a,a,b,b...z,z].
1543        data = [random.uniform(1, 5) for _ in range(1000)]
1544        expected = self.func(data)
1545        actual = self.func(data*2)
1546        self.assertApproxEqual(actual, expected)
1547
1548    def test_with_weights(self):
1549        self.assertEqual(self.func([40, 60], [5, 30]), 56.0)  # common case
1550        self.assertEqual(self.func([40, 60],
1551                                   weights=[5, 30]), 56.0)    # keyword argument
1552        self.assertEqual(self.func(iter([40, 60]),
1553                                   iter([5, 30])), 56.0)      # iterator inputs
1554        self.assertEqual(
1555            self.func([Fraction(10, 3), Fraction(23, 5), Fraction(7, 2)], [5, 2, 10]),
1556            self.func([Fraction(10, 3)] * 5 +
1557                      [Fraction(23, 5)] * 2 +
1558                      [Fraction(7, 2)] * 10))
1559        self.assertEqual(self.func([10], [7]), 10)            # n=1 fast path
1560        with self.assertRaises(TypeError):
1561            self.func([1, 2, 3], [1, (), 3])                  # non-numeric weight
1562        with self.assertRaises(statistics.StatisticsError):
1563            self.func([1, 2, 3], [1, 2])                      # wrong number of weights
1564        with self.assertRaises(statistics.StatisticsError):
1565            self.func([10], [0])                              # no non-zero weights
1566        with self.assertRaises(statistics.StatisticsError):
1567            self.func([10, 20], [0, 0])                       # no non-zero weights
1568
1569
1570class TestMedian(NumericTestCase, AverageMixin):
1571    # Common tests for median and all median.* functions.
1572    def setUp(self):
1573        self.func = statistics.median
1574
1575    def prepare_data(self):
1576        """Overload method from UnivariateCommonMixin."""
1577        data = super().prepare_data()
1578        if len(data)%2 != 1:
1579            data.append(2)
1580        return data
1581
1582    def test_even_ints(self):
1583        # Test median with an even number of int data points.
1584        data = [1, 2, 3, 4, 5, 6]
1585        assert len(data)%2 == 0
1586        self.assertEqual(self.func(data), 3.5)
1587
1588    def test_odd_ints(self):
1589        # Test median with an odd number of int data points.
1590        data = [1, 2, 3, 4, 5, 6, 9]
1591        assert len(data)%2 == 1
1592        self.assertEqual(self.func(data), 4)
1593
1594    def test_odd_fractions(self):
1595        # Test median works with an odd number of Fractions.
1596        F = Fraction
1597        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7)]
1598        assert len(data)%2 == 1
1599        random.shuffle(data)
1600        self.assertEqual(self.func(data), F(3, 7))
1601
1602    def test_even_fractions(self):
1603        # Test median works with an even number of Fractions.
1604        F = Fraction
1605        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
1606        assert len(data)%2 == 0
1607        random.shuffle(data)
1608        self.assertEqual(self.func(data), F(1, 2))
1609
1610    def test_odd_decimals(self):
1611        # Test median works with an odd number of Decimals.
1612        D = Decimal
1613        data = [D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')]
1614        assert len(data)%2 == 1
1615        random.shuffle(data)
1616        self.assertEqual(self.func(data), D('4.2'))
1617
1618    def test_even_decimals(self):
1619        # Test median works with an even number of Decimals.
1620        D = Decimal
1621        data = [D('1.2'), D('2.5'), D('3.1'), D('4.2'), D('5.7'), D('5.8')]
1622        assert len(data)%2 == 0
1623        random.shuffle(data)
1624        self.assertEqual(self.func(data), D('3.65'))
1625
1626
1627class TestMedianDataType(NumericTestCase, UnivariateTypeMixin):
1628    # Test conservation of data element type for median.
1629    def setUp(self):
1630        self.func = statistics.median
1631
1632    def prepare_data(self):
1633        data = list(range(15))
1634        assert len(data)%2 == 1
1635        while data == sorted(data):
1636            random.shuffle(data)
1637        return data
1638
1639
1640class TestMedianLow(TestMedian, UnivariateTypeMixin):
1641    def setUp(self):
1642        self.func = statistics.median_low
1643
1644    def test_even_ints(self):
1645        # Test median_low with an even number of ints.
1646        data = [1, 2, 3, 4, 5, 6]
1647        assert len(data)%2 == 0
1648        self.assertEqual(self.func(data), 3)
1649
1650    def test_even_fractions(self):
1651        # Test median_low works with an even number of Fractions.
1652        F = Fraction
1653        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
1654        assert len(data)%2 == 0
1655        random.shuffle(data)
1656        self.assertEqual(self.func(data), F(3, 7))
1657
1658    def test_even_decimals(self):
1659        # Test median_low works with an even number of Decimals.
1660        D = Decimal
1661        data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')]
1662        assert len(data)%2 == 0
1663        random.shuffle(data)
1664        self.assertEqual(self.func(data), D('3.3'))
1665
1666
1667class TestMedianHigh(TestMedian, UnivariateTypeMixin):
1668    def setUp(self):
1669        self.func = statistics.median_high
1670
1671    def test_even_ints(self):
1672        # Test median_high with an even number of ints.
1673        data = [1, 2, 3, 4, 5, 6]
1674        assert len(data)%2 == 0
1675        self.assertEqual(self.func(data), 4)
1676
1677    def test_even_fractions(self):
1678        # Test median_high works with an even number of Fractions.
1679        F = Fraction
1680        data = [F(1, 7), F(2, 7), F(3, 7), F(4, 7), F(5, 7), F(6, 7)]
1681        assert len(data)%2 == 0
1682        random.shuffle(data)
1683        self.assertEqual(self.func(data), F(4, 7))
1684
1685    def test_even_decimals(self):
1686        # Test median_high works with an even number of Decimals.
1687        D = Decimal
1688        data = [D('1.1'), D('2.2'), D('3.3'), D('4.4'), D('5.5'), D('6.6')]
1689        assert len(data)%2 == 0
1690        random.shuffle(data)
1691        self.assertEqual(self.func(data), D('4.4'))
1692
1693
1694class TestMedianGrouped(TestMedian):
1695    # Test median_grouped.
1696    # Doesn't conserve data element types, so don't use TestMedianType.
1697    def setUp(self):
1698        self.func = statistics.median_grouped
1699
1700    def test_odd_number_repeated(self):
1701        # Test median.grouped with repeated median values.
1702        data = [12, 13, 14, 14, 14, 15, 15]
1703        assert len(data)%2 == 1
1704        self.assertEqual(self.func(data), 14)
1705        #---
1706        data = [12, 13, 14, 14, 14, 14, 15]
1707        assert len(data)%2 == 1
1708        self.assertEqual(self.func(data), 13.875)
1709        #---
1710        data = [5, 10, 10, 15, 20, 20, 20, 20, 25, 25, 30]
1711        assert len(data)%2 == 1
1712        self.assertEqual(self.func(data, 5), 19.375)
1713        #---
1714        data = [16, 18, 18, 18, 18, 20, 20, 20, 22, 22, 22, 24, 24, 26, 28]
1715        assert len(data)%2 == 1
1716        self.assertApproxEqual(self.func(data, 2), 20.66666667, tol=1e-8)
1717
1718    def test_even_number_repeated(self):
1719        # Test median.grouped with repeated median values.
1720        data = [5, 10, 10, 15, 20, 20, 20, 25, 25, 30]
1721        assert len(data)%2 == 0
1722        self.assertApproxEqual(self.func(data, 5), 19.16666667, tol=1e-8)
1723        #---
1724        data = [2, 3, 4, 4, 4, 5]
1725        assert len(data)%2 == 0
1726        self.assertApproxEqual(self.func(data), 3.83333333, tol=1e-8)
1727        #---
1728        data = [2, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6]
1729        assert len(data)%2 == 0
1730        self.assertEqual(self.func(data), 4.5)
1731        #---
1732        data = [3, 4, 4, 4, 5, 5, 5, 5, 6, 6]
1733        assert len(data)%2 == 0
1734        self.assertEqual(self.func(data), 4.75)
1735
1736    def test_repeated_single_value(self):
1737        # Override method from AverageMixin.
1738        # Yet again, failure of median_grouped to conserve the data type
1739        # causes me headaches :-(
1740        for x in (5.3, 68, 4.3e17, Fraction(29, 101), Decimal('32.9714')):
1741            for count in (2, 5, 10, 20):
1742                data = [x]*count
1743                self.assertEqual(self.func(data), float(x))
1744
1745    def test_single_value(self):
1746        # Override method from AverageMixin.
1747        # Average of a single value is the value as a float.
1748        for x in (23, 42.5, 1.3e15, Fraction(15, 19), Decimal('0.28')):
1749            self.assertEqual(self.func([x]), float(x))
1750
1751    def test_odd_fractions(self):
1752        # Test median_grouped works with an odd number of Fractions.
1753        F = Fraction
1754        data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4)]
1755        assert len(data)%2 == 1
1756        random.shuffle(data)
1757        self.assertEqual(self.func(data), 3.0)
1758
1759    def test_even_fractions(self):
1760        # Test median_grouped works with an even number of Fractions.
1761        F = Fraction
1762        data = [F(5, 4), F(9, 4), F(13, 4), F(13, 4), F(17, 4), F(17, 4)]
1763        assert len(data)%2 == 0
1764        random.shuffle(data)
1765        self.assertEqual(self.func(data), 3.25)
1766
1767    def test_odd_decimals(self):
1768        # Test median_grouped works with an odd number of Decimals.
1769        D = Decimal
1770        data = [D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')]
1771        assert len(data)%2 == 1
1772        random.shuffle(data)
1773        self.assertEqual(self.func(data), 6.75)
1774
1775    def test_even_decimals(self):
1776        # Test median_grouped works with an even number of Decimals.
1777        D = Decimal
1778        data = [D('5.5'), D('5.5'), D('6.5'), D('6.5'), D('7.5'), D('8.5')]
1779        assert len(data)%2 == 0
1780        random.shuffle(data)
1781        self.assertEqual(self.func(data), 6.5)
1782        #---
1783        data = [D('5.5'), D('5.5'), D('6.5'), D('7.5'), D('7.5'), D('8.5')]
1784        assert len(data)%2 == 0
1785        random.shuffle(data)
1786        self.assertEqual(self.func(data), 7.0)
1787
1788    def test_interval(self):
1789        # Test median_grouped with interval argument.
1790        data = [2.25, 2.5, 2.5, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75]
1791        self.assertEqual(self.func(data, 0.25), 2.875)
1792        data = [2.25, 2.5, 2.5, 2.75, 2.75, 2.75, 3.0, 3.0, 3.25, 3.5, 3.75]
1793        self.assertApproxEqual(self.func(data, 0.25), 2.83333333, tol=1e-8)
1794        data = [220, 220, 240, 260, 260, 260, 260, 280, 280, 300, 320, 340]
1795        self.assertEqual(self.func(data, 20), 265.0)
1796
1797    def test_data_type_error(self):
1798        # Test median_grouped with str, bytes data types for data and interval
1799        data = ["", "", ""]
1800        self.assertRaises(TypeError, self.func, data)
1801        #---
1802        data = [b"", b"", b""]
1803        self.assertRaises(TypeError, self.func, data)
1804        #---
1805        data = [1, 2, 3]
1806        interval = ""
1807        self.assertRaises(TypeError, self.func, data, interval)
1808        #---
1809        data = [1, 2, 3]
1810        interval = b""
1811        self.assertRaises(TypeError, self.func, data, interval)
1812
1813
1814class TestMode(NumericTestCase, AverageMixin, UnivariateTypeMixin):
1815    # Test cases for the discrete version of mode.
1816    def setUp(self):
1817        self.func = statistics.mode
1818
1819    def prepare_data(self):
1820        """Overload method from UnivariateCommonMixin."""
1821        # Make sure test data has exactly one mode.
1822        return [1, 1, 1, 1, 3, 4, 7, 9, 0, 8, 2]
1823
1824    def test_range_data(self):
1825        # Override test from UnivariateCommonMixin.
1826        data = range(20, 50, 3)
1827        self.assertEqual(self.func(data), 20)
1828
1829    def test_nominal_data(self):
1830        # Test mode with nominal data.
1831        data = 'abcbdb'
1832        self.assertEqual(self.func(data), 'b')
1833        data = 'fe fi fo fum fi fi'.split()
1834        self.assertEqual(self.func(data), 'fi')
1835
1836    def test_discrete_data(self):
1837        # Test mode with discrete numeric data.
1838        data = list(range(10))
1839        for i in range(10):
1840            d = data + [i]
1841            random.shuffle(d)
1842            self.assertEqual(self.func(d), i)
1843
1844    def test_bimodal_data(self):
1845        # Test mode with bimodal data.
1846        data = [1, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, 6, 6, 7, 8, 9, 9]
1847        assert data.count(2) == data.count(6) == 4
1848        # mode() should return 2, the first encountered mode
1849        self.assertEqual(self.func(data), 2)
1850
1851    def test_unique_data(self):
1852        # Test mode when data points are all unique.
1853        data = list(range(10))
1854        # mode() should return 0, the first encountered mode
1855        self.assertEqual(self.func(data), 0)
1856
1857    def test_none_data(self):
1858        # Test that mode raises TypeError if given None as data.
1859
1860        # This test is necessary because the implementation of mode uses
1861        # collections.Counter, which accepts None and returns an empty dict.
1862        self.assertRaises(TypeError, self.func, None)
1863
1864    def test_counter_data(self):
1865        # Test that a Counter is treated like any other iterable.
1866        # We're making sure mode() first calls iter() on its input.
1867        # The concern is that a Counter of a Counter returns the original
1868        # unchanged rather than counting its keys.
1869        c = collections.Counter(a=1, b=2)
1870        # If iter() is called, mode(c) loops over the keys, ['a', 'b'],
1871        # all the counts will be 1, and the first encountered mode is 'a'.
1872        self.assertEqual(self.func(c), 'a')
1873
1874
1875class TestMultiMode(unittest.TestCase):
1876
1877    def test_basics(self):
1878        multimode = statistics.multimode
1879        self.assertEqual(multimode('aabbbbbbbbcc'), ['b'])
1880        self.assertEqual(multimode('aabbbbccddddeeffffgg'), ['b', 'd', 'f'])
1881        self.assertEqual(multimode(''), [])
1882
1883
1884class TestFMean(unittest.TestCase):
1885
1886    def test_basics(self):
1887        fmean = statistics.fmean
1888        D = Decimal
1889        F = Fraction
1890        for data, expected_mean, kind in [
1891            ([3.5, 4.0, 5.25], 4.25, 'floats'),
1892            ([D('3.5'), D('4.0'), D('5.25')], 4.25, 'decimals'),
1893            ([F(7, 2), F(4, 1), F(21, 4)], 4.25, 'fractions'),
1894            ([True, False, True, True, False], 0.60, 'booleans'),
1895            ([3.5, 4, F(21, 4)], 4.25, 'mixed types'),
1896            ((3.5, 4.0, 5.25), 4.25, 'tuple'),
1897            (iter([3.5, 4.0, 5.25]), 4.25, 'iterator'),
1898                ]:
1899            actual_mean = fmean(data)
1900            self.assertIs(type(actual_mean), float, kind)
1901            self.assertEqual(actual_mean, expected_mean, kind)
1902
1903    def test_error_cases(self):
1904        fmean = statistics.fmean
1905        StatisticsError = statistics.StatisticsError
1906        with self.assertRaises(StatisticsError):
1907            fmean([])                               # empty input
1908        with self.assertRaises(StatisticsError):
1909            fmean(iter([]))                         # empty iterator
1910        with self.assertRaises(TypeError):
1911            fmean(None)                             # non-iterable input
1912        with self.assertRaises(TypeError):
1913            fmean([10, None, 20])                   # non-numeric input
1914        with self.assertRaises(TypeError):
1915            fmean()                                 # missing data argument
1916        with self.assertRaises(TypeError):
1917            fmean([10, 20, 60], 70)                 # too many arguments
1918
1919    def test_special_values(self):
1920        # Rules for special values are inherited from math.fsum()
1921        fmean = statistics.fmean
1922        NaN = float('Nan')
1923        Inf = float('Inf')
1924        self.assertTrue(math.isnan(fmean([10, NaN])), 'nan')
1925        self.assertTrue(math.isnan(fmean([NaN, Inf])), 'nan and infinity')
1926        self.assertTrue(math.isinf(fmean([10, Inf])), 'infinity')
1927        with self.assertRaises(ValueError):
1928            fmean([Inf, -Inf])
1929
1930    def test_weights(self):
1931        fmean = statistics.fmean
1932        StatisticsError = statistics.StatisticsError
1933        self.assertEqual(
1934            fmean([10, 10, 10, 50], [0.25] * 4),
1935            fmean([10, 10, 10, 50]))
1936        self.assertEqual(
1937            fmean([10, 10, 20], [0.25, 0.25, 0.50]),
1938            fmean([10, 10, 20, 20]))
1939        self.assertEqual(                           # inputs are iterators
1940            fmean(iter([10, 10, 20]), iter([0.25, 0.25, 0.50])),
1941            fmean([10, 10, 20, 20]))
1942        with self.assertRaises(StatisticsError):
1943            fmean([10, 20, 30], [1, 2])             # unequal lengths
1944        with self.assertRaises(StatisticsError):
1945            fmean(iter([10, 20, 30]), iter([1, 2])) # unequal lengths
1946        with self.assertRaises(StatisticsError):
1947            fmean([10, 20], [-1, 1])                # sum of weights is zero
1948        with self.assertRaises(StatisticsError):
1949            fmean(iter([10, 20]), iter([-1, 1]))    # sum of weights is zero
1950
1951
1952# === Tests for variances and standard deviations ===
1953
1954class VarianceStdevMixin(UnivariateCommonMixin):
1955    # Mixin class holding common tests for variance and std dev.
1956
1957    # Subclasses should inherit from this before NumericTestClass, in order
1958    # to see the rel attribute below. See testShiftData for an explanation.
1959
1960    rel = 1e-12
1961
1962    def test_single_value(self):
1963        # Deviation of a single value is zero.
1964        for x in (11, 19.8, 4.6e14, Fraction(21, 34), Decimal('8.392')):
1965            self.assertEqual(self.func([x]), 0)
1966
1967    def test_repeated_single_value(self):
1968        # The deviation of a single repeated value is zero.
1969        for x in (7.2, 49, 8.1e15, Fraction(3, 7), Decimal('62.4802')):
1970            for count in (2, 3, 5, 15):
1971                data = [x]*count
1972                self.assertEqual(self.func(data), 0)
1973
1974    def test_domain_error_regression(self):
1975        # Regression test for a domain error exception.
1976        # (Thanks to Geremy Condra.)
1977        data = [0.123456789012345]*10000
1978        # All the items are identical, so variance should be exactly zero.
1979        # We allow some small round-off error, but not much.
1980        result = self.func(data)
1981        self.assertApproxEqual(result, 0.0, tol=5e-17)
1982        self.assertGreaterEqual(result, 0)  # A negative result must fail.
1983
1984    def test_shift_data(self):
1985        # Test that shifting the data by a constant amount does not affect
1986        # the variance or stdev. Or at least not much.
1987
1988        # Due to rounding, this test should be considered an ideal. We allow
1989        # some tolerance away from "no change at all" by setting tol and/or rel
1990        # attributes. Subclasses may set tighter or looser error tolerances.
1991        raw = [1.03, 1.27, 1.94, 2.04, 2.58, 3.14, 4.75, 4.98, 5.42, 6.78]
1992        expected = self.func(raw)
1993        # Don't set shift too high, the bigger it is, the more rounding error.
1994        shift = 1e5
1995        data = [x + shift for x in raw]
1996        self.assertApproxEqual(self.func(data), expected)
1997
1998    def test_shift_data_exact(self):
1999        # Like test_shift_data, but result is always exact.
2000        raw = [1, 3, 3, 4, 5, 7, 9, 10, 11, 16]
2001        assert all(x==int(x) for x in raw)
2002        expected = self.func(raw)
2003        shift = 10**9
2004        data = [x + shift for x in raw]
2005        self.assertEqual(self.func(data), expected)
2006
2007    def test_iter_list_same(self):
2008        # Test that iter data and list data give the same result.
2009
2010        # This is an explicit test that iterators and lists are treated the
2011        # same; justification for this test over and above the similar test
2012        # in UnivariateCommonMixin is that an earlier design had variance and
2013        # friends swap between one- and two-pass algorithms, which would
2014        # sometimes give different results.
2015        data = [random.uniform(-3, 8) for _ in range(1000)]
2016        expected = self.func(data)
2017        self.assertEqual(self.func(iter(data)), expected)
2018
2019
2020class TestPVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
2021    # Tests for population variance.
2022    def setUp(self):
2023        self.func = statistics.pvariance
2024
2025    def test_exact_uniform(self):
2026        # Test the variance against an exact result for uniform data.
2027        data = list(range(10000))
2028        random.shuffle(data)
2029        expected = (10000**2 - 1)/12  # Exact value.
2030        self.assertEqual(self.func(data), expected)
2031
2032    def test_ints(self):
2033        # Test population variance with int data.
2034        data = [4, 7, 13, 16]
2035        exact = 22.5
2036        self.assertEqual(self.func(data), exact)
2037
2038    def test_fractions(self):
2039        # Test population variance with Fraction data.
2040        F = Fraction
2041        data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)]
2042        exact = F(3, 8)
2043        result = self.func(data)
2044        self.assertEqual(result, exact)
2045        self.assertIsInstance(result, Fraction)
2046
2047    def test_decimals(self):
2048        # Test population variance with Decimal data.
2049        D = Decimal
2050        data = [D("12.1"), D("12.2"), D("12.5"), D("12.9")]
2051        exact = D('0.096875')
2052        result = self.func(data)
2053        self.assertEqual(result, exact)
2054        self.assertIsInstance(result, Decimal)
2055
2056    def test_accuracy_bug_20499(self):
2057        data = [0, 0, 1]
2058        exact = 2 / 9
2059        result = self.func(data)
2060        self.assertEqual(result, exact)
2061        self.assertIsInstance(result, float)
2062
2063
2064class TestVariance(VarianceStdevMixin, NumericTestCase, UnivariateTypeMixin):
2065    # Tests for sample variance.
2066    def setUp(self):
2067        self.func = statistics.variance
2068
2069    def test_single_value(self):
2070        # Override method from VarianceStdevMixin.
2071        for x in (35, 24.7, 8.2e15, Fraction(19, 30), Decimal('4.2084')):
2072            self.assertRaises(statistics.StatisticsError, self.func, [x])
2073
2074    def test_ints(self):
2075        # Test sample variance with int data.
2076        data = [4, 7, 13, 16]
2077        exact = 30
2078        self.assertEqual(self.func(data), exact)
2079
2080    def test_fractions(self):
2081        # Test sample variance with Fraction data.
2082        F = Fraction
2083        data = [F(1, 4), F(1, 4), F(3, 4), F(7, 4)]
2084        exact = F(1, 2)
2085        result = self.func(data)
2086        self.assertEqual(result, exact)
2087        self.assertIsInstance(result, Fraction)
2088
2089    def test_decimals(self):
2090        # Test sample variance with Decimal data.
2091        D = Decimal
2092        data = [D(2), D(2), D(7), D(9)]
2093        exact = 4*D('9.5')/D(3)
2094        result = self.func(data)
2095        self.assertEqual(result, exact)
2096        self.assertIsInstance(result, Decimal)
2097
2098    def test_center_not_at_mean(self):
2099        data = (1.0, 2.0)
2100        self.assertEqual(self.func(data), 0.5)
2101        self.assertEqual(self.func(data, xbar=2.0), 1.0)
2102
2103    def test_accuracy_bug_20499(self):
2104        data = [0, 0, 2]
2105        exact = 4 / 3
2106        result = self.func(data)
2107        self.assertEqual(result, exact)
2108        self.assertIsInstance(result, float)
2109
2110class TestPStdev(VarianceStdevMixin, NumericTestCase):
2111    # Tests for population standard deviation.
2112    def setUp(self):
2113        self.func = statistics.pstdev
2114
2115    def test_compare_to_variance(self):
2116        # Test that stdev is, in fact, the square root of variance.
2117        data = [random.uniform(-17, 24) for _ in range(1000)]
2118        expected = math.sqrt(statistics.pvariance(data))
2119        self.assertEqual(self.func(data), expected)
2120
2121    def test_center_not_at_mean(self):
2122        # See issue: 40855
2123        data = (3, 6, 7, 10)
2124        self.assertEqual(self.func(data), 2.5)
2125        self.assertEqual(self.func(data, mu=0.5), 6.5)
2126
2127class TestSqrtHelpers(unittest.TestCase):
2128
2129    def test_integer_sqrt_of_frac_rto(self):
2130        for n, m in itertools.product(range(100), range(1, 1000)):
2131            r = statistics._integer_sqrt_of_frac_rto(n, m)
2132            self.assertIsInstance(r, int)
2133            if r*r*m == n:
2134                # Root is exact
2135                continue
2136            # Inexact, so the root should be odd
2137            self.assertEqual(r&1, 1)
2138            # Verify correct rounding
2139            self.assertTrue(m * (r - 1)**2 < n < m * (r + 1)**2)
2140
2141    @requires_IEEE_754
2142    def test_float_sqrt_of_frac(self):
2143
2144        def is_root_correctly_rounded(x: Fraction, root: float) -> bool:
2145            if not x:
2146                return root == 0.0
2147
2148            # Extract adjacent representable floats
2149            r_up: float = math.nextafter(root, math.inf)
2150            r_down: float = math.nextafter(root, -math.inf)
2151            assert r_down < root < r_up
2152
2153            # Convert to fractions for exact arithmetic
2154            frac_root: Fraction = Fraction(root)
2155            half_way_up: Fraction = (frac_root + Fraction(r_up)) / 2
2156            half_way_down: Fraction = (frac_root + Fraction(r_down)) / 2
2157
2158            # Check a closed interval.
2159            # Does not test for a midpoint rounding rule.
2160            return half_way_down ** 2 <= x <= half_way_up ** 2
2161
2162        randrange = random.randrange
2163
2164        for i in range(60_000):
2165            numerator: int = randrange(10 ** randrange(50))
2166            denonimator: int = randrange(10 ** randrange(50)) + 1
2167            with self.subTest(numerator=numerator, denonimator=denonimator):
2168                x: Fraction = Fraction(numerator, denonimator)
2169                root: float = statistics._float_sqrt_of_frac(numerator, denonimator)
2170                self.assertTrue(is_root_correctly_rounded(x, root))
2171
2172        # Verify that corner cases and error handling match math.sqrt()
2173        self.assertEqual(statistics._float_sqrt_of_frac(0, 1), 0.0)
2174        with self.assertRaises(ValueError):
2175            statistics._float_sqrt_of_frac(-1, 1)
2176        with self.assertRaises(ValueError):
2177            statistics._float_sqrt_of_frac(1, -1)
2178
2179        # Error handling for zero denominator matches that for Fraction(1, 0)
2180        with self.assertRaises(ZeroDivisionError):
2181            statistics._float_sqrt_of_frac(1, 0)
2182
2183        # The result is well defined if both inputs are negative
2184        self.assertEqual(statistics._float_sqrt_of_frac(-2, -1), statistics._float_sqrt_of_frac(2, 1))
2185
2186    def test_decimal_sqrt_of_frac(self):
2187        root: Decimal
2188        numerator: int
2189        denominator: int
2190
2191        for root, numerator, denominator in [
2192            (Decimal('0.4481904599041192673635338663'), 200874688349065940678243576378, 1000000000000000000000000000000),  # No adj
2193            (Decimal('0.7924949131383786609961759598'), 628048187350206338833590574929, 1000000000000000000000000000000),  # Adj up
2194            (Decimal('0.8500554152289934068192208727'), 722594208960136395984391238251, 1000000000000000000000000000000),  # Adj down
2195        ]:
2196            with decimal.localcontext(decimal.DefaultContext):
2197                self.assertEqual(statistics._decimal_sqrt_of_frac(numerator, denominator), root)
2198
2199            # Confirm expected root with a quad precision decimal computation
2200            with decimal.localcontext(decimal.DefaultContext) as ctx:
2201                ctx.prec *= 4
2202                high_prec_ratio = Decimal(numerator) / Decimal(denominator)
2203                ctx.rounding = decimal.ROUND_05UP
2204                high_prec_root = high_prec_ratio.sqrt()
2205            with decimal.localcontext(decimal.DefaultContext):
2206                target_root = +high_prec_root
2207            self.assertEqual(root, target_root)
2208
2209        # Verify that corner cases and error handling match Decimal.sqrt()
2210        self.assertEqual(statistics._decimal_sqrt_of_frac(0, 1), 0.0)
2211        with self.assertRaises(decimal.InvalidOperation):
2212            statistics._decimal_sqrt_of_frac(-1, 1)
2213        with self.assertRaises(decimal.InvalidOperation):
2214            statistics._decimal_sqrt_of_frac(1, -1)
2215
2216        # Error handling for zero denominator matches that for Fraction(1, 0)
2217        with self.assertRaises(ZeroDivisionError):
2218            statistics._decimal_sqrt_of_frac(1, 0)
2219
2220        # The result is well defined if both inputs are negative
2221        self.assertEqual(statistics._decimal_sqrt_of_frac(-2, -1), statistics._decimal_sqrt_of_frac(2, 1))
2222
2223
2224class TestStdev(VarianceStdevMixin, NumericTestCase):
2225    # Tests for sample standard deviation.
2226    def setUp(self):
2227        self.func = statistics.stdev
2228
2229    def test_single_value(self):
2230        # Override method from VarianceStdevMixin.
2231        for x in (81, 203.74, 3.9e14, Fraction(5, 21), Decimal('35.719')):
2232            self.assertRaises(statistics.StatisticsError, self.func, [x])
2233
2234    def test_compare_to_variance(self):
2235        # Test that stdev is, in fact, the square root of variance.
2236        data = [random.uniform(-2, 9) for _ in range(1000)]
2237        expected = math.sqrt(statistics.variance(data))
2238        self.assertAlmostEqual(self.func(data), expected)
2239
2240    def test_center_not_at_mean(self):
2241        data = (1.0, 2.0)
2242        self.assertEqual(self.func(data, xbar=2.0), 1.0)
2243
2244class TestGeometricMean(unittest.TestCase):
2245
2246    def test_basics(self):
2247        geometric_mean = statistics.geometric_mean
2248        self.assertAlmostEqual(geometric_mean([54, 24, 36]), 36.0)
2249        self.assertAlmostEqual(geometric_mean([4.0, 9.0]), 6.0)
2250        self.assertAlmostEqual(geometric_mean([17.625]), 17.625)
2251
2252        random.seed(86753095551212)
2253        for rng in [
2254                range(1, 100),
2255                range(1, 1_000),
2256                range(1, 10_000),
2257                range(500, 10_000, 3),
2258                range(10_000, 500, -3),
2259                [12, 17, 13, 5, 120, 7],
2260                [random.expovariate(50.0) for i in range(1_000)],
2261                [random.lognormvariate(20.0, 3.0) for i in range(2_000)],
2262                [random.triangular(2000, 3000, 2200) for i in range(3_000)],
2263            ]:
2264            gm_decimal = math.prod(map(Decimal, rng)) ** (Decimal(1) / len(rng))
2265            gm_float = geometric_mean(rng)
2266            self.assertTrue(math.isclose(gm_float, float(gm_decimal)))
2267
2268    def test_various_input_types(self):
2269        geometric_mean = statistics.geometric_mean
2270        D = Decimal
2271        F = Fraction
2272        # https://www.wolframalpha.com/input/?i=geometric+mean+3.5,+4.0,+5.25
2273        expected_mean = 4.18886
2274        for data, kind in [
2275            ([3.5, 4.0, 5.25], 'floats'),
2276            ([D('3.5'), D('4.0'), D('5.25')], 'decimals'),
2277            ([F(7, 2), F(4, 1), F(21, 4)], 'fractions'),
2278            ([3.5, 4, F(21, 4)], 'mixed types'),
2279            ((3.5, 4.0, 5.25), 'tuple'),
2280            (iter([3.5, 4.0, 5.25]), 'iterator'),
2281                ]:
2282            actual_mean = geometric_mean(data)
2283            self.assertIs(type(actual_mean), float, kind)
2284            self.assertAlmostEqual(actual_mean, expected_mean, places=5)
2285
2286    def test_big_and_small(self):
2287        geometric_mean = statistics.geometric_mean
2288
2289        # Avoid overflow to infinity
2290        large = 2.0 ** 1000
2291        big_gm = geometric_mean([54.0 * large, 24.0 * large, 36.0 * large])
2292        self.assertTrue(math.isclose(big_gm, 36.0 * large))
2293        self.assertFalse(math.isinf(big_gm))
2294
2295        # Avoid underflow to zero
2296        small = 2.0 ** -1000
2297        small_gm = geometric_mean([54.0 * small, 24.0 * small, 36.0 * small])
2298        self.assertTrue(math.isclose(small_gm, 36.0 * small))
2299        self.assertNotEqual(small_gm, 0.0)
2300
2301    def test_error_cases(self):
2302        geometric_mean = statistics.geometric_mean
2303        StatisticsError = statistics.StatisticsError
2304        with self.assertRaises(StatisticsError):
2305            geometric_mean([])                      # empty input
2306        with self.assertRaises(StatisticsError):
2307            geometric_mean([3.5, 0.0, 5.25])        # zero input
2308        with self.assertRaises(StatisticsError):
2309            geometric_mean([3.5, -4.0, 5.25])       # negative input
2310        with self.assertRaises(StatisticsError):
2311            geometric_mean(iter([]))                # empty iterator
2312        with self.assertRaises(TypeError):
2313            geometric_mean(None)                    # non-iterable input
2314        with self.assertRaises(TypeError):
2315            geometric_mean([10, None, 20])          # non-numeric input
2316        with self.assertRaises(TypeError):
2317            geometric_mean()                        # missing data argument
2318        with self.assertRaises(TypeError):
2319            geometric_mean([10, 20, 60], 70)        # too many arguments
2320
2321    def test_special_values(self):
2322        # Rules for special values are inherited from math.fsum()
2323        geometric_mean = statistics.geometric_mean
2324        NaN = float('Nan')
2325        Inf = float('Inf')
2326        self.assertTrue(math.isnan(geometric_mean([10, NaN])), 'nan')
2327        self.assertTrue(math.isnan(geometric_mean([NaN, Inf])), 'nan and infinity')
2328        self.assertTrue(math.isinf(geometric_mean([10, Inf])), 'infinity')
2329        with self.assertRaises(ValueError):
2330            geometric_mean([Inf, -Inf])
2331
2332    def test_mixed_int_and_float(self):
2333        # Regression test for b.p.o. issue #28327
2334        geometric_mean = statistics.geometric_mean
2335        expected_mean = 3.80675409583932
2336        values = [
2337            [2, 3, 5, 7],
2338            [2, 3, 5, 7.0],
2339            [2, 3, 5.0, 7.0],
2340            [2, 3.0, 5.0, 7.0],
2341            [2.0, 3.0, 5.0, 7.0],
2342        ]
2343        for v in values:
2344            with self.subTest(v=v):
2345                actual_mean = geometric_mean(v)
2346                self.assertAlmostEqual(actual_mean, expected_mean, places=5)
2347
2348
2349class TestQuantiles(unittest.TestCase):
2350
2351    def test_specific_cases(self):
2352        # Match results computed by hand and cross-checked
2353        # against the PERCENTILE.EXC function in MS Excel.
2354        quantiles = statistics.quantiles
2355        data = [120, 200, 250, 320, 350]
2356        random.shuffle(data)
2357        for n, expected in [
2358            (1, []),
2359            (2, [250.0]),
2360            (3, [200.0, 320.0]),
2361            (4, [160.0, 250.0, 335.0]),
2362            (5, [136.0, 220.0, 292.0, 344.0]),
2363            (6, [120.0, 200.0, 250.0, 320.0, 350.0]),
2364            (8, [100.0, 160.0, 212.5, 250.0, 302.5, 335.0, 357.5]),
2365            (10, [88.0, 136.0, 184.0, 220.0, 250.0, 292.0, 326.0, 344.0, 362.0]),
2366            (12, [80.0, 120.0, 160.0, 200.0, 225.0, 250.0, 285.0, 320.0, 335.0,
2367                  350.0, 365.0]),
2368            (15, [72.0, 104.0, 136.0, 168.0, 200.0, 220.0, 240.0, 264.0, 292.0,
2369                  320.0, 332.0, 344.0, 356.0, 368.0]),
2370                ]:
2371            self.assertEqual(expected, quantiles(data, n=n))
2372            self.assertEqual(len(quantiles(data, n=n)), n - 1)
2373            # Preserve datatype when possible
2374            for datatype in (float, Decimal, Fraction):
2375                result = quantiles(map(datatype, data), n=n)
2376                self.assertTrue(all(type(x) == datatype) for x in result)
2377                self.assertEqual(result, list(map(datatype, expected)))
2378            # Quantiles should be idempotent
2379            if len(expected) >= 2:
2380                self.assertEqual(quantiles(expected, n=n), expected)
2381            # Cross-check against method='inclusive' which should give
2382            # the same result after adding in minimum and maximum values
2383            # extrapolated from the two lowest and two highest points.
2384            sdata = sorted(data)
2385            lo = 2 * sdata[0] - sdata[1]
2386            hi = 2 * sdata[-1] - sdata[-2]
2387            padded_data = data + [lo, hi]
2388            self.assertEqual(
2389                quantiles(data, n=n),
2390                quantiles(padded_data, n=n, method='inclusive'),
2391                (n, data),
2392            )
2393            # Invariant under translation and scaling
2394            def f(x):
2395                return 3.5 * x - 1234.675
2396            exp = list(map(f, expected))
2397            act = quantiles(map(f, data), n=n)
2398            self.assertTrue(all(math.isclose(e, a) for e, a in zip(exp, act)))
2399        # Q2 agrees with median()
2400        for k in range(2, 60):
2401            data = random.choices(range(100), k=k)
2402            q1, q2, q3 = quantiles(data)
2403            self.assertEqual(q2, statistics.median(data))
2404
2405    def test_specific_cases_inclusive(self):
2406        # Match results computed by hand and cross-checked
2407        # against the PERCENTILE.INC function in MS Excel
2408        # and against the quantile() function in SciPy.
2409        quantiles = statistics.quantiles
2410        data = [100, 200, 400, 800]
2411        random.shuffle(data)
2412        for n, expected in [
2413            (1, []),
2414            (2, [300.0]),
2415            (3, [200.0, 400.0]),
2416            (4, [175.0, 300.0, 500.0]),
2417            (5, [160.0, 240.0, 360.0, 560.0]),
2418            (6, [150.0, 200.0, 300.0, 400.0, 600.0]),
2419            (8, [137.5, 175, 225.0, 300.0, 375.0, 500.0,650.0]),
2420            (10, [130.0, 160.0, 190.0, 240.0, 300.0, 360.0, 440.0, 560.0, 680.0]),
2421            (12, [125.0, 150.0, 175.0, 200.0, 250.0, 300.0, 350.0, 400.0,
2422                  500.0, 600.0, 700.0]),
2423            (15, [120.0, 140.0, 160.0, 180.0, 200.0, 240.0, 280.0, 320.0, 360.0,
2424                  400.0, 480.0, 560.0, 640.0, 720.0]),
2425                ]:
2426            self.assertEqual(expected, quantiles(data, n=n, method="inclusive"))
2427            self.assertEqual(len(quantiles(data, n=n, method="inclusive")), n - 1)
2428            # Preserve datatype when possible
2429            for datatype in (float, Decimal, Fraction):
2430                result = quantiles(map(datatype, data), n=n, method="inclusive")
2431                self.assertTrue(all(type(x) == datatype) for x in result)
2432                self.assertEqual(result, list(map(datatype, expected)))
2433            # Invariant under translation and scaling
2434            def f(x):
2435                return 3.5 * x - 1234.675
2436            exp = list(map(f, expected))
2437            act = quantiles(map(f, data), n=n, method="inclusive")
2438            self.assertTrue(all(math.isclose(e, a) for e, a in zip(exp, act)))
2439        # Natural deciles
2440        self.assertEqual(quantiles([0, 100], n=10, method='inclusive'),
2441                         [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0])
2442        self.assertEqual(quantiles(range(0, 101), n=10, method='inclusive'),
2443                         [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0])
2444        # Whenever n is smaller than the number of data points, running
2445        # method='inclusive' should give the same result as method='exclusive'
2446        # after the two included extreme points are removed.
2447        data = [random.randrange(10_000) for i in range(501)]
2448        actual = quantiles(data, n=32, method='inclusive')
2449        data.remove(min(data))
2450        data.remove(max(data))
2451        expected = quantiles(data, n=32)
2452        self.assertEqual(expected, actual)
2453        # Q2 agrees with median()
2454        for k in range(2, 60):
2455            data = random.choices(range(100), k=k)
2456            q1, q2, q3 = quantiles(data, method='inclusive')
2457            self.assertEqual(q2, statistics.median(data))
2458
2459    def test_equal_inputs(self):
2460        quantiles = statistics.quantiles
2461        for n in range(2, 10):
2462            data = [10.0] * n
2463            self.assertEqual(quantiles(data), [10.0, 10.0, 10.0])
2464            self.assertEqual(quantiles(data, method='inclusive'),
2465                            [10.0, 10.0, 10.0])
2466
2467    def test_equal_sized_groups(self):
2468        quantiles = statistics.quantiles
2469        total = 10_000
2470        data = [random.expovariate(0.2) for i in range(total)]
2471        while len(set(data)) != total:
2472            data.append(random.expovariate(0.2))
2473        data.sort()
2474
2475        # Cases where the group size exactly divides the total
2476        for n in (1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000):
2477            group_size = total // n
2478            self.assertEqual(
2479                [bisect.bisect(data, q) for q in quantiles(data, n=n)],
2480                list(range(group_size, total, group_size)))
2481
2482        # When the group sizes can't be exactly equal, they should
2483        # differ by no more than one
2484        for n in (13, 19, 59, 109, 211, 571, 1019, 1907, 5261, 9769):
2485            group_sizes = {total // n, total // n + 1}
2486            pos = [bisect.bisect(data, q) for q in quantiles(data, n=n)]
2487            sizes = {q - p for p, q in zip(pos, pos[1:])}
2488            self.assertTrue(sizes <= group_sizes)
2489
2490    def test_error_cases(self):
2491        quantiles = statistics.quantiles
2492        StatisticsError = statistics.StatisticsError
2493        with self.assertRaises(TypeError):
2494            quantiles()                         # Missing arguments
2495        with self.assertRaises(TypeError):
2496            quantiles([10, 20, 30], 13, n=4)    # Too many arguments
2497        with self.assertRaises(TypeError):
2498            quantiles([10, 20, 30], 4)          # n is a positional argument
2499        with self.assertRaises(StatisticsError):
2500            quantiles([10, 20, 30], n=0)        # n is zero
2501        with self.assertRaises(StatisticsError):
2502            quantiles([10, 20, 30], n=-1)       # n is negative
2503        with self.assertRaises(TypeError):
2504            quantiles([10, 20, 30], n=1.5)      # n is not an integer
2505        with self.assertRaises(ValueError):
2506            quantiles([10, 20, 30], method='X') # method is unknown
2507        with self.assertRaises(StatisticsError):
2508            quantiles([10], n=4)                # not enough data points
2509        with self.assertRaises(TypeError):
2510            quantiles([10, None, 30], n=4)      # data is non-numeric
2511
2512
2513class TestBivariateStatistics(unittest.TestCase):
2514
2515    def test_unequal_size_error(self):
2516        for x, y in [
2517            ([1, 2, 3], [1, 2]),
2518            ([1, 2], [1, 2, 3]),
2519        ]:
2520            with self.assertRaises(statistics.StatisticsError):
2521                statistics.covariance(x, y)
2522            with self.assertRaises(statistics.StatisticsError):
2523                statistics.correlation(x, y)
2524            with self.assertRaises(statistics.StatisticsError):
2525                statistics.linear_regression(x, y)
2526
2527    def test_small_sample_error(self):
2528        for x, y in [
2529            ([], []),
2530            ([], [1, 2,]),
2531            ([1, 2,], []),
2532            ([1,], [1,]),
2533            ([1,], [1, 2,]),
2534            ([1, 2,], [1,]),
2535        ]:
2536            with self.assertRaises(statistics.StatisticsError):
2537                statistics.covariance(x, y)
2538            with self.assertRaises(statistics.StatisticsError):
2539                statistics.correlation(x, y)
2540            with self.assertRaises(statistics.StatisticsError):
2541                statistics.linear_regression(x, y)
2542
2543
2544class TestCorrelationAndCovariance(unittest.TestCase):
2545
2546    def test_results(self):
2547        for x, y, result in [
2548            ([1, 2, 3], [1, 2, 3], 1),
2549            ([1, 2, 3], [-1, -2, -3], -1),
2550            ([1, 2, 3], [3, 2, 1], -1),
2551            ([1, 2, 3], [1, 2, 1], 0),
2552            ([1, 2, 3], [1, 3, 2], 0.5),
2553        ]:
2554            self.assertAlmostEqual(statistics.correlation(x, y), result)
2555            self.assertAlmostEqual(statistics.covariance(x, y), result)
2556
2557    def test_different_scales(self):
2558        x = [1, 2, 3]
2559        y = [10, 30, 20]
2560        self.assertAlmostEqual(statistics.correlation(x, y), 0.5)
2561        self.assertAlmostEqual(statistics.covariance(x, y), 5)
2562
2563        y = [.1, .2, .3]
2564        self.assertAlmostEqual(statistics.correlation(x, y), 1)
2565        self.assertAlmostEqual(statistics.covariance(x, y), 0.1)
2566
2567
2568class TestLinearRegression(unittest.TestCase):
2569
2570    def test_constant_input_error(self):
2571        x = [1, 1, 1,]
2572        y = [1, 2, 3,]
2573        with self.assertRaises(statistics.StatisticsError):
2574            statistics.linear_regression(x, y)
2575
2576    def test_results(self):
2577        for x, y, true_intercept, true_slope in [
2578            ([1, 2, 3], [0, 0, 0], 0, 0),
2579            ([1, 2, 3], [1, 2, 3], 0, 1),
2580            ([1, 2, 3], [100, 100, 100], 100, 0),
2581            ([1, 2, 3], [12, 14, 16], 10, 2),
2582            ([1, 2, 3], [-1, -2, -3], 0, -1),
2583            ([1, 2, 3], [21, 22, 23], 20, 1),
2584            ([1, 2, 3], [5.1, 5.2, 5.3], 5, 0.1),
2585        ]:
2586            slope, intercept = statistics.linear_regression(x, y)
2587            self.assertAlmostEqual(intercept, true_intercept)
2588            self.assertAlmostEqual(slope, true_slope)
2589
2590    def test_proportional(self):
2591        x = [10, 20, 30, 40]
2592        y = [180, 398, 610, 799]
2593        slope, intercept = statistics.linear_regression(x, y, proportional=True)
2594        self.assertAlmostEqual(slope, 20 + 1/150)
2595        self.assertEqual(intercept, 0.0)
2596
2597class TestNormalDist:
2598
2599    # General note on precision: The pdf(), cdf(), and overlap() methods
2600    # depend on functions in the math libraries that do not make
2601    # explicit accuracy guarantees.  Accordingly, some of the accuracy
2602    # tests below may fail if the underlying math functions are
2603    # inaccurate.  There isn't much we can do about this short of
2604    # implementing our own implementations from scratch.
2605
2606    def test_slots(self):
2607        nd = self.module.NormalDist(300, 23)
2608        with self.assertRaises(TypeError):
2609            vars(nd)
2610        self.assertEqual(tuple(nd.__slots__), ('_mu', '_sigma'))
2611
2612    def test_instantiation_and_attributes(self):
2613        nd = self.module.NormalDist(500, 17)
2614        self.assertEqual(nd.mean, 500)
2615        self.assertEqual(nd.stdev, 17)
2616        self.assertEqual(nd.variance, 17**2)
2617
2618        # default arguments
2619        nd = self.module.NormalDist()
2620        self.assertEqual(nd.mean, 0)
2621        self.assertEqual(nd.stdev, 1)
2622        self.assertEqual(nd.variance, 1**2)
2623
2624        # error case: negative sigma
2625        with self.assertRaises(self.module.StatisticsError):
2626            self.module.NormalDist(500, -10)
2627
2628        # verify that subclass type is honored
2629        class NewNormalDist(self.module.NormalDist):
2630            pass
2631        nnd = NewNormalDist(200, 5)
2632        self.assertEqual(type(nnd), NewNormalDist)
2633
2634    def test_alternative_constructor(self):
2635        NormalDist = self.module.NormalDist
2636        data = [96, 107, 90, 92, 110]
2637        # list input
2638        self.assertEqual(NormalDist.from_samples(data), NormalDist(99, 9))
2639        # tuple input
2640        self.assertEqual(NormalDist.from_samples(tuple(data)), NormalDist(99, 9))
2641        # iterator input
2642        self.assertEqual(NormalDist.from_samples(iter(data)), NormalDist(99, 9))
2643        # error cases
2644        with self.assertRaises(self.module.StatisticsError):
2645            NormalDist.from_samples([])                      # empty input
2646        with self.assertRaises(self.module.StatisticsError):
2647            NormalDist.from_samples([10])                    # only one input
2648
2649        # verify that subclass type is honored
2650        class NewNormalDist(NormalDist):
2651            pass
2652        nnd = NewNormalDist.from_samples(data)
2653        self.assertEqual(type(nnd), NewNormalDist)
2654
2655    def test_sample_generation(self):
2656        NormalDist = self.module.NormalDist
2657        mu, sigma = 10_000, 3.0
2658        X = NormalDist(mu, sigma)
2659        n = 1_000
2660        data = X.samples(n)
2661        self.assertEqual(len(data), n)
2662        self.assertEqual(set(map(type, data)), {float})
2663        # mean(data) expected to fall within 8 standard deviations
2664        xbar = self.module.mean(data)
2665        self.assertTrue(mu - sigma*8 <= xbar <= mu + sigma*8)
2666
2667        # verify that seeding makes reproducible sequences
2668        n = 100
2669        data1 = X.samples(n, seed='happiness and joy')
2670        data2 = X.samples(n, seed='trouble and despair')
2671        data3 = X.samples(n, seed='happiness and joy')
2672        data4 = X.samples(n, seed='trouble and despair')
2673        self.assertEqual(data1, data3)
2674        self.assertEqual(data2, data4)
2675        self.assertNotEqual(data1, data2)
2676
2677    def test_pdf(self):
2678        NormalDist = self.module.NormalDist
2679        X = NormalDist(100, 15)
2680        # Verify peak around center
2681        self.assertLess(X.pdf(99), X.pdf(100))
2682        self.assertLess(X.pdf(101), X.pdf(100))
2683        # Test symmetry
2684        for i in range(50):
2685            self.assertAlmostEqual(X.pdf(100 - i), X.pdf(100 + i))
2686        # Test vs CDF
2687        dx = 2.0 ** -10
2688        for x in range(90, 111):
2689            est_pdf = (X.cdf(x + dx) - X.cdf(x)) / dx
2690            self.assertAlmostEqual(X.pdf(x), est_pdf, places=4)
2691        # Test vs table of known values -- CRC 26th Edition
2692        Z = NormalDist()
2693        for x, px in enumerate([
2694            0.3989, 0.3989, 0.3989, 0.3988, 0.3986,
2695            0.3984, 0.3982, 0.3980, 0.3977, 0.3973,
2696            0.3970, 0.3965, 0.3961, 0.3956, 0.3951,
2697            0.3945, 0.3939, 0.3932, 0.3925, 0.3918,
2698            0.3910, 0.3902, 0.3894, 0.3885, 0.3876,
2699            0.3867, 0.3857, 0.3847, 0.3836, 0.3825,
2700            0.3814, 0.3802, 0.3790, 0.3778, 0.3765,
2701            0.3752, 0.3739, 0.3725, 0.3712, 0.3697,
2702            0.3683, 0.3668, 0.3653, 0.3637, 0.3621,
2703            0.3605, 0.3589, 0.3572, 0.3555, 0.3538,
2704        ]):
2705            self.assertAlmostEqual(Z.pdf(x / 100.0), px, places=4)
2706            self.assertAlmostEqual(Z.pdf(-x / 100.0), px, places=4)
2707        # Error case: variance is zero
2708        Y = NormalDist(100, 0)
2709        with self.assertRaises(self.module.StatisticsError):
2710            Y.pdf(90)
2711        # Special values
2712        self.assertEqual(X.pdf(float('-Inf')), 0.0)
2713        self.assertEqual(X.pdf(float('Inf')), 0.0)
2714        self.assertTrue(math.isnan(X.pdf(float('NaN'))))
2715
2716    def test_cdf(self):
2717        NormalDist = self.module.NormalDist
2718        X = NormalDist(100, 15)
2719        cdfs = [X.cdf(x) for x in range(1, 200)]
2720        self.assertEqual(set(map(type, cdfs)), {float})
2721        # Verify montonic
2722        self.assertEqual(cdfs, sorted(cdfs))
2723        # Verify center (should be exact)
2724        self.assertEqual(X.cdf(100), 0.50)
2725        # Check against a table of known values
2726        # https://en.wikipedia.org/wiki/Standard_normal_table#Cumulative
2727        Z = NormalDist()
2728        for z, cum_prob in [
2729            (0.00, 0.50000), (0.01, 0.50399), (0.02, 0.50798),
2730            (0.14, 0.55567), (0.29, 0.61409), (0.33, 0.62930),
2731            (0.54, 0.70540), (0.60, 0.72575), (1.17, 0.87900),
2732            (1.60, 0.94520), (2.05, 0.97982), (2.89, 0.99807),
2733            (3.52, 0.99978), (3.98, 0.99997), (4.07, 0.99998),
2734            ]:
2735            self.assertAlmostEqual(Z.cdf(z), cum_prob, places=5)
2736            self.assertAlmostEqual(Z.cdf(-z), 1.0 - cum_prob, places=5)
2737        # Error case: variance is zero
2738        Y = NormalDist(100, 0)
2739        with self.assertRaises(self.module.StatisticsError):
2740            Y.cdf(90)
2741        # Special values
2742        self.assertEqual(X.cdf(float('-Inf')), 0.0)
2743        self.assertEqual(X.cdf(float('Inf')), 1.0)
2744        self.assertTrue(math.isnan(X.cdf(float('NaN'))))
2745
2746    @support.skip_if_pgo_task
2747    def test_inv_cdf(self):
2748        NormalDist = self.module.NormalDist
2749
2750        # Center case should be exact.
2751        iq = NormalDist(100, 15)
2752        self.assertEqual(iq.inv_cdf(0.50), iq.mean)
2753
2754        # Test versus a published table of known percentage points.
2755        # See the second table at the bottom of the page here:
2756        # http://people.bath.ac.uk/masss/tables/normaltable.pdf
2757        Z = NormalDist()
2758        pp = {5.0: (0.000, 1.645, 2.576, 3.291, 3.891,
2759                    4.417, 4.892, 5.327, 5.731, 6.109),
2760              2.5: (0.674, 1.960, 2.807, 3.481, 4.056,
2761                    4.565, 5.026, 5.451, 5.847, 6.219),
2762              1.0: (1.282, 2.326, 3.090, 3.719, 4.265,
2763                    4.753, 5.199, 5.612, 5.998, 6.361)}
2764        for base, row in pp.items():
2765            for exp, x in enumerate(row, start=1):
2766                p = base * 10.0 ** (-exp)
2767                self.assertAlmostEqual(-Z.inv_cdf(p), x, places=3)
2768                p = 1.0 - p
2769                self.assertAlmostEqual(Z.inv_cdf(p), x, places=3)
2770
2771        # Match published example for MS Excel
2772        # https://support.office.com/en-us/article/norm-inv-function-54b30935-fee7-493c-bedb-2278a9db7e13
2773        self.assertAlmostEqual(NormalDist(40, 1.5).inv_cdf(0.908789), 42.000002)
2774
2775        # One million equally spaced probabilities
2776        n = 2**20
2777        for p in range(1, n):
2778            p /= n
2779            self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p)
2780
2781        # One hundred ever smaller probabilities to test tails out to
2782        # extreme probabilities: 1 / 2**50 and (2**50-1) / 2 ** 50
2783        for e in range(1, 51):
2784            p = 2.0 ** (-e)
2785            self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p)
2786            p = 1.0 - p
2787            self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p)
2788
2789        # Now apply cdf() first.  Near the tails, the round-trip loses
2790        # precision and is ill-conditioned (small changes in the inputs
2791        # give large changes in the output), so only check to 5 places.
2792        for x in range(200):
2793            self.assertAlmostEqual(iq.inv_cdf(iq.cdf(x)), x, places=5)
2794
2795        # Error cases:
2796        with self.assertRaises(self.module.StatisticsError):
2797            iq.inv_cdf(0.0)                         # p is zero
2798        with self.assertRaises(self.module.StatisticsError):
2799            iq.inv_cdf(-0.1)                        # p under zero
2800        with self.assertRaises(self.module.StatisticsError):
2801            iq.inv_cdf(1.0)                         # p is one
2802        with self.assertRaises(self.module.StatisticsError):
2803            iq.inv_cdf(1.1)                         # p over one
2804        with self.assertRaises(self.module.StatisticsError):
2805            iq = NormalDist(100, 0)                 # sigma is zero
2806            iq.inv_cdf(0.5)
2807
2808        # Special values
2809        self.assertTrue(math.isnan(Z.inv_cdf(float('NaN'))))
2810
2811    def test_quantiles(self):
2812        # Quartiles of a standard normal distribution
2813        Z = self.module.NormalDist()
2814        for n, expected in [
2815            (1, []),
2816            (2, [0.0]),
2817            (3, [-0.4307, 0.4307]),
2818            (4 ,[-0.6745, 0.0, 0.6745]),
2819                ]:
2820            actual = Z.quantiles(n=n)
2821            self.assertTrue(all(math.isclose(e, a, abs_tol=0.0001)
2822                            for e, a in zip(expected, actual)))
2823
2824    def test_overlap(self):
2825        NormalDist = self.module.NormalDist
2826
2827        # Match examples from Imman and Bradley
2828        for X1, X2, published_result in [
2829                (NormalDist(0.0, 2.0), NormalDist(1.0, 2.0), 0.80258),
2830                (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0), 0.60993),
2831            ]:
2832            self.assertAlmostEqual(X1.overlap(X2), published_result, places=4)
2833            self.assertAlmostEqual(X2.overlap(X1), published_result, places=4)
2834
2835        # Check against integration of the PDF
2836        def overlap_numeric(X, Y, *, steps=8_192, z=5):
2837            'Numerical integration cross-check for overlap() '
2838            fsum = math.fsum
2839            center = (X.mean + Y.mean) / 2.0
2840            width = z * max(X.stdev, Y.stdev)
2841            start = center - width
2842            dx = 2.0 * width / steps
2843            x_arr = [start + i*dx for i in range(steps)]
2844            xp = list(map(X.pdf, x_arr))
2845            yp = list(map(Y.pdf, x_arr))
2846            total = max(fsum(xp), fsum(yp))
2847            return fsum(map(min, xp, yp)) / total
2848
2849        for X1, X2 in [
2850                # Examples from Imman and Bradley
2851                (NormalDist(0.0, 2.0), NormalDist(1.0, 2.0)),
2852                (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0)),
2853                # Example from https://www.rasch.org/rmt/rmt101r.htm
2854                (NormalDist(0.0, 1.0), NormalDist(1.0, 2.0)),
2855                # Gender heights from http://www.usablestats.com/lessons/normal
2856                (NormalDist(70, 4), NormalDist(65, 3.5)),
2857                # Misc cases with equal standard deviations
2858                (NormalDist(100, 15), NormalDist(110, 15)),
2859                (NormalDist(-100, 15), NormalDist(110, 15)),
2860                (NormalDist(-100, 15), NormalDist(-110, 15)),
2861                # Misc cases with unequal standard deviations
2862                (NormalDist(100, 12), NormalDist(100, 15)),
2863                (NormalDist(100, 12), NormalDist(110, 15)),
2864                (NormalDist(100, 12), NormalDist(150, 15)),
2865                (NormalDist(100, 12), NormalDist(150, 35)),
2866                # Misc cases with small values
2867                (NormalDist(1.000, 0.002), NormalDist(1.001, 0.003)),
2868                (NormalDist(1.000, 0.002), NormalDist(1.006, 0.0003)),
2869                (NormalDist(1.000, 0.002), NormalDist(1.001, 0.099)),
2870            ]:
2871            self.assertAlmostEqual(X1.overlap(X2), overlap_numeric(X1, X2), places=5)
2872            self.assertAlmostEqual(X2.overlap(X1), overlap_numeric(X1, X2), places=5)
2873
2874        # Error cases
2875        X = NormalDist()
2876        with self.assertRaises(TypeError):
2877            X.overlap()                             # too few arguments
2878        with self.assertRaises(TypeError):
2879            X.overlap(X, X)                         # too may arguments
2880        with self.assertRaises(TypeError):
2881            X.overlap(None)                         # right operand not a NormalDist
2882        with self.assertRaises(self.module.StatisticsError):
2883            X.overlap(NormalDist(1, 0))             # right operand sigma is zero
2884        with self.assertRaises(self.module.StatisticsError):
2885            NormalDist(1, 0).overlap(X)             # left operand sigma is zero
2886
2887    def test_zscore(self):
2888        NormalDist = self.module.NormalDist
2889        X = NormalDist(100, 15)
2890        self.assertEqual(X.zscore(142), 2.8)
2891        self.assertEqual(X.zscore(58), -2.8)
2892        self.assertEqual(X.zscore(100), 0.0)
2893        with self.assertRaises(TypeError):
2894            X.zscore()                              # too few arguments
2895        with self.assertRaises(TypeError):
2896            X.zscore(1, 1)                          # too may arguments
2897        with self.assertRaises(TypeError):
2898            X.zscore(None)                          # non-numeric type
2899        with self.assertRaises(self.module.StatisticsError):
2900            NormalDist(1, 0).zscore(100)            # sigma is zero
2901
2902    def test_properties(self):
2903        X = self.module.NormalDist(100, 15)
2904        self.assertEqual(X.mean, 100)
2905        self.assertEqual(X.median, 100)
2906        self.assertEqual(X.mode, 100)
2907        self.assertEqual(X.stdev, 15)
2908        self.assertEqual(X.variance, 225)
2909
2910    def test_same_type_addition_and_subtraction(self):
2911        NormalDist = self.module.NormalDist
2912        X = NormalDist(100, 12)
2913        Y = NormalDist(40, 5)
2914        self.assertEqual(X + Y, NormalDist(140, 13))        # __add__
2915        self.assertEqual(X - Y, NormalDist(60, 13))         # __sub__
2916
2917    def test_translation_and_scaling(self):
2918        NormalDist = self.module.NormalDist
2919        X = NormalDist(100, 15)
2920        y = 10
2921        self.assertEqual(+X, NormalDist(100, 15))           # __pos__
2922        self.assertEqual(-X, NormalDist(-100, 15))          # __neg__
2923        self.assertEqual(X + y, NormalDist(110, 15))        # __add__
2924        self.assertEqual(y + X, NormalDist(110, 15))        # __radd__
2925        self.assertEqual(X - y, NormalDist(90, 15))         # __sub__
2926        self.assertEqual(y - X, NormalDist(-90, 15))        # __rsub__
2927        self.assertEqual(X * y, NormalDist(1000, 150))      # __mul__
2928        self.assertEqual(y * X, NormalDist(1000, 150))      # __rmul__
2929        self.assertEqual(X / y, NormalDist(10, 1.5))        # __truediv__
2930        with self.assertRaises(TypeError):                  # __rtruediv__
2931            y / X
2932
2933    def test_unary_operations(self):
2934        NormalDist = self.module.NormalDist
2935        X = NormalDist(100, 12)
2936        Y = +X
2937        self.assertIsNot(X, Y)
2938        self.assertEqual(X.mean, Y.mean)
2939        self.assertEqual(X.stdev, Y.stdev)
2940        Y = -X
2941        self.assertIsNot(X, Y)
2942        self.assertEqual(X.mean, -Y.mean)
2943        self.assertEqual(X.stdev, Y.stdev)
2944
2945    def test_equality(self):
2946        NormalDist = self.module.NormalDist
2947        nd1 = NormalDist()
2948        nd2 = NormalDist(2, 4)
2949        nd3 = NormalDist()
2950        nd4 = NormalDist(2, 4)
2951        nd5 = NormalDist(2, 8)
2952        nd6 = NormalDist(8, 4)
2953        self.assertNotEqual(nd1, nd2)
2954        self.assertEqual(nd1, nd3)
2955        self.assertEqual(nd2, nd4)
2956        self.assertNotEqual(nd2, nd5)
2957        self.assertNotEqual(nd2, nd6)
2958
2959        # Test NotImplemented when types are different
2960        class A:
2961            def __eq__(self, other):
2962                return 10
2963        a = A()
2964        self.assertEqual(nd1.__eq__(a), NotImplemented)
2965        self.assertEqual(nd1 == a, 10)
2966        self.assertEqual(a == nd1, 10)
2967
2968        # All subclasses to compare equal giving the same behavior
2969        # as list, tuple, int, float, complex, str, dict, set, etc.
2970        class SizedNormalDist(NormalDist):
2971            def __init__(self, mu, sigma, n):
2972                super().__init__(mu, sigma)
2973                self.n = n
2974        s = SizedNormalDist(100, 15, 57)
2975        nd4 = NormalDist(100, 15)
2976        self.assertEqual(s, nd4)
2977
2978        # Don't allow duck type equality because we wouldn't
2979        # want a lognormal distribution to compare equal
2980        # to a normal distribution with the same parameters
2981        class LognormalDist:
2982            def __init__(self, mu, sigma):
2983                self.mu = mu
2984                self.sigma = sigma
2985        lnd = LognormalDist(100, 15)
2986        nd = NormalDist(100, 15)
2987        self.assertNotEqual(nd, lnd)
2988
2989    def test_copy(self):
2990        nd = self.module.NormalDist(37.5, 5.625)
2991        nd1 = copy.copy(nd)
2992        self.assertEqual(nd, nd1)
2993        nd2 = copy.deepcopy(nd)
2994        self.assertEqual(nd, nd2)
2995
2996    def test_pickle(self):
2997        nd = self.module.NormalDist(37.5, 5.625)
2998        for proto in range(pickle.HIGHEST_PROTOCOL + 1):
2999            with self.subTest(proto=proto):
3000                pickled = pickle.loads(pickle.dumps(nd, protocol=proto))
3001                self.assertEqual(nd, pickled)
3002
3003    def test_hashability(self):
3004        ND = self.module.NormalDist
3005        s = {ND(100, 15), ND(100.0, 15.0), ND(100, 10), ND(95, 15), ND(100, 15)}
3006        self.assertEqual(len(s), 3)
3007
3008    def test_repr(self):
3009        nd = self.module.NormalDist(37.5, 5.625)
3010        self.assertEqual(repr(nd), 'NormalDist(mu=37.5, sigma=5.625)')
3011
3012# Swapping the sys.modules['statistics'] is to solving the
3013# _pickle.PicklingError:
3014# Can't pickle <class 'statistics.NormalDist'>:
3015# it's not the same object as statistics.NormalDist
3016class TestNormalDistPython(unittest.TestCase, TestNormalDist):
3017    module = py_statistics
3018    def setUp(self):
3019        sys.modules['statistics'] = self.module
3020
3021    def tearDown(self):
3022        sys.modules['statistics'] = statistics
3023
3024
3025@unittest.skipUnless(c_statistics, 'requires _statistics')
3026class TestNormalDistC(unittest.TestCase, TestNormalDist):
3027    module = c_statistics
3028    def setUp(self):
3029        sys.modules['statistics'] = self.module
3030
3031    def tearDown(self):
3032        sys.modules['statistics'] = statistics
3033
3034
3035# === Run tests ===
3036
3037def load_tests(loader, tests, ignore):
3038    """Used for doctest/unittest integration."""
3039    tests.addTests(doctest.DocTestSuite())
3040    return tests
3041
3042
3043if __name__ == "__main__":
3044    unittest.main()
3045