summaryrefslogtreecommitdiffstats
path: root/Lib/unittest
diff options
context:
space:
mode:
authorSerhiy Storchaka <storchaka@gmail.com>2025-01-14 08:02:38 (GMT)
committerGitHub <noreply@github.com>2025-01-14 08:02:38 (GMT)
commit06cad77a5b345adde88609be9c3c470c5cd9f417 (patch)
tree3c66069a2a22b4bffe221c3db5da638faad45ee9 /Lib/unittest
parent41f73501eca2ff8b42fa4811d918a81c052a758b (diff)
downloadcpython-06cad77a5b345adde88609be9c3c470c5cd9f417.zip
cpython-06cad77a5b345adde88609be9c3c470c5cd9f417.tar.gz
cpython-06cad77a5b345adde88609be9c3c470c5cd9f417.tar.bz2
gh-71339: Add additional assertion methods for unittest (GH-128707)
Add the following methods: * assertHasAttr() and assertNotHasAttr() * assertIsSubclass() and assertNotIsSubclass() * assertStartsWith() and assertNotStartsWith() * assertEndsWith() and assertNotEndsWith() Also improve error messages for assertIsInstance() and assertNotIsInstance().
Diffstat (limited to 'Lib/unittest')
-rw-r--r--Lib/unittest/case.py132
1 files changed, 130 insertions, 2 deletions
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py
index 55c79d3..e9ef551 100644
--- a/Lib/unittest/case.py
+++ b/Lib/unittest/case.py
@@ -1321,13 +1321,67 @@ class TestCase(object):
"""Same as self.assertTrue(isinstance(obj, cls)), with a nicer
default message."""
if not isinstance(obj, cls):
- standardMsg = '%s is not an instance of %r' % (safe_repr(obj), cls)
+ if isinstance(cls, tuple):
+ standardMsg = f'{safe_repr(obj)} is not an instance of any of {cls!r}'
+ else:
+ standardMsg = f'{safe_repr(obj)} is not an instance of {cls!r}'
self.fail(self._formatMessage(msg, standardMsg))
def assertNotIsInstance(self, obj, cls, msg=None):
"""Included for symmetry with assertIsInstance."""
if isinstance(obj, cls):
- standardMsg = '%s is an instance of %r' % (safe_repr(obj), cls)
+ if isinstance(cls, tuple):
+ for x in cls:
+ if isinstance(obj, x):
+ cls = x
+ break
+ standardMsg = f'{safe_repr(obj)} is an instance of {cls!r}'
+ self.fail(self._formatMessage(msg, standardMsg))
+
+ def assertIsSubclass(self, cls, superclass, msg=None):
+ try:
+ if issubclass(cls, superclass):
+ return
+ except TypeError:
+ if not isinstance(cls, type):
+ self.fail(self._formatMessage(msg, f'{cls!r} is not a class'))
+ raise
+ if isinstance(superclass, tuple):
+ standardMsg = f'{cls!r} is not a subclass of any of {superclass!r}'
+ else:
+ standardMsg = f'{cls!r} is not a subclass of {superclass!r}'
+ self.fail(self._formatMessage(msg, standardMsg))
+
+ def assertNotIsSubclass(self, cls, superclass, msg=None):
+ try:
+ if not issubclass(cls, superclass):
+ return
+ except TypeError:
+ if not isinstance(cls, type):
+ self.fail(self._formatMessage(msg, f'{cls!r} is not a class'))
+ raise
+ if isinstance(superclass, tuple):
+ for x in superclass:
+ if issubclass(cls, x):
+ superclass = x
+ break
+ standardMsg = f'{cls!r} is a subclass of {superclass!r}'
+ self.fail(self._formatMessage(msg, standardMsg))
+
+ def assertHasAttr(self, obj, name, msg=None):
+ if not hasattr(obj, name):
+ if isinstance(obj, types.ModuleType):
+ standardMsg = f'module {obj.__name__!r} has no attribute {name!r}'
+ else:
+ standardMsg = f'{type(obj).__name__} instance has no attribute {name!r}'
+ self.fail(self._formatMessage(msg, standardMsg))
+
+ def assertNotHasAttr(self, obj, name, msg=None):
+ if hasattr(obj, name):
+ if isinstance(obj, types.ModuleType):
+ standardMsg = f'module {obj.__name__!r} has unexpected attribute {name!r}'
+ else:
+ standardMsg = f'{type(obj).__name__} instance has unexpected attribute {name!r}'
self.fail(self._formatMessage(msg, standardMsg))
def assertRaisesRegex(self, expected_exception, expected_regex,
@@ -1391,6 +1445,80 @@ class TestCase(object):
msg = self._formatMessage(msg, standardMsg)
raise self.failureException(msg)
+ def _tail_type_check(self, s, tails, msg):
+ if not isinstance(tails, tuple):
+ tails = (tails,)
+ for tail in tails:
+ if isinstance(tail, str):
+ if not isinstance(s, str):
+ self.fail(self._formatMessage(msg,
+ f'Expected str, not {type(s).__name__}'))
+ elif isinstance(tail, (bytes, bytearray)):
+ if not isinstance(s, (bytes, bytearray)):
+ self.fail(self._formatMessage(msg,
+ f'Expected bytes, not {type(s).__name__}'))
+
+ def assertStartsWith(self, s, prefix, msg=None):
+ try:
+ if s.startswith(prefix):
+ return
+ except (AttributeError, TypeError):
+ self._tail_type_check(s, prefix, msg)
+ raise
+ a = safe_repr(s, short=True)
+ b = safe_repr(prefix)
+ if isinstance(prefix, tuple):
+ standardMsg = f"{a} doesn't start with any of {b}"
+ else:
+ standardMsg = f"{a} doesn't start with {b}"
+ self.fail(self._formatMessage(msg, standardMsg))
+
+ def assertNotStartsWith(self, s, prefix, msg=None):
+ try:
+ if not s.startswith(prefix):
+ return
+ except (AttributeError, TypeError):
+ self._tail_type_check(s, prefix, msg)
+ raise
+ if isinstance(prefix, tuple):
+ for x in prefix:
+ if s.startswith(x):
+ prefix = x
+ break
+ a = safe_repr(s, short=True)
+ b = safe_repr(prefix)
+ self.fail(self._formatMessage(msg, f"{a} starts with {b}"))
+
+ def assertEndsWith(self, s, suffix, msg=None):
+ try:
+ if s.endswith(suffix):
+ return
+ except (AttributeError, TypeError):
+ self._tail_type_check(s, suffix, msg)
+ raise
+ a = safe_repr(s, short=True)
+ b = safe_repr(suffix)
+ if isinstance(suffix, tuple):
+ standardMsg = f"{a} doesn't end with any of {b}"
+ else:
+ standardMsg = f"{a} doesn't end with {b}"
+ self.fail(self._formatMessage(msg, standardMsg))
+
+ def assertNotEndsWith(self, s, suffix, msg=None):
+ try:
+ if not s.endswith(suffix):
+ return
+ except (AttributeError, TypeError):
+ self._tail_type_check(s, suffix, msg)
+ raise
+ if isinstance(suffix, tuple):
+ for x in suffix:
+ if s.endswith(x):
+ suffix = x
+ break
+ a = safe_repr(s, short=True)
+ b = safe_repr(suffix)
+ self.fail(self._formatMessage(msg, f"{a} ends with {b}"))
class FunctionTestCase(TestCase):