diff options
author | Raymond Hettinger <rhettinger@users.noreply.github.com> | 2022-05-05 08:01:07 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-05-05 08:01:07 (GMT) |
commit | 5212cbc2618bd4390c4b768f1c65c28fa6b595a0 (patch) | |
tree | c044284a9d45dd062611afb83f1636e61250d4fd | |
parent | b885b8f4be9c74ef1ce7923dbf055c31e7f47735 (diff) | |
download | cpython-5212cbc2618bd4390c4b768f1c65c28fa6b595a0.zip cpython-5212cbc2618bd4390c4b768f1c65c28fa6b595a0.tar.gz cpython-5212cbc2618bd4390c4b768f1c65c28fa6b595a0.tar.bz2 |
Clean-up and simplify median_grouped(). Vastly improve its docstring. (#92324)
-rw-r--r-- | Lib/statistics.py | 106 | ||||
-rw-r--r-- | Lib/test/test_statistics.py | 44 |
2 files changed, 54 insertions, 96 deletions
diff --git a/Lib/statistics.py b/Lib/statistics.py index c022088..54f4e13 100644 --- a/Lib/statistics.py +++ b/Lib/statistics.py @@ -348,22 +348,6 @@ def _convert(value, T): raise -def _find_lteq(a, x): - 'Locate the leftmost value exactly equal to x' - i = bisect_left(a, x) - if i != len(a) and a[i] == x: - return i - raise ValueError - - -def _find_rteq(a, l, x): - 'Locate the rightmost value exactly equal to x' - i = bisect_right(a, x, lo=l) - if i != (len(a) + 1) and a[i - 1] == x: - return i - 1 - raise ValueError - - def _fail_neg(values, errmsg='negative value'): """Iterate over values, failing if any are less than zero.""" for x in values: @@ -628,30 +612,44 @@ def median_high(data): def median_grouped(data, interval=1): - """Return the 50th percentile (median) of grouped continuous data. - - >>> median_grouped([1, 2, 2, 3, 4, 4, 4, 4, 4, 5]) - 3.7 - >>> median_grouped([52, 52, 53, 54]) - 52.5 - - This calculates the median as the 50th percentile, and should be - used when your data is continuous and grouped. In the above example, - the values 1, 2, 3, etc. actually represent the midpoint of classes - 0.5-1.5, 1.5-2.5, 2.5-3.5, etc. The middle value falls somewhere in - class 3.5-4.5, and interpolation is used to estimate it. - - Optional argument ``interval`` represents the class interval, and - defaults to 1. Changing the class interval naturally will change the - interpolated 50th percentile value: - - >>> median_grouped([1, 3, 3, 5, 7], interval=1) - 3.25 - >>> median_grouped([1, 3, 3, 5, 7], interval=2) - 3.5 - - This function does not check whether the data points are at least - ``interval`` apart. + """Estimates the median for numeric data binned around the midpoints + of consecutive, fixed-width intervals. + + The *data* can be any iterable of numeric data with each value being + exactly the midpoint of a bin. At least one value must be present. + + The *interval* is width of each bin. + + For example, demographic information may have been summarized into + consecutive ten-year age groups with each group being represented + by the 5-year midpoints of the intervals: + + >>> demographics = Counter({ + ... 25: 172, # 20 to 30 years old + ... 35: 484, # 30 to 40 years old + ... 45: 387, # 40 to 50 years old + ... 55: 22, # 50 to 60 years old + ... 65: 6, # 60 to 70 years old + ... }) + + The 50th percentile (median) is the 536th person out of the 1071 + member cohort. That person is in the 30 to 40 year old age group. + + The regular median() function would assume that everyone in the + tricenarian age group was exactly 35 years old. A more tenable + assumption is that the 484 members of that age group are evenly + distributed between 30 and 40. For that, we use median_grouped(). + + >>> data = list(demographics.elements()) + >>> median(data) + 35 + >>> round(median_grouped(data, interval=10), 1) + 37.5 + + The caller is responsible for making sure the data points are separated + by exact multiples of *interval*. This is essential for getting a + correct result. The function does not check this precondition. + """ data = sorted(data) n = len(data) @@ -659,26 +657,30 @@ def median_grouped(data, interval=1): raise StatisticsError("no median for empty data") elif n == 1: return data[0] + # Find the value at the midpoint. Remember this corresponds to the - # centre of the class interval. + # midpoint of the class interval. x = data[n // 2] + + # Generate a clear error message for non-numeric data for obj in (x, interval): if isinstance(obj, (str, bytes)): - raise TypeError('expected number but got %r' % obj) + raise TypeError(f'expected a number but got {obj!r}') + + # Using O(log n) bisection, find where all the x values occur in the data. + # All x will lie within data[i:j]. + i = bisect_left(data, x) + j = bisect_right(data, x, lo=i) + + # Interpolate the median using the formula found at: + # https://www.cuemath.com/data/median-of-grouped-data/ try: L = x - interval / 2 # The lower limit of the median interval. except TypeError: - # Mixed type. For now we just coerce to float. + # Coerce mixed types to float. L = float(x) - float(interval) / 2 - - # Uses bisection search to search for x in data with log(n) time complexity - # Find the position of leftmost occurrence of x in data - l1 = _find_lteq(data, x) - # Find the position of rightmost occurrence of x in data[l1...len(data)] - # Assuming always l1 <= l2 - l2 = _find_rteq(data, l1, x) - cf = l1 - f = l2 - l1 + 1 + cf = i # Cumulative frequency of the preceding interval + f = j - i # Number of elements in the median internal return L + interval * (n / 2 - cf) / f diff --git a/Lib/test/test_statistics.py b/Lib/test/test_statistics.py index bacb76a..ed6021d 100644 --- a/Lib/test/test_statistics.py +++ b/Lib/test/test_statistics.py @@ -1040,50 +1040,6 @@ class FailNegTest(unittest.TestCase): self.assertEqual(errmsg, msg) -class FindLteqTest(unittest.TestCase): - # Test _find_lteq private function. - - def test_invalid_input_values(self): - for a, x in [ - ([], 1), - ([1, 2], 3), - ([1, 3], 2) - ]: - with self.subTest(a=a, x=x): - with self.assertRaises(ValueError): - statistics._find_lteq(a, x) - - def test_locate_successfully(self): - for a, x, expected_i in [ - ([1, 1, 1, 2, 3], 1, 0), - ([0, 1, 1, 1, 2, 3], 1, 1), - ([1, 2, 3, 3, 3], 3, 2) - ]: - with self.subTest(a=a, x=x): - self.assertEqual(expected_i, statistics._find_lteq(a, x)) - - -class FindRteqTest(unittest.TestCase): - # Test _find_rteq private function. - - def test_invalid_input_values(self): - for a, l, x in [ - ([1], 2, 1), - ([1, 3], 0, 2) - ]: - with self.assertRaises(ValueError): - statistics._find_rteq(a, l, x) - - def test_locate_successfully(self): - for a, l, x, expected_i in [ - ([1, 1, 1, 2, 3], 0, 1, 2), - ([0, 1, 1, 1, 2, 3], 0, 1, 3), - ([1, 2, 3, 3, 3], 0, 3, 4) - ]: - with self.subTest(a=a, l=l, x=x): - self.assertEqual(expected_i, statistics._find_rteq(a, l, x)) - - # === Tests for public functions === class UnivariateCommonMixin: |