summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_index.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_index.py')
-rw-r--r--Lib/test/test_index.py66
1 files changed, 43 insertions, 23 deletions
diff --git a/Lib/test/test_index.py b/Lib/test/test_index.py
index 7a94af1..8785711 100644
--- a/Lib/test/test_index.py
+++ b/Lib/test/test_index.py
@@ -9,7 +9,7 @@ class newstyle:
class TrapInt(int):
def __index__(self):
- return self
+ return int(self)
class BaseTestCase(unittest.TestCase):
def setUp(self):
@@ -55,8 +55,41 @@ class BaseTestCase(unittest.TestCase):
self.assertRaises(TypeError, slice(self.o).indices, 0)
self.assertRaises(TypeError, slice(self.n).indices, 0)
-
-class SeqTestCase(unittest.TestCase):
+ def test_int_subclass_with_index(self):
+ # __index__ should be used when computing indices, even for int
+ # subclasses. See issue #17576.
+ class MyInt(int):
+ def __index__(self):
+ return int(self) + 1
+
+ my_int = MyInt(7)
+ direct_index = my_int.__index__()
+ operator_index = operator.index(my_int)
+ self.assertEqual(direct_index, 8)
+ self.assertEqual(operator_index, 7)
+ # Both results should be of exact type int.
+ self.assertIs(type(direct_index), int)
+ #self.assertIs(type(operator_index), int)
+
+ def test_index_returns_int_subclass(self):
+ class BadInt:
+ def __index__(self):
+ return True
+
+ class BadInt2(int):
+ def __index__(self):
+ return True
+
+ bad_int = BadInt()
+ n = operator.index(bad_int)
+ self.assertEqual(n, 1)
+
+ bad_int = BadInt2()
+ n = operator.index(bad_int)
+ self.assertEqual(n, 0)
+
+
+class SeqTestCase:
# This test case isn't run directly. It just defines common tests
# to the different sequence types below
def setUp(self):
@@ -126,7 +159,7 @@ class SeqTestCase(unittest.TestCase):
self.assertRaises(TypeError, sliceobj, self.n, self)
-class ListTestCase(SeqTestCase):
+class ListTestCase(SeqTestCase, unittest.TestCase):
seq = [0,10,20,30,40,50]
def test_setdelitem(self):
@@ -182,19 +215,19 @@ class NewSeq:
return self._list[index]
-class TupleTestCase(SeqTestCase):
+class TupleTestCase(SeqTestCase, unittest.TestCase):
seq = (0,10,20,30,40,50)
-class ByteArrayTestCase(SeqTestCase):
+class ByteArrayTestCase(SeqTestCase, unittest.TestCase):
seq = bytearray(b"this is a test")
-class BytesTestCase(SeqTestCase):
+class BytesTestCase(SeqTestCase, unittest.TestCase):
seq = b"this is a test"
-class StringTestCase(SeqTestCase):
+class StringTestCase(SeqTestCase, unittest.TestCase):
seq = "this is a test"
-class NewSeqTestCase(SeqTestCase):
+class NewSeqTestCase(SeqTestCase, unittest.TestCase):
seq = NewSeq((0,10,20,30,40,50))
@@ -237,18 +270,5 @@ class OverflowTestCase(unittest.TestCase):
self.assertRaises(OverflowError, lambda: "a" * self.neg)
-def test_main():
- support.run_unittest(
- BaseTestCase,
- ListTestCase,
- TupleTestCase,
- BytesTestCase,
- ByteArrayTestCase,
- StringTestCase,
- NewSeqTestCase,
- RangeTestCase,
- OverflowTestCase,
- )
-
if __name__ == "__main__":
- test_main()
+ unittest.main()