summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMiss Islington (bot) <31488909+miss-islington@users.noreply.github.com>2023-01-08 20:04:49 (GMT)
committerGitHub <noreply@github.com>2023-01-08 20:04:49 (GMT)
commit6184b800ee403d8c8c93b94fed1dbc3e0bade336 (patch)
tree12c34bd7c48b347eb38143c3cb39bb2de60e1e1d
parent6c7e32f6a82fac22d52b4b27330286bd8982cc75 (diff)
downloadcpython-6184b800ee403d8c8c93b94fed1dbc3e0bade336.zip
cpython-6184b800ee403d8c8c93b94fed1dbc3e0bade336.tar.gz
cpython-6184b800ee403d8c8c93b94fed1dbc3e0bade336.tar.bz2
GH-100805: Support numpy.array() in random.choice(). (GH-100830)
(cherry picked from commit 9a68ff12c3e647a4f8dd935919ae296593770a6b) Co-authored-by: Raymond Hettinger <rhettinger@users.noreply.github.com>
-rw-r--r--Lib/random.py5
-rw-r--r--Lib/test/test_random.py15
-rw-r--r--Misc/NEWS.d/next/Library/2023-01-07-15-13-47.gh-issue-100805.05rBz9.rst2
3 files changed, 21 insertions, 1 deletions
diff --git a/Lib/random.py b/Lib/random.py
index f94616e..22dcb4d 100644
--- a/Lib/random.py
+++ b/Lib/random.py
@@ -366,7 +366,10 @@ class Random(_random.Random):
def choice(self, seq):
"""Choose a random element from a non-empty sequence."""
- if not seq:
+
+ # As an accommodation for NumPy, we don't use "if not seq"
+ # because bool(numpy.array()) raises a ValueError.
+ if not len(seq):
raise IndexError('Cannot choose from an empty sequence')
return seq[self._randbelow(len(seq))]
diff --git a/Lib/test/test_random.py b/Lib/test/test_random.py
index 32e7868..f32d592 100644
--- a/Lib/test/test_random.py
+++ b/Lib/test/test_random.py
@@ -111,6 +111,21 @@ class TestBasicOps:
self.assertEqual(choice([50]), 50)
self.assertIn(choice([25, 75]), [25, 75])
+ def test_choice_with_numpy(self):
+ # Accommodation for NumPy arrays which have disabled __bool__().
+ # See: https://github.com/python/cpython/issues/100805
+ choice = self.gen.choice
+
+ class NA(list):
+ "Simulate numpy.array() behavior"
+ def __bool__(self):
+ raise RuntimeError
+
+ with self.assertRaises(IndexError):
+ choice(NA([]))
+ self.assertEqual(choice(NA([50])), 50)
+ self.assertIn(choice(NA([25, 75])), [25, 75])
+
def test_sample(self):
# For the entire allowable range of 0 <= k <= N, validate that
# the sample is of the correct length and contains only unique items
diff --git a/Misc/NEWS.d/next/Library/2023-01-07-15-13-47.gh-issue-100805.05rBz9.rst b/Misc/NEWS.d/next/Library/2023-01-07-15-13-47.gh-issue-100805.05rBz9.rst
new file mode 100644
index 0000000..4424d7c
--- /dev/null
+++ b/Misc/NEWS.d/next/Library/2023-01-07-15-13-47.gh-issue-100805.05rBz9.rst
@@ -0,0 +1,2 @@
+Modify :func:`random.choice` implementation to once again work with NumPy
+arrays.