summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorRaymond Hettinger <rhettinger@users.noreply.github.com>2019-03-19 03:17:14 (GMT)
committerGitHub <noreply@github.com>2019-03-19 03:17:14 (GMT)
commit714c60d7aca6d0f6d73ad2d7c876d2d683a7fce3 (patch)
tree7f807f36c486ef3bd660e829693c524757567c77 /Lib
parentfaddaedd05ca81a9fed3f315e7bc8dcf455824a2 (diff)
downloadcpython-714c60d7aca6d0f6d73ad2d7c876d2d683a7fce3.zip
cpython-714c60d7aca6d0f6d73ad2d7c876d2d683a7fce3.tar.gz
cpython-714c60d7aca6d0f6d73ad2d7c876d2d683a7fce3.tar.bz2
bpo-36324: Add inv_cdf() to statistics.NormalDist() (GH-12377)
Diffstat (limited to 'Lib')
-rw-r--r--Lib/statistics.py95
-rw-r--r--Lib/test/test_statistics.py63
2 files changed, 158 insertions, 0 deletions
diff --git a/Lib/statistics.py b/Lib/statistics.py
index 8d79eed..fe68e58 100644
--- a/Lib/statistics.py
+++ b/Lib/statistics.py
@@ -745,6 +745,101 @@ class NormalDist:
raise StatisticsError('cdf() not defined when sigma is zero')
return 0.5 * (1.0 + erf((x - self.mu) / (self.sigma * sqrt(2.0))))
+ def inv_cdf(self, p):
+ ''' Inverse cumulative distribution function: x : P(X <= x) = p
+
+ Finds the value of the random variable such that the probability of the
+ variable being less than or equal to that value equals the given probability.
+
+ This function is also called the percent-point function or quantile function.
+
+ '''
+ if (p <= 0.0 or p >= 1.0):
+ raise StatisticsError('p must be in the range 0.0 < p < 1.0')
+ if self.sigma <= 0.0:
+ raise StatisticsError('cdf() not defined when sigma at or below zero')
+
+ # There is no closed-form solution to the inverse CDF for the normal
+ # distribution, so we use a rational approximation instead:
+ # Wichura, M.J. (1988). "Algorithm AS241: The Percentage Points of the
+ # Normal Distribution". Applied Statistics. Blackwell Publishing. 37
+ # (3): 477–484. doi:10.2307/2347330. JSTOR 2347330.
+
+ q = p - 0.5
+ if fabs(q) <= 0.425:
+ a0 = 3.38713_28727_96366_6080e+0
+ a1 = 1.33141_66789_17843_7745e+2
+ a2 = 1.97159_09503_06551_4427e+3
+ a3 = 1.37316_93765_50946_1125e+4
+ a4 = 4.59219_53931_54987_1457e+4
+ a5 = 6.72657_70927_00870_0853e+4
+ a6 = 3.34305_75583_58812_8105e+4
+ a7 = 2.50908_09287_30122_6727e+3
+ b1 = 4.23133_30701_60091_1252e+1
+ b2 = 6.87187_00749_20579_0830e+2
+ b3 = 5.39419_60214_24751_1077e+3
+ b4 = 2.12137_94301_58659_5867e+4
+ b5 = 3.93078_95800_09271_0610e+4
+ b6 = 2.87290_85735_72194_2674e+4
+ b7 = 5.22649_52788_52854_5610e+3
+ r = 0.180625 - q * q
+ num = (q * (((((((a7 * r + a6) * r + a5) * r + a4) * r + a3)
+ * r + a2) * r + a1) * r + a0))
+ den = ((((((((b7 * r + b6) * r + b5) * r + b4) * r + b3)
+ * r + b2) * r + b1) * r + 1.0))
+ x = num / den
+ return self.mu + (x * self.sigma)
+
+ r = p if q <= 0.0 else 1.0 - p
+ r = sqrt(-log(r))
+ if r <= 5.0:
+ c0 = 1.42343_71107_49683_57734e+0
+ c1 = 4.63033_78461_56545_29590e+0
+ c2 = 5.76949_72214_60691_40550e+0
+ c3 = 3.64784_83247_63204_60504e+0
+ c4 = 1.27045_82524_52368_38258e+0
+ c5 = 2.41780_72517_74506_11770e-1
+ c6 = 2.27238_44989_26918_45833e-2
+ c7 = 7.74545_01427_83414_07640e-4
+ d1 = 2.05319_16266_37758_82187e+0
+ d2 = 1.67638_48301_83803_84940e+0
+ d3 = 6.89767_33498_51000_04550e-1
+ d4 = 1.48103_97642_74800_74590e-1
+ d5 = 1.51986_66563_61645_71966e-2
+ d6 = 5.47593_80849_95344_94600e-4
+ d7 = 1.05075_00716_44416_84324e-9
+ r = r - 1.6
+ num = ((((((((c7 * r + c6) * r + c5) * r + c4) * r + c3)
+ * r + c2) * r + c1) * r + c0))
+ den = ((((((((d7 * r + d6) * r + d5) * r + d4) * r + d3)
+ * r + d2) * r + d1) * r + 1.0))
+ else:
+ e0 = 6.65790_46435_01103_77720e+0
+ e1 = 5.46378_49111_64114_36990e+0
+ e2 = 1.78482_65399_17291_33580e+0
+ e3 = 2.96560_57182_85048_91230e-1
+ e4 = 2.65321_89526_57612_30930e-2
+ e5 = 1.24266_09473_88078_43860e-3
+ e6 = 2.71155_55687_43487_57815e-5
+ e7 = 2.01033_43992_92288_13265e-7
+ f1 = 5.99832_20655_58879_37690e-1
+ f2 = 1.36929_88092_27358_05310e-1
+ f3 = 1.48753_61290_85061_48525e-2
+ f4 = 7.86869_13114_56132_59100e-4
+ f5 = 1.84631_83175_10054_68180e-5
+ f6 = 1.42151_17583_16445_88870e-7
+ f7 = 2.04426_31033_89939_78564e-15
+ r = r - 5.0
+ num = ((((((((e7 * r + e6) * r + e5) * r + e4) * r + e3)
+ * r + e2) * r + e1) * r + e0))
+ den = ((((((((f7 * r + f6) * r + f5) * r + f4) * r + f3)
+ * r + f2) * r + f1) * r + 1.0))
+
+ x = num / den
+ if q < 0.0:
+ x = -x
+ return self.mu + (x * self.sigma)
+
def overlap(self, other):
'''Compute the overlapping coefficient (OVL) between two normal distributions.
diff --git a/Lib/test/test_statistics.py b/Lib/test/test_statistics.py
index 26b22a1..02cbebd 100644
--- a/Lib/test/test_statistics.py
+++ b/Lib/test/test_statistics.py
@@ -2174,6 +2174,69 @@ class TestNormalDist(unittest.TestCase):
self.assertEqual(X.cdf(float('Inf')), 1.0)
self.assertTrue(math.isnan(X.cdf(float('NaN'))))
+ def test_inv_cdf(self):
+ NormalDist = statistics.NormalDist
+
+ # Center case should be exact.
+ iq = NormalDist(100, 15)
+ self.assertEqual(iq.inv_cdf(0.50), iq.mean)
+
+ # Test versus a published table of known percentage points.
+ # See the second table at the bottom of the page here:
+ # http://people.bath.ac.uk/masss/tables/normaltable.pdf
+ Z = NormalDist()
+ pp = {5.0: (0.000, 1.645, 2.576, 3.291, 3.891,
+ 4.417, 4.892, 5.327, 5.731, 6.109),
+ 2.5: (0.674, 1.960, 2.807, 3.481, 4.056,
+ 4.565, 5.026, 5.451, 5.847, 6.219),
+ 1.0: (1.282, 2.326, 3.090, 3.719, 4.265,
+ 4.753, 5.199, 5.612, 5.998, 6.361)}
+ for base, row in pp.items():
+ for exp, x in enumerate(row, start=1):
+ p = base * 10.0 ** (-exp)
+ self.assertAlmostEqual(-Z.inv_cdf(p), x, places=3)
+ p = 1.0 - p
+ self.assertAlmostEqual(Z.inv_cdf(p), x, places=3)
+
+ # Match published example for MS Excel
+ # https://support.office.com/en-us/article/norm-inv-function-54b30935-fee7-493c-bedb-2278a9db7e13
+ self.assertAlmostEqual(NormalDist(40, 1.5).inv_cdf(0.908789), 42.000002)
+
+ # One million equally spaced probabilities
+ n = 2**20
+ for p in range(1, n):
+ p /= n
+ self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p)
+
+ # One hundred ever smaller probabilities to test tails out to
+ # extreme probabilities: 1 / 2**50 and (2**50-1) / 2 ** 50
+ for e in range(1, 51):
+ p = 2.0 ** (-e)
+ self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p)
+ p = 1.0 - p
+ self.assertAlmostEqual(iq.cdf(iq.inv_cdf(p)), p)
+
+ # Now apply cdf() first. At six sigmas, the round-trip
+ # loses a lot of precision, so only check to 6 places.
+ for x in range(10, 190):
+ self.assertAlmostEqual(iq.inv_cdf(iq.cdf(x)), x, places=6)
+
+ # Error cases:
+ with self.assertRaises(statistics.StatisticsError):
+ iq.inv_cdf(0.0) # p is zero
+ with self.assertRaises(statistics.StatisticsError):
+ iq.inv_cdf(-0.1) # p under zero
+ with self.assertRaises(statistics.StatisticsError):
+ iq.inv_cdf(1.0) # p is one
+ with self.assertRaises(statistics.StatisticsError):
+ iq.inv_cdf(1.1) # p over one
+ with self.assertRaises(statistics.StatisticsError):
+ iq.sigma = 0.0 # sigma is zero
+ iq.inv_cdf(0.5)
+ with self.assertRaises(statistics.StatisticsError):
+ iq.sigma = -0.1 # sigma under zero
+ iq.inv_cdf(0.5)
+
def test_overlap(self):
NormalDist = statistics.NormalDist