summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/Bastion.py2
-rw-r--r--Lib/CGIHTTPServer.py5
-rw-r--r--Lib/ConfigParser.py13
-rw-r--r--Lib/DocXMLRPCServer.py6
-rw-r--r--Lib/HTMLParser.py30
-rw-r--r--Lib/SimpleXMLRPCServer.py4
-rw-r--r--Lib/SocketServer.py10
-rw-r--r--Lib/_strptime.py9
-rw-r--r--Lib/bisect.py8
-rw-r--r--Lib/cmd.py7
-rw-r--r--Lib/collections.py62
-rw-r--r--Lib/commands.py2
-rw-r--r--Lib/copy_reg.py3
-rw-r--r--Lib/csv.py7
-rw-r--r--Lib/ctypes/__init__.py7
-rw-r--r--Lib/ctypes/test/test_functions.py14
-rw-r--r--Lib/ctypes/test/test_memfunctions.py2
-rw-r--r--Lib/ctypes/test/test_numbers.py31
-rw-r--r--Lib/ctypes/test/test_repr.py2
-rw-r--r--Lib/decimal.py333
-rw-r--r--Lib/difflib.py9
-rw-r--r--Lib/distutils/__init__.py2
-rw-r--r--Lib/distutils/command/build_ext.py7
-rw-r--r--Lib/distutils/msvccompiler.py16
-rw-r--r--Lib/doctest.py19
-rw-r--r--Lib/email/_parseaddr.py5
-rw-r--r--Lib/email/header.py3
-rw-r--r--Lib/email/message.py6
-rw-r--r--Lib/email/test/test_email.py27
-rw-r--r--Lib/email/test/test_email_renamed.py20
-rw-r--r--Lib/ftplib.py51
-rw-r--r--Lib/genericpath.py29
-rw-r--r--Lib/glob.py12
-rw-r--r--Lib/heapq.py2
-rw-r--r--Lib/httplib.py28
-rw-r--r--Lib/idlelib/MultiCall.py2
-rw-r--r--Lib/imaplib.py12
-rw-r--r--Lib/logging/handlers.py4
-rw-r--r--Lib/macpath.py14
-rw-r--r--Lib/ntpath.py99
-rw-r--r--Lib/os.py10
-rw-r--r--Lib/pdb.doc6
-rwxr-xr-xLib/pdb.py72
-rw-r--r--Lib/popen2.py42
-rw-r--r--Lib/poplib.py18
-rw-r--r--Lib/posixpath.py28
-rwxr-xr-xLib/pydoc.py49
-rw-r--r--Lib/rexec.py4
-rw-r--r--Lib/robotparser.py15
-rw-r--r--Lib/sched.py2
-rw-r--r--Lib/site.py2
-rwxr-xr-xLib/smtplib.py84
-rw-r--r--Lib/socket.py33
-rw-r--r--Lib/sre.py3
-rw-r--r--Lib/subprocess.py4
-rw-r--r--Lib/tarfile.py980
-rw-r--r--Lib/telnetlib.py26
-rw-r--r--Lib/tempfile.py131
-rw-r--r--Lib/test/README47
-rw-r--r--Lib/test/crashers/modify_dict_attr.py9
-rw-r--r--Lib/test/infinite_reload.py7
-rw-r--r--Lib/test/output/test_operations19
-rw-r--r--Lib/test/output/test_popen29
-rw-r--r--Lib/test/output/test_pty3
-rw-r--r--Lib/test/output/test_pyexpat110
-rw-r--r--Lib/test/output/test_threadedtempfile5
-rw-r--r--Lib/test/output/xmltests364
-rw-r--r--Lib/test/outstanding_bugs.py37
-rw-r--r--Lib/test/pickletester.py32
-rwxr-xr-xLib/test/regrtest.py2
-rw-r--r--Lib/test/ssl_cert.pem14
-rw-r--r--Lib/test/ssl_key.pem9
-rw-r--r--Lib/test/string_tests.py3
-rw-r--r--Lib/test/test___all__.py12
-rwxr-xr-xLib/test/test_array.py16
-rw-r--r--Lib/test/test_atexit.py32
-rw-r--r--Lib/test/test_base64.py12
-rwxr-xr-xLib/test/test_binascii.py9
-rw-r--r--Lib/test/test_bool.py2
-rw-r--r--Lib/test/test_bsddb3.py8
-rw-r--r--Lib/test/test_builtin.py18
-rw-r--r--Lib/test/test_cfgparser.py12
-rwxr-xr-xLib/test/test_cmath.py246
-rw-r--r--Lib/test/test_cmd_line.py2
-rw-r--r--Lib/test/test_codecencodings_cn.py6
-rw-r--r--Lib/test/test_codecencodings_hk.py4
-rw-r--r--Lib/test/test_codecencodings_jp.py8
-rw-r--r--Lib/test/test_codecencodings_kr.py6
-rw-r--r--Lib/test/test_codecencodings_tw.py4
-rw-r--r--Lib/test/test_codecmaps_cn.py5
-rw-r--r--Lib/test/test_codecmaps_hk.py4
-rw-r--r--Lib/test/test_codecmaps_jp.py8
-rw-r--r--Lib/test/test_codecmaps_kr.py6
-rw-r--r--Lib/test/test_codecmaps_tw.py5
-rw-r--r--Lib/test/test_collections.py57
-rw-r--r--Lib/test/test_commands.py4
-rw-r--r--Lib/test/test_compile.py17
-rw-r--r--Lib/test/test_compiler.py4
-rw-r--r--Lib/test/test_complex.py12
-rw-r--r--Lib/test/test_contextlib.py6
-rwxr-xr-xLib/test/test_crypt.py2
-rw-r--r--Lib/test/test_csv.py4
-rw-r--r--Lib/test/test_ctypes.py4
-rw-r--r--Lib/test/test_curses.py15
-rw-r--r--Lib/test/test_datetime.py69
-rw-r--r--Lib/test/test_defaultdict.py9
-rw-r--r--Lib/test/test_deque.py2
-rw-r--r--Lib/test/test_descr.py131
-rw-r--r--Lib/test/test_descrtut.py2
-rw-r--r--Lib/test/test_dict.py76
-rw-r--r--Lib/test/test_dis.py8
-rw-r--r--Lib/test/test_doctest.py10
-rw-r--r--Lib/test/test_email.py4
-rw-r--r--Lib/test/test_email_codecs.py2
-rw-r--r--Lib/test/test_email_renamed.py4
-rw-r--r--Lib/test/test_exceptions.py5
-rw-r--r--Lib/test/test_fileinput.py397
-rw-r--r--Lib/test/test_fileio.py4
-rw-r--r--Lib/test/test_ftplib.py93
-rw-r--r--Lib/test/test_functools.py6
-rw-r--r--Lib/test/test_gc.py1132
-rw-r--r--Lib/test/test_getopt.py347
-rw-r--r--Lib/test/test_gettext.py15
-rw-r--r--Lib/test/test_glob.py10
-rw-r--r--Lib/test/test_grammar.py2
-rwxr-xr-xLib/test/test_htmlparser.py5
-rw-r--r--Lib/test/test_httplib.py47
-rw-r--r--Lib/test/test_import.py10
-rw-r--r--Lib/test/test_itertools.py18
-rw-r--r--Lib/test/test_keywordonlyarg.py2
-rw-r--r--Lib/test/test_locale.py2
-rw-r--r--Lib/test/test_logging.py2
-rw-r--r--Lib/test/test_long_future.py100
-rw-r--r--Lib/test/test_macpath.py2
-rw-r--r--Lib/test/test_mailbox.py5
-rw-r--r--Lib/test/test_metaclass.py2
-rw-r--r--Lib/test/test_minidom.py2594
-rw-r--r--Lib/test/test_module.py103
-rw-r--r--Lib/test/test_multibytecodec.py8
-rw-r--r--Lib/test/test_normalization.py118
-rw-r--r--Lib/test/test_ntpath.py20
-rw-r--r--Lib/test/test_operations.py77
-rw-r--r--Lib/test/test_operator.py6
-rw-r--r--Lib/test/test_optparse.py14
-rw-r--r--Lib/test/test_os.py78
-rw-r--r--Lib/test/test_ossaudiodev.py247
-rw-r--r--Lib/test/test_peepholer.py11
-rw-r--r--Lib/test/test_popen2.py150
-rw-r--r--Lib/test/test_poplib.py71
-rw-r--r--Lib/test/test_posixpath.py208
-rw-r--r--Lib/test/test_pty.py256
-rw-r--r--Lib/test/test_pyexpat.py677
-rw-r--r--Lib/test/test_re.py15
-rw-r--r--Lib/test/test_robotparser.py4
-rw-r--r--Lib/test/test_sax.py1193
-rw-r--r--Lib/test/test_scope.py35
-rw-r--r--Lib/test/test_set.py11
-rw-r--r--Lib/test/test_slice.py19
-rw-r--r--Lib/test/test_smtplib.py71
-rw-r--r--Lib/test/test_socket.py106
-rw-r--r--Lib/test/test_socket_ssl.py299
-rw-r--r--Lib/test/test_socketserver.py11
-rw-r--r--Lib/test/test_stringprep.py130
-rw-r--r--Lib/test/test_strptime.py29
-rw-r--r--Lib/test/test_struct.py24
-rw-r--r--Lib/test/test_structmembers.py39
-rw-r--r--Lib/test/test_support.py125
-rw-r--r--Lib/test/test_syntax.py55
-rw-r--r--Lib/test/test_tarfile.py1342
-rw-r--r--Lib/test/test_telnetlib.py74
-rw-r--r--Lib/test/test_tempfile.py125
-rw-r--r--Lib/test/test_textwrap.py8
-rw-r--r--Lib/test/test_threadedtempfile.py79
-rw-r--r--Lib/test/test_threading_local.py2
-rw-r--r--Lib/test/test_unicode.py2
-rw-r--r--Lib/test/test_unicode_file.py6
-rw-r--r--Lib/test/test_unittest.py2303
-rw-r--r--Lib/test/test_unpack.py2
-rw-r--r--Lib/test/test_urllib.py9
-rw-r--r--Lib/test/test_urllib2.py24
-rw-r--r--Lib/test/test_urllib2net.py3
-rw-r--r--Lib/test/test_userdict.py6
-rw-r--r--Lib/test/test_warnings.py146
-rw-r--r--Lib/test/test_weakref.py4
-rwxr-xr-xLib/test/test_wsgiref.py9
-rw-r--r--Lib/test/test_zipfile.py433
-rw-r--r--Lib/test/test_zlib.py34
-rw-r--r--Lib/test/testtar.tarbin133120 -> 256000 bytes
-rw-r--r--Lib/test/warning_tests.py9
-rw-r--r--Lib/textwrap.py12
-rw-r--r--Lib/timeit.py64
-rw-r--r--Lib/unittest.py97
-rw-r--r--Lib/urllib.py12
-rw-r--r--Lib/urllib2.py71
-rw-r--r--Lib/wave.py14
-rw-r--r--Lib/webbrowser.py51
-rw-r--r--Lib/zipfile.py274
197 files changed, 11502 insertions, 6693 deletions
diff --git a/Lib/Bastion.py b/Lib/Bastion.py
index d83ea3e..5331ba9 100644
--- a/Lib/Bastion.py
+++ b/Lib/Bastion.py
@@ -97,7 +97,7 @@ def Bastion(object, filter = lambda name: name[:1] != '_',
"""
- raise RuntimeError, "This code is not secure in Python 2.2 and 2.3"
+ raise RuntimeError, "This code is not secure in Python 2.2 and later"
# Note: we define *two* ad-hoc functions here, get1 and get2.
# Both are intended to be called in the same way: get(name).
diff --git a/Lib/CGIHTTPServer.py b/Lib/CGIHTTPServer.py
index f2d10e9..5017eec 100644
--- a/Lib/CGIHTTPServer.py
+++ b/Lib/CGIHTTPServer.py
@@ -197,6 +197,9 @@ class CGIHTTPRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
length = self.headers.getheader('content-length')
if length:
env['CONTENT_LENGTH'] = length
+ referer = self.headers.getheader('referer')
+ if referer:
+ env['HTTP_REFERER'] = referer
accept = []
for line in self.headers.getallmatchingheaders('accept'):
if line[:1] in "\t\n\r ":
@@ -214,7 +217,7 @@ class CGIHTTPRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
# Since we're setting the env in the parent, provide empty
# values to override previously set values
for k in ('QUERY_STRING', 'REMOTE_HOST', 'CONTENT_LENGTH',
- 'HTTP_USER_AGENT', 'HTTP_COOKIE'):
+ 'HTTP_USER_AGENT', 'HTTP_COOKIE', 'HTTP_REFERER'):
env.setdefault(k, "")
os.environ.update(env)
diff --git a/Lib/ConfigParser.py b/Lib/ConfigParser.py
index 8e644e1..6457e0f 100644
--- a/Lib/ConfigParser.py
+++ b/Lib/ConfigParser.py
@@ -594,7 +594,8 @@ class SafeConfigParser(ConfigParser):
self._interpolate_some(option, L, rawval, section, vars, 1)
return ''.join(L)
- _interpvar_match = re.compile(r"%\(([^)]+)\)s").match
+ _interpvar_re = re.compile(r"%\(([^)]+)\)s")
+ _badpercent_re = re.compile(r"%[^%]|%$")
def _interpolate_some(self, option, accum, rest, section, map, depth):
if depth > MAX_INTERPOLATION_DEPTH:
@@ -613,7 +614,7 @@ class SafeConfigParser(ConfigParser):
accum.append("%")
rest = rest[2:]
elif c == "(":
- m = self._interpvar_match(rest)
+ m = self._interpvar_re.match(rest)
if m is None:
raise InterpolationSyntaxError(option, section,
"bad interpolation variable reference %r" % rest)
@@ -638,4 +639,12 @@ class SafeConfigParser(ConfigParser):
"""Set an option. Extend ConfigParser.set: check for string values."""
if not isinstance(value, basestring):
raise TypeError("option values must be strings")
+ # check for bad percent signs:
+ # first, replace all "good" interpolations
+ tmp_value = self._interpvar_re.sub('', value)
+ # then, check if there's a lone percent sign left
+ m = self._badpercent_re.search(tmp_value)
+ if m:
+ raise ValueError("invalid interpolation syntax in %r at "
+ "position %d" % (value, m.start()))
ConfigParser.set(self, section, option, value)
diff --git a/Lib/DocXMLRPCServer.py b/Lib/DocXMLRPCServer.py
index fd3b2c9..111e5f6 100644
--- a/Lib/DocXMLRPCServer.py
+++ b/Lib/DocXMLRPCServer.py
@@ -252,8 +252,10 @@ class DocXMLRPCServer( SimpleXMLRPCServer,
"""
def __init__(self, addr, requestHandler=DocXMLRPCRequestHandler,
- logRequests=1):
- SimpleXMLRPCServer.__init__(self, addr, requestHandler, logRequests)
+ logRequests=1, allow_none=False, encoding=None,
+ bind_and_activate=True):
+ SimpleXMLRPCServer.__init__(self, addr, requestHandler, logRequests,
+ allow_none, encoding, bind_and_activate)
XMLRPCDocGenerator.__init__(self)
class DocCGIXMLRPCRequestHandler( CGIXMLRPCRequestHandler,
diff --git a/Lib/HTMLParser.py b/Lib/HTMLParser.py
index 8380466..52f8c57 100644
--- a/Lib/HTMLParser.py
+++ b/Lib/HTMLParser.py
@@ -358,12 +358,30 @@ class HTMLParser(markupbase.ParserBase):
self.error("unknown declaration: %r" % (data,))
# Internal -- helper to remove special character quoting
+ entitydefs = None
def unescape(self, s):
if '&' not in s:
return s
- s = s.replace("&lt;", "<")
- s = s.replace("&gt;", ">")
- s = s.replace("&apos;", "'")
- s = s.replace("&quot;", '"')
- s = s.replace("&amp;", "&") # Must be last
- return s
+ def replaceEntities(s):
+ s = s.groups()[0]
+ if s[0] == "#":
+ s = s[1:]
+ if s[0] in ['x','X']:
+ c = int(s[1:], 16)
+ else:
+ c = int(s)
+ return unichr(c)
+ else:
+ # Cannot use name2codepoint directly, because HTMLParser supports apos,
+ # which is not part of HTML 4
+ import htmlentitydefs
+ if HTMLParser.entitydefs is None:
+ entitydefs = HTMLParser.entitydefs = {'apos':u"'"}
+ for k, v in htmlentitydefs.name2codepoint.items():
+ entitydefs[k] = unichr(v)
+ try:
+ return self.entitydefs[s]
+ except KeyError:
+ return '&'+s+';'
+
+ return re.sub(r"&(#?[xX]?(?:[0-9a-fA-F]+|\w{1,8}));", replaceEntities, s)
diff --git a/Lib/SimpleXMLRPCServer.py b/Lib/SimpleXMLRPCServer.py
index 7065cc0..4aadffa 100644
--- a/Lib/SimpleXMLRPCServer.py
+++ b/Lib/SimpleXMLRPCServer.py
@@ -517,11 +517,11 @@ class SimpleXMLRPCServer(SocketServer.TCPServer,
allow_reuse_address = True
def __init__(self, addr, requestHandler=SimpleXMLRPCRequestHandler,
- logRequests=True, allow_none=False, encoding=None):
+ logRequests=True, allow_none=False, encoding=None, bind_and_activate=True):
self.logRequests = logRequests
SimpleXMLRPCDispatcher.__init__(self, allow_none, encoding)
- SocketServer.TCPServer.__init__(self, addr, requestHandler)
+ SocketServer.TCPServer.__init__(self, addr, requestHandler, bind_and_activate)
# [Bug #1222790] If possible, set close-on-exec flag; if a
# method spawns a subprocess, the subprocess shouldn't have
diff --git a/Lib/SocketServer.py b/Lib/SocketServer.py
index eedb251..84bbcf6 100644
--- a/Lib/SocketServer.py
+++ b/Lib/SocketServer.py
@@ -279,7 +279,7 @@ class TCPServer(BaseServer):
Methods for the caller:
- - __init__(server_address, RequestHandlerClass)
+ - __init__(server_address, RequestHandlerClass, bind_and_activate=True)
- serve_forever()
- handle_request() # if you don't use serve_forever()
- fileno() -> int # for select()
@@ -322,13 +322,14 @@ class TCPServer(BaseServer):
allow_reuse_address = False
- def __init__(self, server_address, RequestHandlerClass):
+ def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True):
"""Constructor. May be extended, do not override."""
BaseServer.__init__(self, server_address, RequestHandlerClass)
self.socket = socket.socket(self.address_family,
self.socket_type)
- self.server_bind()
- self.server_activate()
+ if bind_and_activate:
+ self.server_bind()
+ self.server_activate()
def server_bind(self):
"""Called by constructor to bind the socket.
@@ -339,6 +340,7 @@ class TCPServer(BaseServer):
if self.allow_reuse_address:
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind(self.server_address)
+ self.server_address = self.socket.getsockname()
def server_activate(self):
"""Called by constructor to activate the server.
diff --git a/Lib/_strptime.py b/Lib/_strptime.py
index 10b0083..e5d2721 100644
--- a/Lib/_strptime.py
+++ b/Lib/_strptime.py
@@ -295,17 +295,16 @@ def strptime(data_string, format="%a %b %d %H:%M:%S %Y"):
"""Return a time struct based on the input string and the format string."""
global _TimeRE_cache, _regex_cache
with _cache_lock:
- time_re = _TimeRE_cache
- locale_time = time_re.locale_time
- if _getlang() != locale_time.lang:
+ if _getlang() != _TimeRE_cache.locale_time.lang:
_TimeRE_cache = TimeRE()
- _regex_cache = {}
+ _regex_cache.clear()
if len(_regex_cache) > _CACHE_MAX_SIZE:
_regex_cache.clear()
+ locale_time = _TimeRE_cache.locale_time
format_regex = _regex_cache.get(format)
if not format_regex:
try:
- format_regex = time_re.compile(format)
+ format_regex = _TimeRE_cache.compile(format)
# KeyError raised when a bad format is found; can be specified as
# \\, in which case it was a stray % but with a space after it
except KeyError as err:
diff --git a/Lib/bisect.py b/Lib/bisect.py
index 152f6c7..e4a2133 100644
--- a/Lib/bisect.py
+++ b/Lib/bisect.py
@@ -23,8 +23,8 @@ def bisect_right(a, x, lo=0, hi=None):
"""Return the index where to insert item x in list a, assuming a is sorted.
The return value i is such that all e in a[:i] have e <= x, and all e in
- a[i:] have e > x. So if x already appears in the list, i points just
- beyond the rightmost x already there.
+ a[i:] have e > x. So if x already appears in the list, a.insert(x) will
+ insert just after the rightmost x already there.
Optional args lo (default 0) and hi (default len(a)) bound the
slice of a to be searched.
@@ -62,8 +62,8 @@ def bisect_left(a, x, lo=0, hi=None):
"""Return the index where to insert item x in list a, assuming a is sorted.
The return value i is such that all e in a[:i] have e < x, and all e in
- a[i:] have e >= x. So if x already appears in the list, i points just
- before the leftmost x already there.
+ a[i:] have e >= x. So if x already appears in the list, a.insert(x) will
+ insert just before the leftmost x already there.
Optional args lo (default 0) and hi (default len(a)) bound the
slice of a to be searched.
diff --git a/Lib/cmd.py b/Lib/cmd.py
index 23dc5b2..e44f55a 100644
--- a/Lib/cmd.py
+++ b/Lib/cmd.py
@@ -49,11 +49,6 @@ __all__ = ["Cmd"]
PROMPT = '(Cmd) '
IDENTCHARS = string.ascii_letters + string.digits + '_'
-def raw_input(prompt):
- sys.stdout.write(prompt)
- sys.stdout.flush()
- return sys.stdin.readline()
-
class Cmd:
"""A simple framework for writing line-oriented command interpreters.
@@ -129,7 +124,7 @@ class Cmd:
else:
if self.use_rawinput:
try:
- line = raw_input(self.prompt)
+ line = input(self.prompt)
except EOFError:
line = 'EOF'
else:
diff --git a/Lib/collections.py b/Lib/collections.py
new file mode 100644
index 0000000..dba9c7d
--- /dev/null
+++ b/Lib/collections.py
@@ -0,0 +1,62 @@
+__all__ = ['deque', 'defaultdict', 'NamedTuple']
+
+from _collections import deque, defaultdict
+from operator import itemgetter as _itemgetter
+import sys as _sys
+
+def NamedTuple(typename, s):
+ """Returns a new subclass of tuple with named fields.
+
+ >>> Point = NamedTuple('Point', 'x y')
+ >>> Point.__doc__ # docstring for the new class
+ 'Point(x, y)'
+ >>> p = Point(11, y=22) # instantiate with positional args or keywords
+ >>> p[0] + p[1] # works just like the tuple (11, 22)
+ 33
+ >>> x, y = p # unpacks just like a tuple
+ >>> x, y
+ (11, 22)
+ >>> p.x + p.y # fields also accessable by name
+ 33
+ >>> p # readable __repr__ with name=value style
+ Point(x=11, y=22)
+
+ """
+
+ field_names = s.split()
+ nargs = len(field_names)
+
+ def __new__(cls, *args, **kwds):
+ if kwds:
+ try:
+ args += tuple(kwds[name] for name in field_names[len(args):])
+ except KeyError as name:
+ raise TypeError('%s missing required argument: %s' % (typename, name))
+ if len(args) != nargs:
+ raise TypeError('%s takes exactly %d arguments (%d given)' % (typename, nargs, len(args)))
+ return tuple.__new__(cls, args)
+
+ repr_template = '%s(%s)' % (typename, ', '.join('%s=%%r' % name for name in field_names))
+
+ m = dict(vars(tuple)) # pre-lookup superclass methods (for faster lookup)
+ m.update(__doc__= '%s(%s)' % (typename, ', '.join(field_names)),
+ __slots__ = (), # no per-instance dict (so instances are same size as tuples)
+ __new__ = __new__,
+ __repr__ = lambda self, _format=repr_template.__mod__: _format(self),
+ __module__ = _sys._getframe(1).f_globals['__name__'],
+ )
+ m.update((name, property(_itemgetter(index))) for index, name in enumerate(field_names))
+
+ return type(typename, (tuple,), m)
+
+
+if __name__ == '__main__':
+ # verify that instances are pickable
+ from cPickle import loads, dumps
+ Point = NamedTuple('Point', 'x y')
+ p = Point(x=10, y=20)
+ assert p == loads(dumps(p))
+
+ import doctest
+ TestResults = NamedTuple('TestResults', 'failed attempted')
+ print(TestResults(*doctest.testmod()))
diff --git a/Lib/commands.py b/Lib/commands.py
index cfbb541..d19aa1a 100644
--- a/Lib/commands.py
+++ b/Lib/commands.py
@@ -32,6 +32,8 @@ __all__ = ["getstatusoutput","getoutput","getstatus"]
#
def getstatus(file):
"""Return output of "ls -ld <file>" in a string."""
+ import warnings
+ warnings.warn("commands.getstatus() is deprecated", DeprecationWarning)
return getoutput('ls -ld' + mkarg(file))
diff --git a/Lib/copy_reg.py b/Lib/copy_reg.py
index f4661ed..58d462b 100644
--- a/Lib/copy_reg.py
+++ b/Lib/copy_reg.py
@@ -43,7 +43,8 @@ def _reconstructor(cls, base, state):
obj = object.__new__(cls)
else:
obj = base.__new__(cls, state)
- base.__init__(obj, state)
+ if base.__init__ != object.__init__:
+ base.__init__(obj, state)
return obj
_HEAPTYPE = 1<<9
diff --git a/Lib/csv.py b/Lib/csv.py
index 45570f7..6ee12c8 100644
--- a/Lib/csv.py
+++ b/Lib/csv.py
@@ -115,9 +115,10 @@ class DictWriter:
def _dict_to_list(self, rowdict):
if self.extrasaction == "raise":
- for k in rowdict.keys():
- if k not in self.fieldnames:
- raise ValueError, "dict contains fields not in fieldnames"
+ wrong_fields = [k for k in rowdict if k not in self.fieldnames]
+ if wrong_fields:
+ raise ValueError("dict contains fields not in fieldnames: " +
+ ", ".join(wrong_fields))
return [rowdict.get(key, self.restval) for key in self.fieldnames]
def writerow(self, rowdict):
diff --git a/Lib/ctypes/__init__.py b/Lib/ctypes/__init__.py
index e2ea426..bd9b66e 100644
--- a/Lib/ctypes/__init__.py
+++ b/Lib/ctypes/__init__.py
@@ -233,6 +233,9 @@ class c_void_p(_SimpleCData):
c_voidp = c_void_p # backwards compatibility (to a bug)
_check_size(c_void_p)
+class c_bool(_SimpleCData):
+ _type_ = "t"
+
# This cache maps types to pointers to them.
_pointer_type_cache = {}
@@ -480,7 +483,7 @@ def cast(obj, typ):
return _cast(obj, obj, typ)
_string_at = CFUNCTYPE(py_object, c_void_p, c_int)(_string_at_addr)
-def string_at(ptr, size=0):
+def string_at(ptr, size=-1):
"""string_at(addr[, size]) -> string
Return the string at addr."""
@@ -492,7 +495,7 @@ except ImportError:
pass
else:
_wstring_at = CFUNCTYPE(py_object, c_void_p, c_int)(_wstring_at_addr)
- def wstring_at(ptr, size=0):
+ def wstring_at(ptr, size=-1):
"""wstring_at(addr[, size]) -> string
Return the string at addr."""
diff --git a/Lib/ctypes/test/test_functions.py b/Lib/ctypes/test/test_functions.py
index e907e21..40892b9 100644
--- a/Lib/ctypes/test/test_functions.py
+++ b/Lib/ctypes/test/test_functions.py
@@ -21,7 +21,9 @@ if sys.platform == "win32":
class POINT(Structure):
_fields_ = [("x", c_int), ("y", c_int)]
-
+class RECT(Structure):
+ _fields_ = [("left", c_int), ("top", c_int),
+ ("right", c_int), ("bottom", c_int)]
class FunctionTestCase(unittest.TestCase):
def test_mro(self):
@@ -379,5 +381,15 @@ class FunctionTestCase(unittest.TestCase):
self.failUnlessEqual((s8i.a, s8i.b, s8i.c, s8i.d, s8i.e, s8i.f, s8i.g, s8i.h),
(9*2, 8*3, 7*4, 6*5, 5*6, 4*7, 3*8, 2*9))
+ def test_sf1651235(self):
+ # see http://www.python.org/sf/1651235
+
+ proto = CFUNCTYPE(c_int, RECT, POINT)
+ def callback(*args):
+ return 0
+
+ callback = proto(callback)
+ self.failUnlessRaises(ArgumentError, lambda: callback((1, 2, 3, 4), POINT()))
+
if __name__ == '__main__':
unittest.main()
diff --git a/Lib/ctypes/test/test_memfunctions.py b/Lib/ctypes/test/test_memfunctions.py
index fbae2ce..aef7a73 100644
--- a/Lib/ctypes/test/test_memfunctions.py
+++ b/Lib/ctypes/test/test_memfunctions.py
@@ -14,6 +14,7 @@ class MemFunctionsTest(unittest.TestCase):
self.failUnlessEqual(string_at(result), "Hello, World")
self.failUnlessEqual(string_at(result, 5), "Hello")
self.failUnlessEqual(string_at(result, 16), "Hello, World\0\0\0\0")
+ self.failUnlessEqual(string_at(result, 0), "")
def test_memset(self):
a = create_string_buffer(1000000)
@@ -54,6 +55,7 @@ class MemFunctionsTest(unittest.TestCase):
self.failUnlessEqual(wstring_at(a), "Hello, World")
self.failUnlessEqual(wstring_at(a, 5), "Hello")
self.failUnlessEqual(wstring_at(a, 16), "Hello, World\0\0\0\0")
+ self.failUnlessEqual(wstring_at(a, 0), "")
if __name__ == "__main__":
unittest.main()
diff --git a/Lib/ctypes/test/test_numbers.py b/Lib/ctypes/test/test_numbers.py
index 2c5a990..eaabc7c 100644
--- a/Lib/ctypes/test/test_numbers.py
+++ b/Lib/ctypes/test/test_numbers.py
@@ -24,6 +24,8 @@ ArgType = type(byref(c_int(0)))
unsigned_types = [c_ubyte, c_ushort, c_uint, c_ulong]
signed_types = [c_byte, c_short, c_int, c_long, c_longlong]
+bool_types = []
+
float_types = [c_double, c_float]
try:
@@ -35,8 +37,16 @@ else:
unsigned_types.append(c_ulonglong)
signed_types.append(c_longlong)
+try:
+ c_bool
+except NameError:
+ pass
+else:
+ bool_types.append(c_bool)
+
unsigned_ranges = valid_ranges(*unsigned_types)
signed_ranges = valid_ranges(*signed_types)
+bool_values = [True, False, 0, 1, -1, 5000, 'test', [], [1]]
################################################################
@@ -60,6 +70,11 @@ class NumberTestCase(unittest.TestCase):
self.failUnlessEqual(t(l).value, l)
self.failUnlessEqual(t(h).value, h)
+ def test_bool_values(self):
+ from operator import truth
+ for t, v in zip(bool_types, bool_values):
+ self.failUnlessEqual(t(v).value, truth(v))
+
def test_typeerror(self):
# Only numbers are allowed in the contructor,
# otherwise TypeError is raised
@@ -82,7 +97,7 @@ class NumberTestCase(unittest.TestCase):
def test_byref(self):
# calling byref returns also a PyCArgObject instance
- for t in signed_types + unsigned_types + float_types:
+ for t in signed_types + unsigned_types + float_types + bool_types:
parm = byref(t())
self.failUnlessEqual(ArgType, type(parm))
@@ -101,7 +116,7 @@ class NumberTestCase(unittest.TestCase):
self.assertRaises(TypeError, t, 3.14)
def test_sizes(self):
- for t in signed_types + unsigned_types + float_types:
+ for t in signed_types + unsigned_types + float_types + bool_types:
size = struct.calcsize(t._type_)
# sizeof of the type...
self.failUnlessEqual(sizeof(t), size)
@@ -164,6 +179,18 @@ class NumberTestCase(unittest.TestCase):
a[0] = '?'
self.failUnlessEqual(v.value, a[0])
+ # array does not support c_bool / 't'
+ # def test_bool_from_address(self):
+ # from ctypes import c_bool
+ # from array import array
+ # a = array(c_bool._type_, [True])
+ # v = t.from_address(a.buffer_info()[0])
+ # self.failUnlessEqual(v.value, a[0])
+ # self.failUnlessEqual(type(v) is t)
+ # a[0] = False
+ # self.failUnlessEqual(v.value, a[0])
+ # self.failUnlessEqual(type(v) is t)
+
def test_init(self):
# c_int() can be initialized from Python's int, and c_int.
# Not from c_long or so, which seems strange, abd should
diff --git a/Lib/ctypes/test/test_repr.py b/Lib/ctypes/test/test_repr.py
index 1044f67..f6f9366 100644
--- a/Lib/ctypes/test/test_repr.py
+++ b/Lib/ctypes/test/test_repr.py
@@ -4,7 +4,7 @@ import unittest
subclasses = []
for base in [c_byte, c_short, c_int, c_long, c_longlong,
c_ubyte, c_ushort, c_uint, c_ulong, c_ulonglong,
- c_float, c_double]:
+ c_float, c_double, c_bool]:
class X(base):
pass
subclasses.append(X)
diff --git a/Lib/decimal.py b/Lib/decimal.py
index 148b626..a7238e1 100644
--- a/Lib/decimal.py
+++ b/Lib/decimal.py
@@ -29,8 +29,8 @@ and IEEE standard 854-1987:
Decimal floating point has finite precision with arbitrarily large bounds.
-The purpose of the module is to support arithmetic using familiar
-"schoolhouse" rules and to avoid the some of tricky representation
+The purpose of this module is to support arithmetic using familiar
+"schoolhouse" rules and to avoid some of the tricky representation
issues associated with binary floating point. The package is especially
useful for financial applications or for contexts where users have
expectations that are at odds with binary floating point (for instance,
@@ -136,7 +136,7 @@ __all__ = [
import copy as _copy
-#Rounding
+# Rounding
ROUND_DOWN = 'ROUND_DOWN'
ROUND_HALF_UP = 'ROUND_HALF_UP'
ROUND_HALF_EVEN = 'ROUND_HALF_EVEN'
@@ -145,11 +145,11 @@ ROUND_FLOOR = 'ROUND_FLOOR'
ROUND_UP = 'ROUND_UP'
ROUND_HALF_DOWN = 'ROUND_HALF_DOWN'
-#Rounding decision (not part of the public API)
+# Rounding decision (not part of the public API)
NEVER_ROUND = 'NEVER_ROUND' # Round in division (non-divmod), sqrt ONLY
ALWAYS_ROUND = 'ALWAYS_ROUND' # Every operation rounds at end.
-#Errors
+# Errors
class DecimalException(ArithmeticError):
"""Base exception class.
@@ -179,9 +179,9 @@ class Clamped(DecimalException):
This occurs and signals clamped if the exponent of a result has been
altered in order to fit the constraints of a specific concrete
- representation. This may occur when the exponent of a zero result would
- be outside the bounds of a representation, or when a large normal
- number would have an encoded exponent that cannot be represented. In
+ representation. This may occur when the exponent of a zero result would
+ be outside the bounds of a representation, or when a large normal
+ number would have an encoded exponent that cannot be represented. In
this latter case, the exponent is reduced to fit and the corresponding
number of zero digits are appended to the coefficient ("fold-down").
"""
@@ -194,8 +194,8 @@ class InvalidOperation(DecimalException):
Something creates a signaling NaN
-INF + INF
- 0 * (+-)INF
- (+-)INF / (+-)INF
+ 0 * (+-)INF
+ (+-)INF / (+-)INF
x % 0
(+-)INF % x
x._rescale( non-integer )
@@ -207,7 +207,7 @@ class InvalidOperation(DecimalException):
"""
def handle(self, context, *args):
if args:
- if args[0] == 1: #sNaN, must drop 's' but keep diagnostics
+ if args[0] == 1: # sNaN, must drop 's' but keep diagnostics
return Decimal( (args[1]._sign, args[1]._int, 'n') )
return NaN
@@ -216,11 +216,11 @@ class ConversionSyntax(InvalidOperation):
This occurs and signals invalid-operation if an string is being
converted to a number and it does not conform to the numeric string
- syntax. The result is [0,qNaN].
+ syntax. The result is [0,qNaN].
"""
def handle(self, context, *args):
- return (0, (0,), 'n') #Passed to something which uses a tuple.
+ return (0, (0,), 'n') # Passed to something which uses a tuple.
class DivisionByZero(DecimalException, ZeroDivisionError):
"""Division by 0.
@@ -245,7 +245,7 @@ class DivisionImpossible(InvalidOperation):
This occurs and signals invalid-operation if the integer result of a
divide-integer or remainder operation had too many digits (would be
- longer than precision). The result is [0,qNaN].
+ longer than precision). The result is [0,qNaN].
"""
def handle(self, context, *args):
@@ -256,12 +256,12 @@ class DivisionUndefined(InvalidOperation, ZeroDivisionError):
This occurs and signals invalid-operation if division by zero was
attempted (during a divide-integer, divide, or remainder operation), and
- the dividend is also zero. The result is [0,qNaN].
+ the dividend is also zero. The result is [0,qNaN].
"""
def handle(self, context, tup=None, *args):
if tup is not None:
- return (NaN, NaN) #for 0 %0, 0 // 0
+ return (NaN, NaN) # for 0 %0, 0 // 0
return NaN
class Inexact(DecimalException):
@@ -269,7 +269,7 @@ class Inexact(DecimalException):
This occurs and signals inexact whenever the result of an operation is
not exact (that is, it needed to be rounded and any discarded digits
- were non-zero), or if an overflow or underflow condition occurs. The
+ were non-zero), or if an overflow or underflow condition occurs. The
result in all cases is unchanged.
The inexact signal may be tested (or trapped) to determine if a given
@@ -281,11 +281,11 @@ class InvalidContext(InvalidOperation):
"""Invalid context. Unknown rounding, for example.
This occurs and signals invalid-operation if an invalid context was
- detected during an operation. This can occur if contexts are not checked
+ detected during an operation. This can occur if contexts are not checked
on creation and either the precision exceeds the capability of the
underlying concrete representation or an unknown or unsupported rounding
- was specified. These aspects of the context need only be checked when
- the values are required to be used. The result is [0,qNaN].
+ was specified. These aspects of the context need only be checked when
+ the values are required to be used. The result is [0,qNaN].
"""
def handle(self, context, *args):
@@ -296,7 +296,7 @@ class Rounded(DecimalException):
This occurs and signals rounded whenever the result of an operation is
rounded (that is, some zero or non-zero digits were discarded from the
- coefficient), or if an overflow or underflow condition occurs. The
+ coefficient), or if an overflow or underflow condition occurs. The
result in all cases is unchanged.
The rounded signal may be tested (or trapped) to determine if a given
@@ -309,7 +309,7 @@ class Subnormal(DecimalException):
This occurs and signals subnormal whenever the result of a conversion or
operation is subnormal (that is, its adjusted exponent is less than
- Emin, before any rounding). The result in all cases is unchanged.
+ Emin, before any rounding). The result in all cases is unchanged.
The subnormal signal may be tested (or trapped) to determine if a given
or operation (or sequence of operations) yielded a subnormal result.
@@ -328,13 +328,13 @@ class Overflow(Inexact, Rounded):
For round-half-up and round-half-even (and for round-half-down and
round-up, if implemented), the result of the operation is [sign,inf],
- where sign is the sign of the intermediate result. For round-down, the
+ where sign is the sign of the intermediate result. For round-down, the
result is the largest finite number that can be represented in the
- current precision, with the sign of the intermediate result. For
+ current precision, with the sign of the intermediate result. For
round-ceiling, the result is the same as for round-down if the sign of
- the intermediate result is 1, or is [0,inf] otherwise. For round-floor,
+ the intermediate result is 1, or is [0,inf] otherwise. For round-floor,
the result is the same as for round-down if the sign of the intermediate
- result is 0, or is [1,inf] otherwise. In all cases, Inexact and Rounded
+ result is 0, or is [1,inf] otherwise. In all cases, Inexact and Rounded
will also be raised.
"""
@@ -360,10 +360,10 @@ class Underflow(Inexact, Rounded, Subnormal):
This occurs and signals underflow if a result is inexact and the
adjusted exponent of the result would be smaller (more negative) than
the smallest value that can be handled by the implementation (the value
- Emin). That is, the result is both inexact and subnormal.
+ Emin). That is, the result is both inexact and subnormal.
The result after an underflow will be a subnormal number rounded, if
- necessary, so that its exponent is not less than Etiny. This may result
+ necessary, so that its exponent is not less than Etiny. This may result
in 0 with the sign of the intermediate result and an exponent of Etiny.
In all cases, Inexact, Rounded, and Subnormal will also be raised.
@@ -379,7 +379,7 @@ _condition_map = {ConversionSyntax:InvalidOperation,
DivisionUndefined:InvalidOperation,
InvalidContext:InvalidOperation}
-##### Context Functions #######################################
+##### Context Functions ##################################################
# The getcontext() and setcontext() function manage access to a thread-local
# current context. Py2.4 offers direct support for thread locals. If that
@@ -392,7 +392,7 @@ try:
except ImportError:
# Python was compiled without threads; create a mock object instead
import sys
- class MockThreading:
+ class MockThreading(object):
def local(self, sys=sys):
return sys.modules[__name__]
threading = MockThreading()
@@ -403,8 +403,8 @@ try:
except AttributeError:
- #To fix reloading, force it to create a new context
- #Old contexts have different exceptions in their dicts, making problems.
+ # To fix reloading, force it to create a new context
+ # Old contexts have different exceptions in their dicts, making problems.
if hasattr(threading.currentThread(), '__decimal_context__'):
del threading.currentThread().__decimal_context__
@@ -469,14 +469,14 @@ def localcontext(ctx=None):
ctx.prec += 2
# Rest of sin calculation algorithm
# uses a precision 2 greater than normal
- return +s # Convert result to normal precision
+ return +s # Convert result to normal precision
def sin(x):
with localcontext(ExtendedContext):
# Rest of sin calculation algorithm
# uses the Extended Context from the
# General Decimal Arithmetic Specification
- return +s # Convert result to normal context
+ return +s # Convert result to normal context
"""
# The string below can't be included in the docstring until Python 2.6
@@ -489,11 +489,11 @@ def localcontext(ctx=None):
... ctx = getcontext()
... ctx.prec += 2
... print(ctx.prec)
- ...
+ ...
30
>>> with localcontext(ExtendedContext):
... print(getcontext().prec)
- ...
+ ...
9
>>> print(getcontext().prec)
28
@@ -502,7 +502,7 @@ def localcontext(ctx=None):
return _ContextManager(ctx)
-##### Decimal class ###########################################
+##### Decimal class #######################################################
class Decimal(object):
"""Floating point class for decimal arithmetic."""
@@ -518,7 +518,7 @@ class Decimal(object):
>>> Decimal('3.14') # string input
Decimal("3.14")
- >>> Decimal((0, (3, 1, 4), -2)) # tuple input (sign, digit_tuple, exponent)
+ >>> Decimal((0, (3, 1, 4), -2)) # tuple (sign, digit_tuple, exponent)
Decimal("3.14")
>>> Decimal(314) # int or long
Decimal("314")
@@ -557,13 +557,13 @@ class Decimal(object):
# tuple/list conversion (possibly from as_tuple())
if isinstance(value, (list,tuple)):
if len(value) != 3:
- raise ValueError, 'Invalid arguments'
+ raise ValueError('Invalid arguments')
if value[0] not in (0,1):
- raise ValueError, 'Invalid sign'
+ raise ValueError('Invalid sign')
for digit in value[1]:
if not isinstance(digit, (int,int)) or digit < 0:
- raise ValueError, "The second value in the tuple must be composed of non negative integer elements."
-
+ raise ValueError("The second value in the tuple must be"
+ "composed of non negative integer elements.")
self._sign = value[0]
self._int = tuple(value[1])
if value[2] in ('F','n','N'):
@@ -596,22 +596,23 @@ class Decimal(object):
if _isnan(value):
sig, sign, diag = _isnan(value)
self._is_special = True
- if len(diag) > context.prec: #Diagnostic info too long
+ if len(diag) > context.prec: # Diagnostic info too long
self._sign, self._int, self._exp = \
context._raise_error(ConversionSyntax)
return self
if sig == 1:
- self._exp = 'n' #qNaN
- else: #sig == 2
- self._exp = 'N' #sNaN
+ self._exp = 'n' # qNaN
+ else: # sig == 2
+ self._exp = 'N' # sNaN
self._sign = sign
- self._int = tuple(map(int, diag)) #Diagnostic info
+ self._int = tuple(map(int, diag)) # Diagnostic info
return self
try:
self._sign, self._int, self._exp = _string2exact(value)
except ValueError:
self._is_special = True
- self._sign, self._int, self._exp = context._raise_error(ConversionSyntax)
+ self._sign, self._int, self._exp = \
+ context._raise_error(ConversionSyntax)
return self
raise TypeError("Cannot convert %r to Decimal" % value)
@@ -694,15 +695,15 @@ class Decimal(object):
if self._is_special or other._is_special:
ans = self._check_nans(other, context)
if ans:
- return 1 # Comparison involving NaN's always reports self > other
+ return 1 # Comparison involving NaN's always reports self > other
# INF = INF
return cmp(self._isinfinity(), other._isinfinity())
if not self and not other:
- return 0 #If both 0, sign comparison isn't certain.
+ return 0 # If both 0, sign comparison isn't certain.
- #If different signs, neg one is less
+ # If different signs, neg one is less
if other._sign < self._sign:
return -1
if self._sign < other._sign:
@@ -713,7 +714,7 @@ class Decimal(object):
if self_adjusted == other_adjusted and \
self._int + (0,)*(self._exp - other._exp) == \
other._int + (0,)*(other._exp - self._exp):
- return 0 #equal, except in precision. ([0]*(-x) = [])
+ return 0 # equal, except in precision. ([0]*(-x) = [])
elif self_adjusted > other_adjusted and self._int[0] != 0:
return (-1)**self._sign
elif self_adjusted < other_adjusted and other._int[0] != 0:
@@ -724,7 +725,7 @@ class Decimal(object):
context = getcontext()
context = context._shallow_copy()
- rounding = context._set_rounding(ROUND_UP) #round away from 0
+ rounding = context._set_rounding(ROUND_UP) # round away from 0
flags = context._ignore_all_flags()
res = self.__sub__(other, context=context)
@@ -782,7 +783,7 @@ class Decimal(object):
if other is NotImplemented:
return other
- #compare(NaN, NaN) = NaN
+ # Compare(NaN, NaN) = NaN
if (self._is_special or other and other._is_special):
ans = self._check_nans(other, context)
if ans:
@@ -843,11 +844,11 @@ class Decimal(object):
tmp = map(str, self._int)
numdigits = len(self._int)
leftdigits = self._exp + numdigits
- if eng and not self: #self = 0eX wants 0[.0[0]]eY, not [[0]0]0eY
- if self._exp < 0 and self._exp >= -6: #short, no need for e/E
+ if eng and not self: # self = 0eX wants 0[.0[0]]eY, not [[0]0]0eY
+ if self._exp < 0 and self._exp >= -6: # short, no need for e/E
s = '-'*self._sign + '0.' + '0'*(abs(self._exp))
return s
- #exp is closest mult. of 3 >= self._exp
+ # exp is closest mult. of 3 >= self._exp
exp = ((self._exp - 1)// 3 + 1) * 3
if exp != self._exp:
s = '0.'+'0'*(exp - self._exp)
@@ -859,7 +860,7 @@ class Decimal(object):
else:
s += 'e'
if exp > 0:
- s += '+' #0.0e+3, not 0.0e3
+ s += '+' # 0.0e+3, not 0.0e3
s += str(exp)
s = '-'*self._sign + s
return s
@@ -999,19 +1000,19 @@ class Decimal(object):
return ans
if self._isinfinity():
- #If both INF, same sign => same as both, opposite => error.
+ # If both INF, same sign => same as both, opposite => error.
if self._sign != other._sign and other._isinfinity():
return context._raise_error(InvalidOperation, '-INF + INF')
return Decimal(self)
if other._isinfinity():
- return Decimal(other) #Can't both be infinity here
+ return Decimal(other) # Can't both be infinity here
shouldround = context._rounding_decision == ALWAYS_ROUND
exp = min(self._exp, other._exp)
negativezero = 0
if context.rounding == ROUND_FLOOR and self._sign != other._sign:
- #If the answer is 0, the sign should be negative, in this case.
+ # If the answer is 0, the sign should be negative, in this case.
negativezero = 1
if not self and not other:
@@ -1046,19 +1047,19 @@ class Decimal(object):
return Decimal((negativezero, (0,), exp))
if op1.int < op2.int:
op1, op2 = op2, op1
- #OK, now abs(op1) > abs(op2)
+ # OK, now abs(op1) > abs(op2)
if op1.sign == 1:
result.sign = 1
op1.sign, op2.sign = op2.sign, op1.sign
else:
result.sign = 0
- #So we know the sign, and op1 > 0.
+ # So we know the sign, and op1 > 0.
elif op1.sign == 1:
result.sign = 1
op1.sign, op2.sign = (0, 0)
else:
result.sign = 0
- #Now, op1 > abs(op2) > 0
+ # Now, op1 > abs(op2) > 0
if op2.sign == 0:
result.int = op1.int + op2.int
@@ -1116,7 +1117,8 @@ class Decimal(object):
if ans:
return ans
- return Decimal(self) # Must be infinite, and incrementing makes no difference
+ # Must be infinite, and incrementing makes no difference
+ return Decimal(self)
L = list(self._int)
L[-1] += 1
@@ -1172,7 +1174,7 @@ class Decimal(object):
if not self or not other:
ans = Decimal((resultsign, (0,), resultexp))
if shouldround:
- #Fixing in case the exponent is out of bounds
+ # Fixing in case the exponent is out of bounds
ans = ans._fix(context)
return ans
@@ -1191,7 +1193,7 @@ class Decimal(object):
op1 = _WorkRep(self)
op2 = _WorkRep(other)
- ans = Decimal( (resultsign, map(int, str(op1.int * op2.int)), resultexp))
+ ans = Decimal((resultsign, map(int, str(op1.int * op2.int)), resultexp))
if shouldround:
ans = ans._fix(context)
@@ -1283,12 +1285,11 @@ class Decimal(object):
sign, 1)
return context._raise_error(DivisionByZero, 'x / 0', sign)
- #OK, so neither = 0, INF or NaN
-
+ # OK, so neither = 0, INF or NaN
shouldround = context._rounding_decision == ALWAYS_ROUND
- #If we're dividing into ints, and self < other, stop.
- #self.__abs__(0) does not round.
+ # If we're dividing into ints, and self < other, stop.
+ # self.__abs__(0) does not round.
if divmod and (self.__abs__(0, context) < other.__abs__(0, context)):
if divmod == 1 or divmod == 3:
@@ -1300,7 +1301,7 @@ class Decimal(object):
ans2)
elif divmod == 2:
- #Don't round the mod part, if we don't need it.
+ # Don't round the mod part, if we don't need it.
return (Decimal( (sign, (0,), 0) ), Decimal(self))
op1 = _WorkRep(self)
@@ -1349,7 +1350,7 @@ class Decimal(object):
op1.exp -= 1
if res.exp == 0 and divmod and op2.int > op1.int:
- #Solves an error in precision. Same as a previous block.
+ # Solves an error in precision. Same as a previous block.
if res.int >= prec_limit and shouldround:
return context._raise_error(DivisionImpossible)
@@ -1434,7 +1435,7 @@ class Decimal(object):
# ignored in the calling function.
context = context._shallow_copy()
flags = context._ignore_flags(Rounded, Inexact)
- #keep DivisionImpossible flags
+ # Keep DivisionImpossible flags
(side, r) = self.__divmod__(other, context=context)
if r._isnan():
@@ -1457,7 +1458,7 @@ class Decimal(object):
if r < comparison:
r._sign, comparison._sign = s1, s2
- #Get flags now
+ # Get flags now
self.__divmod__(other, context=context)
return r._fix(context)
r._sign, comparison._sign = s1, s2
@@ -1479,7 +1480,8 @@ class Decimal(object):
if r > comparison or decrease and r == comparison:
r._sign, comparison._sign = s1, s2
context.prec += 1
- if len(side.__add__(Decimal(1), context=context)._int) >= context.prec:
+ numbsquant = len(side.__add__(Decimal(1), context=context)._int)
+ if numbsquant >= context.prec:
context.prec -= 1
return context._raise_error(DivisionImpossible)[1]
context.prec -= 1
@@ -1514,7 +1516,7 @@ class Decimal(object):
context = getcontext()
return context._raise_error(InvalidContext)
elif self._isinfinity():
- raise OverflowError, "Cannot convert infinity to long"
+ raise OverflowError("Cannot convert infinity to long")
if self._exp >= 0:
s = ''.join(map(str, self._int)) + '0'*self._exp
else:
@@ -1568,13 +1570,13 @@ class Decimal(object):
context._raise_error(Clamped)
return ans
ans = ans._rescale(Etiny, context=context)
- #It isn't zero, and exp < Emin => subnormal
+ # It isn't zero, and exp < Emin => subnormal
context._raise_error(Subnormal)
if context.flags[Inexact]:
context._raise_error(Underflow)
else:
if ans:
- #Only raise subnormal if non-zero.
+ # Only raise subnormal if non-zero.
context._raise_error(Subnormal)
else:
Etop = context.Etop()
@@ -1591,7 +1593,8 @@ class Decimal(object):
return ans
context._raise_error(Inexact)
context._raise_error(Rounded)
- return context._raise_error(Overflow, 'above Emax', ans._sign)
+ c = context._raise_error(Overflow, 'above Emax', ans._sign)
+ return c
return ans
def _round(self, prec=None, rounding=None, context=None):
@@ -1651,18 +1654,18 @@ class Decimal(object):
ans = Decimal( (temp._sign, tmp, temp._exp - expdiff))
return ans
- #OK, but maybe all the lost digits are 0.
+ # OK, but maybe all the lost digits are 0.
lostdigits = self._int[expdiff:]
if lostdigits == (0,) * len(lostdigits):
ans = Decimal( (temp._sign, temp._int[:prec], temp._exp - expdiff))
- #Rounded, but not Inexact
+ # Rounded, but not Inexact
context._raise_error(Rounded)
return ans
# Okay, let's round and lose data
this_function = getattr(temp, self._pick_rounding_function[rounding])
- #Now we've got the rounding function
+ # Now we've got the rounding function
if prec != context.prec:
context = context._shallow_copy()
@@ -1758,7 +1761,7 @@ class Decimal(object):
context = getcontext()
if self._is_special or n._is_special or n.adjusted() > 8:
- #Because the spot << doesn't work with really big exponents
+ # Because the spot << doesn't work with really big exponents
if n._isinfinity() or n.adjusted() > 8:
return context._raise_error(InvalidOperation, 'x ** INF')
@@ -1788,9 +1791,10 @@ class Decimal(object):
return Infsign[sign]
return Decimal( (sign, (0,), 0) )
- #with ludicrously large exponent, just raise an overflow and return inf.
- if not modulo and n > 0 and (self._exp + len(self._int) - 1) * n > context.Emax \
- and self:
+ # With ludicrously large exponent, just raise an overflow
+ # and return inf.
+ if not modulo and n > 0 and \
+ (self._exp + len(self._int) - 1) * n > context.Emax and self:
tmp = Decimal('inf')
tmp._sign = sign
@@ -1810,7 +1814,7 @@ class Decimal(object):
context = context._shallow_copy()
context.prec = firstprec + elength + 1
if n < 0:
- #n is a long now, not Decimal instance
+ # n is a long now, not Decimal instance
n = -n
mul = Decimal(1).__truediv__(mul, context=context)
@@ -1819,7 +1823,7 @@ class Decimal(object):
spot <<= 1
spot >>= 1
- #Spot is the highest power of 2 less than n
+ # spot is the highest power of 2 less than n
while spot:
val = val.__mul__(val, context=context)
if val._isinfinity():
@@ -1877,7 +1881,7 @@ class Decimal(object):
if exp._isinfinity() or self._isinfinity():
if exp._isinfinity() and self._isinfinity():
- return self #if both are inf, it is OK
+ return self # if both are inf, it is OK
if context is None:
context = getcontext()
return context._raise_error(InvalidOperation,
@@ -1982,13 +1986,13 @@ class Decimal(object):
return Decimal(self)
if not self:
- #exponent = self._exp / 2, using round_down.
- #if self._exp < 0:
+ # exponent = self._exp / 2, using round_down.
+ # if self._exp < 0:
# exp = (self._exp+1) // 2
- #else:
+ # else:
exp = (self._exp) // 2
if self._sign == 1:
- #sqrt(-0) = -0
+ # sqrt(-0) = -0
return Decimal( (1, (0,), exp))
else:
return Decimal( (0, (0,), exp))
@@ -2023,8 +2027,7 @@ class Decimal(object):
context=context), context=context)
ans._exp -= 1 + tmp.adjusted() // 2
- #ans is now a linear approximation.
-
+ # ans is now a linear approximation.
Emax, Emin = context.Emax, context.Emin
context.Emax, context.Emin = DefaultContext.Emax, DefaultContext.Emin
@@ -2039,12 +2042,12 @@ class Decimal(object):
if context.prec == maxp:
break
- #round to the answer's precision-- the only error can be 1 ulp.
+ # Round to the answer's precision-- the only error can be 1 ulp.
context.prec = firstprec
prevexp = ans.adjusted()
ans = ans._round(context=context)
- #Now, check if the other last digits are better.
+ # Now, check if the other last digits are better.
context.prec = firstprec + 1
# In case we rounded up another digit and we should actually go lower.
if prevexp != ans.adjusted():
@@ -2076,10 +2079,10 @@ class Decimal(object):
context._raise_error(Rounded)
context._raise_error(Inexact)
else:
- #Exact answer, so let's set the exponent right.
- #if self._exp < 0:
+ # Exact answer, so let's set the exponent right.
+ # if self._exp < 0:
# exp = (self._exp +1)// 2
- #else:
+ # else:
exp = self._exp // 2
context.prec += ans._exp - exp
ans = ans._rescale(exp, context=context)
@@ -2100,7 +2103,7 @@ class Decimal(object):
return other
if self._is_special or other._is_special:
- # if one operand is a quiet NaN and the other is number, then the
+ # If one operand is a quiet NaN and the other is number, then the
# number is always returned
sn = self._isnan()
on = other._isnan()
@@ -2114,13 +2117,13 @@ class Decimal(object):
ans = self
c = self.__cmp__(other)
if c == 0:
- # if both operands are finite and equal in numerical value
+ # If both operands are finite and equal in numerical value
# then an ordering is applied:
#
- # if the signs differ then max returns the operand with the
+ # If the signs differ then max returns the operand with the
# positive sign and min returns the operand with the negative sign
#
- # if the signs are the same then the exponent is used to select
+ # If the signs are the same then the exponent is used to select
# the result.
if self._sign != other._sign:
if self._sign:
@@ -2141,7 +2144,7 @@ class Decimal(object):
def min(self, other, context=None):
"""Returns the smaller value.
- like min(self, other) except if one is not a number, returns
+ Like min(self, other) except if one is not a number, returns
NaN (and signals if one is sNaN). Also rounds.
"""
other = _convert_other(other)
@@ -2149,7 +2152,7 @@ class Decimal(object):
return other
if self._is_special or other._is_special:
- # if one operand is a quiet NaN and the other is number, then the
+ # If one operand is a quiet NaN and the other is number, then the
# number is always returned
sn = self._isnan()
on = other._isnan()
@@ -2163,13 +2166,13 @@ class Decimal(object):
ans = self
c = self.__cmp__(other)
if c == 0:
- # if both operands are finite and equal in numerical value
+ # If both operands are finite and equal in numerical value
# then an ordering is applied:
#
- # if the signs differ then max returns the operand with the
+ # If the signs differ then max returns the operand with the
# positive sign and min returns the operand with the negative sign
#
- # if the signs are the same then the exponent is used to select
+ # If the signs are the same then the exponent is used to select
# the result.
if self._sign != other._sign:
if other._sign:
@@ -2204,11 +2207,11 @@ class Decimal(object):
"""Return the adjusted exponent of self"""
try:
return self._exp + len(self._int) - 1
- #If NaN or Infinity, self._exp is string
+ # If NaN or Infinity, self._exp is string
except TypeError:
return 0
- # support for pickling, copy, and deepcopy
+ # Support for pickling, copy, and deepcopy
def __reduce__(self):
return (self.__class__, (str(self),))
@@ -2222,13 +2225,14 @@ class Decimal(object):
return self # My components are also immutable
return self.__class__(str(self))
-##### Context class ###########################################
+##### Context class #######################################################
# get rounding method function:
-rounding_functions = [name for name in Decimal.__dict__.keys() if name.startswith('_round_')]
+rounding_functions = [name for name in Decimal.__dict__.keys()
+ if name.startswith('_round_')]
for name in rounding_functions:
- #name is like _round_half_even, goes to the global ROUND_HALF_EVEN value.
+ # name is like _round_half_even, goes to the global ROUND_HALF_EVEN value.
globalname = name[1:].upper()
val = globals()[globalname]
Decimal._pick_rounding_function[val] = name
@@ -2255,7 +2259,7 @@ class Context(object):
Contains:
prec - precision (for use in rounding, division, square roots..)
- rounding - rounding type. (how you round)
+ rounding - rounding type (how you round)
_rounding_decision - ALWAYS_ROUND, NEVER_ROUND -- do you round?
traps - If traps[exception] = 1, then the exception is
raised when it is caused. Otherwise, a value is
@@ -2294,9 +2298,13 @@ class Context(object):
def __repr__(self):
"""Show the current context."""
s = []
- s.append('Context(prec=%(prec)d, rounding=%(rounding)s, Emin=%(Emin)d, Emax=%(Emax)d, capitals=%(capitals)d' % vars(self))
- s.append('flags=[' + ', '.join([f.__name__ for f, v in self.flags.items() if v]) + ']')
- s.append('traps=[' + ', '.join([t.__name__ for t, v in self.traps.items() if v]) + ']')
+ s.append('Context(prec=%(prec)d, rounding=%(rounding)s, '
+ 'Emin=%(Emin)d, Emax=%(Emax)d, capitals=%(capitals)d'
+ % vars(self))
+ names = [f.__name__ for f, v in self.flags.items() if v]
+ s.append('flags=[' + ', '.join(names) + ']')
+ names = [t.__name__ for t, v in self.traps.items() if v]
+ s.append('traps=[' + ', '.join(names) + ']')
return ', '.join(s) + ')'
def clear_flags(self):
@@ -2313,9 +2321,9 @@ class Context(object):
def copy(self):
"""Returns a deep copy from self."""
- nc = Context(self.prec, self.rounding, self.traps.copy(), self.flags.copy(),
- self._rounding_decision, self.Emin, self.Emax,
- self.capitals, self._clamp, self._ignored_flags)
+ nc = Context(self.prec, self.rounding, self.traps.copy(),
+ self.flags.copy(), self._rounding_decision, self.Emin,
+ self.Emax, self.capitals, self._clamp, self._ignored_flags)
return nc
__copy__ = copy
@@ -2329,16 +2337,16 @@ class Context(object):
"""
error = _condition_map.get(condition, condition)
if error in self._ignored_flags:
- #Don't touch the flag
+ # Don't touch the flag
return error().handle(self, *args)
self.flags[error] += 1
if not self.traps[error]:
- #The errors define how to handle themselves.
+ # The errors define how to handle themselves.
return condition().handle(self, *args)
# Errors should only be risked on copies of the context
- #self._ignored_flags = []
+ # self._ignored_flags = []
raise error, explanation
def _ignore_all_flags(self):
@@ -2362,7 +2370,7 @@ class Context(object):
def __hash__(self):
"""A Context cannot be hashed."""
# We inherit object.__hash__, so we must deny this explicitly
- raise TypeError, "Cannot hash a Context."
+ raise TypeError("Cannot hash a Context.")
def Etiny(self):
"""Returns Etiny (= Emin - prec + 1)"""
@@ -2417,12 +2425,12 @@ class Context(object):
d = Decimal(num, context=self)
return d._fix(self)
- #Methods
+ # Methods
def abs(self, a):
"""Returns the absolute value of the operand.
If the operand is negative, the result is the same as using the minus
- operation on the operand. Otherwise, the result is the same as using
+ operation on the operand. Otherwise, the result is the same as using
the plus operation on the operand.
>>> ExtendedContext.abs(Decimal('2.1'))
@@ -2524,8 +2532,8 @@ class Context(object):
If either operand is a NaN then the general rules apply.
Otherwise, the operands are compared as as though by the compare
- operation. If they are numerically equal then the left-hand operand
- is chosen as the result. Otherwise the maximum (closer to positive
+ operation. If they are numerically equal then the left-hand operand
+ is chosen as the result. Otherwise the maximum (closer to positive
infinity) of the two operands is chosen as the result.
>>> ExtendedContext.max(Decimal('3'), Decimal('2'))
@@ -2544,8 +2552,8 @@ class Context(object):
If either operand is a NaN then the general rules apply.
Otherwise, the operands are compared as as though by the compare
- operation. If they are numerically equal then the left-hand operand
- is chosen as the result. Otherwise the minimum (closer to negative
+ operation. If they are numerically equal then the left-hand operand
+ is chosen as the result. Otherwise the minimum (closer to negative
infinity) of the two operands is chosen as the result.
>>> ExtendedContext.min(Decimal('3'), Decimal('2'))
@@ -2634,14 +2642,14 @@ class Context(object):
The right-hand operand must be a whole number whose integer part (after
any exponent has been applied) has no more than 9 digits and whose
- fractional part (if any) is all zeros before any rounding. The operand
+ fractional part (if any) is all zeros before any rounding. The operand
may be positive, negative, or zero; if negative, the absolute value of
the power is used, and the left-hand operand is inverted (divided into
1) before use.
If the increased precision needed for the intermediate calculations
- exceeds the capabilities of the implementation then an Invalid operation
- condition is raised.
+ exceeds the capabilities of the implementation then an Invalid
+ operation condition is raised.
If, when raising to a negative power, an underflow occurs during the
division into 1, the operation is not halted at that point but
@@ -2679,18 +2687,18 @@ class Context(object):
return a.__pow__(b, modulo, context=self)
def quantize(self, a, b):
- """Returns a value equal to 'a' (rounded) and having the exponent of 'b'.
+ """Returns a value equal to 'a' (rounded), having the exponent of 'b'.
The coefficient of the result is derived from that of the left-hand
- operand. It may be rounded using the current rounding setting (if the
+ operand. It may be rounded using the current rounding setting (if the
exponent is being increased), multiplied by a positive power of ten (if
the exponent is being decreased), or is unchanged (if the exponent is
already equal to that of the right-hand operand).
Unlike other operations, if the length of the coefficient after the
quantize operation would be greater than precision then an Invalid
- operation condition is raised. This guarantees that, unless there is an
- error condition, the exponent of the result of a quantize is always
+ operation condition is raised. This guarantees that, unless there is
+ an error condition, the exponent of the result of a quantize is always
equal to that of the right-hand operand.
Also unlike other operations, quantize will never raise Underflow, even
@@ -2733,9 +2741,9 @@ class Context(object):
"""Returns the remainder from integer division.
The result is the residue of the dividend after the operation of
- calculating integer division as described for divide-integer, rounded to
- precision digits if necessary. The sign of the result, if non-zero, is
- the same as that of the original dividend.
+ calculating integer division as described for divide-integer, rounded
+ to precision digits if necessary. The sign of the result, if
+ non-zero, is the same as that of the original dividend.
This operation will fail under the same conditions as integer division
(that is, if integer division on the same two operands would fail, the
@@ -2759,7 +2767,7 @@ class Context(object):
def remainder_near(self, a, b):
"""Returns to be "a - b * n", where n is the integer nearest the exact
value of "x / b" (if two integers are equally near then the even one
- is chosen). If the result is equal to 0 then its sign will be the
+ is chosen). If the result is equal to 0 then its sign will be the
sign of a.
This operation will fail under the same conditions as integer division
@@ -2801,7 +2809,7 @@ class Context(object):
return a.same_quantum(b)
def sqrt(self, a):
- """Returns the square root of a non-negative number to context precision.
+ """Square root of a non-negative number to context precision.
If the result must be inexact, it is rounded using the round-half-even
algorithm.
@@ -2862,7 +2870,7 @@ class Context(object):
as using the quantize() operation using the given operand as the
left-hand-operand, 1E+0 as the right-hand-operand, and the precision
of the operand as the precision setting, except that no flags will
- be set. The rounding mode is taken from the context.
+ be set. The rounding mode is taken from the context.
>>> ExtendedContext.to_integral(Decimal('2.1'))
Decimal("2")
@@ -2937,8 +2945,9 @@ def _normalize(op1, op2, shouldround = 0, prec = 0):
other_len = len(str(other.int))
if numdigits > (other_len + prec + 1 - tmp_len):
# If the difference in adjusted exps is > prec+1, we know
- # other is insignificant, so might as well put a 1 after the precision.
- # (since this is only for addition.) Also stops use of massive longs.
+ # other is insignificant, so might as well put a 1 after the
+ # precision (since this is only for addition). Also stops
+ # use of massive longs.
extend = prec + 2 - tmp_len
if extend <= 0:
@@ -2961,13 +2970,13 @@ def _adjust_coefficients(op1, op2):
Used on _WorkRep instances during division.
"""
adjust = 0
- #If op1 is smaller, make it larger
+ # If op1 is smaller, make it larger
while op2.int > op1.int:
op1.int *= 10
op1.exp -= 1
adjust += 1
- #If op2 is too small, make it larger
+ # If op2 is too small, make it larger
while op1.int >= (10 * op2.int):
op2.int *= 10
op2.exp -= 1
@@ -2975,7 +2984,7 @@ def _adjust_coefficients(op1, op2):
return op1, op2, adjust
-##### Helper Functions ########################################
+##### Helper Functions ####################################################
def _convert_other(other):
"""Convert other to Decimal.
@@ -3016,16 +3025,16 @@ def _isnan(num):
if not num:
return 0
- #get the sign, get rid of trailing [+-]
+ # Get the sign, get rid of trailing [+-]
sign = 0
if num[0] == '+':
num = num[1:]
- elif num[0] == '-': #elif avoids '+-nan'
+ elif num[0] == '-': # elif avoids '+-nan'
num = num[1:]
sign = 1
if num.startswith('nan'):
- if len(num) > 3 and not num[3:].isdigit(): #diagnostic info
+ if len(num) > 3 and not num[3:].isdigit(): # diagnostic info
return 0
return (1, sign, num[3:].lstrip('0'))
if num.startswith('snan'):
@@ -3035,7 +3044,7 @@ def _isnan(num):
return 0
-##### Setup Specific Contexts ################################
+##### Setup Specific Contexts ############################################
# The default context prototype used by Context()
# Is mutable, so that new contexts can have different default values
@@ -3068,19 +3077,19 @@ ExtendedContext = Context(
)
-##### Useful Constants (internal use only) ####################
+##### Useful Constants (internal use only) ################################
-#Reusable defaults
+# Reusable defaults
Inf = Decimal('Inf')
negInf = Decimal('-Inf')
-#Infsign[sign] is infinity w/ that sign
+# Infsign[sign] is infinity w/ that sign
Infsign = (Inf, negInf)
NaN = Decimal('NaN')
-##### crud for parsing strings #################################
+##### crud for parsing strings #############################################
import re
# There's an optional sign at the start, and an optional exponent
@@ -3100,13 +3109,15 @@ _parser = re.compile(r"""
([eE](?P<exp>[-+]? \d+))?
# \s*
$
-""", re.VERBOSE).match #Uncomment the \s* to allow leading or trailing spaces.
+""", re.VERBOSE).match # Uncomment the \s* to allow leading or trailing spaces.
del re
-# return sign, n, p s.t. float string value == -1**sign * n * 10**p exactly
-
def _string2exact(s):
+ """Return sign, n, p s.t.
+
+ Float string value == -1**sign * n * 10**p exactly
+ """
m = _parser(s)
if m is None:
raise ValueError("invalid literal for Decimal: %r" % s)
diff --git a/Lib/difflib.py b/Lib/difflib.py
index 2a057d9..5b42d07 100644
--- a/Lib/difflib.py
+++ b/Lib/difflib.py
@@ -1946,8 +1946,7 @@ class HtmlDiff(object):
fromlist,tolist,flaglist,next_href,next_id = self._convert_flags(
fromlist,tolist,flaglist,context,numlines)
- import cStringIO
- s = cStringIO.StringIO()
+ s = []
fmt = ' <tr><td class="diff_next"%s>%s</td>%s' + \
'<td class="diff_next">%s</td>%s</tr>\n'
for i in range(len(flaglist)):
@@ -1955,9 +1954,9 @@ class HtmlDiff(object):
# mdiff yields None on separator lines skip the bogus ones
# generated for the first line
if i > 0:
- s.write(' </tbody> \n <tbody>\n')
+ s.append(' </tbody> \n <tbody>\n')
else:
- s.write( fmt % (next_id[i],next_href[i],fromlist[i],
+ s.append( fmt % (next_id[i],next_href[i],fromlist[i],
next_href[i],tolist[i]))
if fromdesc or todesc:
header_row = '<thead><tr>%s%s%s%s</tr></thead>' % (
@@ -1969,7 +1968,7 @@ class HtmlDiff(object):
header_row = ''
table = self._table_template % dict(
- data_rows=s.getvalue(),
+ data_rows=''.join(s),
header_row=header_row,
prefix=self._prefix[1])
diff --git a/Lib/distutils/__init__.py b/Lib/distutils/__init__.py
index 21d34c7..86ad44f 100644
--- a/Lib/distutils/__init__.py
+++ b/Lib/distutils/__init__.py
@@ -20,4 +20,4 @@ __revision__ = "$Id$"
# In general, major and minor version should loosely follow the Python
# version number the distutils code was shipped with.
#
-__version__ = "2.5.0"
+__version__ = "2.5.1"
diff --git a/Lib/distutils/command/build_ext.py b/Lib/distutils/command/build_ext.py
index d0cd162..2832d57 100644
--- a/Lib/distutils/command/build_ext.py
+++ b/Lib/distutils/command/build_ext.py
@@ -186,7 +186,7 @@ class build_ext (Command):
# for extensions under Cygwin and AtheOS Python's library directory must be
# appended to library_dirs
if sys.platform[:6] == 'cygwin' or sys.platform[:6] == 'atheos':
- if sys.executable.find(sys.exec_prefix) != -1:
+ if sys.executable.startswith(os.path.join(sys.exec_prefix, "bin")):
# building third party extensions
self.library_dirs.append(os.path.join(sys.prefix, "lib",
"python" + get_python_version(),
@@ -199,7 +199,7 @@ class build_ext (Command):
# Python's library directory must be appended to library_dirs
if (sys.platform.startswith('linux') or sys.platform.startswith('gnu')) \
and sysconfig.get_config_var('Py_ENABLE_SHARED'):
- if sys.executable.find(sys.exec_prefix) != -1:
+ if sys.executable.startswith(os.path.join(sys.exec_prefix, "bin")):
# building third party extensions
self.library_dirs.append(sysconfig.get_config_var('LIBDIR'))
else:
@@ -533,7 +533,8 @@ class build_ext (Command):
if self.swig_cpp:
log.warn("--swig-cpp is deprecated - use --swig-opts=-c++")
- if self.swig_cpp or ('-c++' in self.swig_opts):
+ if self.swig_cpp or ('-c++' in self.swig_opts) or \
+ ('-c++' in extension.swig_opts):
target_ext = '.cpp'
else:
target_ext = '.c'
diff --git a/Lib/distutils/msvccompiler.py b/Lib/distutils/msvccompiler.py
index ca1feaa..07c76f1 100644
--- a/Lib/distutils/msvccompiler.py
+++ b/Lib/distutils/msvccompiler.py
@@ -187,6 +187,19 @@ def get_build_architecture():
j = sys.version.find(")", i)
return sys.version[i+len(prefix):j]
+def normalize_and_reduce_paths(paths):
+ """Return a list of normalized paths with duplicates removed.
+
+ The current order of paths is maintained.
+ """
+ # Paths are normalized so things like: /a and /a/ aren't both preserved.
+ reduced_paths = []
+ for p in paths:
+ np = os.path.normpath(p)
+ # XXX(nnorwitz): O(n**2), if reduced_paths gets long perhaps use a set.
+ if np not in reduced_paths:
+ reduced_paths.append(np)
+ return reduced_paths
class MSVCCompiler (CCompiler) :
@@ -270,7 +283,8 @@ class MSVCCompiler (CCompiler) :
self.__paths.append(p)
except KeyError:
pass
- os.environ['path'] = ';'.join(self.__paths)
+ self.__paths = normalize_and_reduce_paths(self.__paths)
+ os.environ['path'] = ";".join(self.__paths)
self.preprocess_options = None
if self.__arch == "Intel":
diff --git a/Lib/doctest.py b/Lib/doctest.py
index 5ee4d85..2671cc6 100644
--- a/Lib/doctest.py
+++ b/Lib/doctest.py
@@ -2625,8 +2625,23 @@ __test__ = {"_TestClass": _TestClass,
}
def _test():
- r = unittest.TextTestRunner()
- r.run(DocTestSuite())
+ testfiles = [arg for arg in sys.argv[1:] if arg and arg[0] != '-']
+ if testfiles:
+ for filename in testfiles:
+ if filename.endswith(".py"):
+ # It is a module -- insert its dir into sys.path and try to
+ # import it. If it is part of a package, that possibly won't work
+ # because of package imports.
+ dirname, filename = os.path.split(filename)
+ sys.path.insert(0, dirname)
+ m = __import__(filename[:-3])
+ del sys.path[0]
+ testmod(m)
+ else:
+ testfile(filename, module_relative=False)
+ else:
+ r = unittest.TextTestRunner()
+ r.run(DocTestSuite())
if __name__ == "__main__":
_test()
diff --git a/Lib/email/_parseaddr.py b/Lib/email/_parseaddr.py
index 8047df2..81913a3 100644
--- a/Lib/email/_parseaddr.py
+++ b/Lib/email/_parseaddr.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2002-2006 Python Software Foundation
+# Copyright (C) 2002-2007 Python Software Foundation
# Contact: email-sig@python.org
"""Email address parsing code.
@@ -172,6 +172,7 @@ class AddrlistClass:
self.pos = 0
self.LWS = ' \t'
self.CR = '\r\n'
+ self.FWS = self.LWS + self.CR
self.atomends = self.specials + self.LWS + self.CR
# Note that RFC 2822 now specifies `.' as obs-phrase, meaning that it
# is obsolete syntax. RFC 2822 requires that we recognize obsolete
@@ -418,7 +419,7 @@ class AddrlistClass:
plist = []
while self.pos < len(self.field):
- if self.field[self.pos] in self.LWS:
+ if self.field[self.pos] in self.FWS:
self.pos += 1
elif self.field[self.pos] == '"':
plist.append(self.getquote())
diff --git a/Lib/email/header.py b/Lib/email/header.py
index 3de44f9..ab0d3fc 100644
--- a/Lib/email/header.py
+++ b/Lib/email/header.py
@@ -39,7 +39,8 @@ ecre = re.compile(r'''
\? # literal ?
(?P<encoded>.*?) # non-greedy up to the next ?= is the encoded string
\?= # literal ?=
- ''', re.VERBOSE | re.IGNORECASE)
+ (?=[ \t]|$) # whitespace or the end of the string
+ ''', re.VERBOSE | re.IGNORECASE | re.MULTILINE)
# Field name regexp, including trailing colon, but not separating whitespace,
# according to RFC 2822. Character range is from tilde to exclamation mark.
diff --git a/Lib/email/message.py b/Lib/email/message.py
index 9d25cb0..6fc3af1 100644
--- a/Lib/email/message.py
+++ b/Lib/email/message.py
@@ -238,7 +238,7 @@ class Message:
self.del_param('charset')
self._charset = None
return
- if isinstance(charset, str):
+ if isinstance(charset, basestring):
charset = email.charset.Charset(charset)
if not isinstance(charset, email.charset.Charset):
raise TypeError(charset)
@@ -756,7 +756,9 @@ class Message:
charset = charset[2]
# charset character must be in us-ascii range
try:
- charset = unicode(charset, 'us-ascii').encode('us-ascii')
+ if isinstance(charset, str):
+ charset = unicode(charset, 'us-ascii')
+ charset = charset.encode('us-ascii')
except UnicodeError:
return failobj
# RFC 2046, $4.1.2 says charsets are not case sensitive
diff --git a/Lib/email/test/test_email.py b/Lib/email/test/test_email.py
index c3269d7..a2e09fa 100644
--- a/Lib/email/test/test_email.py
+++ b/Lib/email/test/test_email.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2001-2006 Python Software Foundation
+# Copyright (C) 2001-2007 Python Software Foundation
# Contact: email-sig@python.org
# email package unit tests
@@ -501,6 +501,13 @@ class TestMessageAPI(TestEmailBase):
msg.set_payload(x)
self.assertEqual(msg.get_payload(decode=True), x)
+ def test_get_content_charset(self):
+ msg = Message()
+ msg.set_charset('us-ascii')
+ self.assertEqual('us-ascii', msg.get_content_charset())
+ msg.set_charset(u'us-ascii')
+ self.assertEqual('us-ascii', msg.get_content_charset())
+
# Test the email.Encoders module
@@ -1519,6 +1526,18 @@ class TestRFC2047(unittest.TestCase):
hu = make_header(dh).__unicode__()
eq(hu, u'The quick brown fox jumped over the lazy dog')
+ def test_rfc2047_without_whitespace(self):
+ s = 'Sm=?ISO-8859-1?B?9g==?=rg=?ISO-8859-1?B?5Q==?=sbord'
+ dh = decode_header(s)
+ self.assertEqual(dh, [(s, None)])
+
+ def test_rfc2047_with_whitespace(self):
+ s = 'Sm =?ISO-8859-1?B?9g==?= rg =?ISO-8859-1?B?5Q==?= sbord'
+ dh = decode_header(s)
+ self.assertEqual(dh, [('Sm', None), ('\xf6', 'iso-8859-1'),
+ ('rg', None), ('\xe5', 'iso-8859-1'),
+ ('sbord', None)])
+
# Test the MIMEMessage class
@@ -2164,6 +2183,12 @@ class TestMiscellaneous(TestEmailBase):
# formataddr() quotes the name if there's a dot in it
self.assertEqual(Utils.formataddr((a, b)), y)
+ def test_multiline_from_comment(self):
+ x = """\
+Foo
+\tBar <foo@example.com>"""
+ self.assertEqual(Utils.parseaddr(x), ('Foo Bar', 'foo@example.com'))
+
def test_quote_dump(self):
self.assertEqual(
Utils.formataddr(('A Silly; Person', 'person@dom.ain')),
diff --git a/Lib/email/test/test_email_renamed.py b/Lib/email/test/test_email_renamed.py
index 21061b0..7f72270 100644
--- a/Lib/email/test/test_email_renamed.py
+++ b/Lib/email/test/test_email_renamed.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2001-2006 Python Software Foundation
+# Copyright (C) 2001-2007 Python Software Foundation
# Contact: email-sig@python.org
# email package unit tests
@@ -1524,6 +1524,18 @@ class TestRFC2047(unittest.TestCase):
hu = make_header(dh).__unicode__()
eq(hu, u'The quick brown fox jumped over the lazy dog')
+ def test_rfc2047_missing_whitespace(self):
+ s = 'Sm=?ISO-8859-1?B?9g==?=rg=?ISO-8859-1?B?5Q==?=sbord'
+ dh = decode_header(s)
+ self.assertEqual(dh, [(s, None)])
+
+ def test_rfc2047_with_whitespace(self):
+ s = 'Sm =?ISO-8859-1?B?9g==?= rg =?ISO-8859-1?B?5Q==?= sbord'
+ dh = decode_header(s)
+ self.assertEqual(dh, [('Sm', None), ('\xf6', 'iso-8859-1'),
+ ('rg', None), ('\xe5', 'iso-8859-1'),
+ ('sbord', None)])
+
# Test the MIMEMessage class
@@ -2170,6 +2182,12 @@ class TestMiscellaneous(TestEmailBase):
# formataddr() quotes the name if there's a dot in it
self.assertEqual(utils.formataddr((a, b)), y)
+ def test_multiline_from_comment(self):
+ x = """\
+Foo
+\tBar <foo@example.com>"""
+ self.assertEqual(utils.parseaddr(x), ('Foo Bar', 'foo@example.com'))
+
def test_quote_dump(self):
self.assertEqual(
utils.formataddr(('A Silly; Person', 'person@dom.ain')),
diff --git a/Lib/ftplib.py b/Lib/ftplib.py
index 85e3cc9..cdc893b 100644
--- a/Lib/ftplib.py
+++ b/Lib/ftplib.py
@@ -76,9 +76,15 @@ class FTP:
'''An FTP client class.
- To create a connection, call the class using these argument:
- host, user, passwd, acct
- These are all strings, and have default value ''.
+ To create a connection, call the class using these arguments:
+ host, user, passwd, acct, timeout
+
+ The first four arguments are all strings, and have default value ''.
+ timeout must be numeric and defaults to None if not passed,
+ meaning that no timeout will be set on any ftp socket(s)
+ If a timeout is passed, then this is now the default timeout for all ftp
+ socket operations for this instance.
+
Then use self.connect() with optional host and port argument.
To download a file, use ftp.retrlines('RETR ' + filename),
@@ -102,33 +108,26 @@ class FTP:
# Initialize host to localhost, port to standard ftp port
# Optional arguments are host (for connect()),
# and user, passwd, acct (for login())
- def __init__(self, host='', user='', passwd='', acct=''):
+ def __init__(self, host='', user='', passwd='', acct='', timeout=None):
+ self.timeout = timeout
if host:
self.connect(host)
- if user: self.login(user, passwd, acct)
+ if user:
+ self.login(user, passwd, acct)
- def connect(self, host = '', port = 0):
+ def connect(self, host='', port=0, timeout=None):
'''Connect to host. Arguments are:
- - host: hostname to connect to (string, default previous host)
- - port: port to connect to (integer, default previous port)'''
- if host: self.host = host
- if port: self.port = port
- msg = "getaddrinfo returns an empty list"
- for res in socket.getaddrinfo(self.host, self.port, 0, socket.SOCK_STREAM):
- af, socktype, proto, canonname, sa = res
- try:
- self.sock = socket.socket(af, socktype, proto)
- self.sock.connect(sa)
- except socket.error as err:
- msg = err
- if self.sock:
- self.sock.close()
- self.sock = None
- continue
- break
- if not self.sock:
- raise socket.error, msg
- self.af = af
+ - host: hostname to connect to (string, default previous host)
+ - port: port to connect to (integer, default previous port)
+ '''
+ if host != '':
+ self.host = host
+ if port > 0:
+ self.port = port
+ if timeout is not None:
+ self.timeout = timeout
+ self.sock = socket.create_connection((self.host, self.port), self.timeout)
+ self.af = self.sock.family
self.file = self.sock.makefile('rb')
self.welcome = self.getresp()
return self.welcome
diff --git a/Lib/genericpath.py b/Lib/genericpath.py
index 1574cef..e2bc7da 100644
--- a/Lib/genericpath.py
+++ b/Lib/genericpath.py
@@ -75,3 +75,32 @@ def commonprefix(m):
if s1[i] != s2[i]:
return s1[:i]
return s1[:n]
+
+# Split a path in root and extension.
+# The extension is everything starting at the last dot in the last
+# pathname component; the root is everything before that.
+# It is always true that root + ext == p.
+
+# Generic implementation of splitext, to be parametrized with
+# the separators
+def _splitext(p, sep, altsep, extsep):
+ """Split the extension from a pathname.
+
+ Extension is everything from the last dot to the end, ignoring
+ leading dots. Returns "(root, ext)"; ext may be empty."""
+
+ sepIndex = p.rfind(sep)
+ if altsep:
+ altsepIndex = p.rfind(altsep)
+ sepIndex = max(sepIndex, altsepIndex)
+
+ dotIndex = p.rfind(extsep)
+ if dotIndex > sepIndex:
+ # skip all leading dots
+ filenameIndex = sepIndex + 1
+ while filenameIndex < dotIndex:
+ if p[filenameIndex] != extsep:
+ return p[:dotIndex], p[dotIndex:]
+ filenameIndex += 1
+
+ return p, ''
diff --git a/Lib/glob.py b/Lib/glob.py
index 95656cc..75d7bf9 100644
--- a/Lib/glob.py
+++ b/Lib/glob.py
@@ -1,8 +1,9 @@
"""Filename globbing utility."""
+import sys
import os
-import fnmatch
import re
+import fnmatch
__all__ = ["glob", "iglob"]
@@ -48,13 +49,16 @@ def iglob(pathname):
def glob1(dirname, pattern):
if not dirname:
dirname = os.curdir
+ if isinstance(pattern, unicode) and not isinstance(dirname, unicode):
+ dirname = unicode(dirname, sys.getfilesystemencoding() or
+ sys.getdefaultencoding())
try:
names = os.listdir(dirname)
except os.error:
return []
- if pattern[0]!='.':
- names=filter(lambda x: x[0]!='.',names)
- return fnmatch.filter(names,pattern)
+ if pattern[0] != '.':
+ names = filter(lambda x: x[0] != '.', names)
+ return fnmatch.filter(names, pattern)
def glob0(dirname, basename):
if basename == '':
diff --git a/Lib/heapq.py b/Lib/heapq.py
index 6ee26d1..d34ea3b 100644
--- a/Lib/heapq.py
+++ b/Lib/heapq.py
@@ -311,7 +311,7 @@ except ImportError:
def merge(*iterables):
'''Merge multiple sorted inputs into a single sorted output.
- Similar to sorted(itertools.chain(*iterables)) but returns an iterable,
+ Similar to sorted(itertools.chain(*iterables)) but returns a generator,
does not pull the data into memory all at once, and assumes that each of
the input streams is already sorted (smallest to largest).
diff --git a/Lib/httplib.py b/Lib/httplib.py
index 89d5392..84401ac 100644
--- a/Lib/httplib.py
+++ b/Lib/httplib.py
@@ -625,7 +625,8 @@ class HTTPConnection:
debuglevel = 0
strict = 0
- def __init__(self, host, port=None, strict=None):
+ def __init__(self, host, port=None, strict=None, timeout=None):
+ self.timeout = timeout
self.sock = None
self._buffer = []
self.__response = None
@@ -658,25 +659,8 @@ class HTTPConnection:
def connect(self):
"""Connect to the host and port specified in __init__."""
- msg = "getaddrinfo returns an empty list"
- for res in socket.getaddrinfo(self.host, self.port, 0,
- socket.SOCK_STREAM):
- af, socktype, proto, canonname, sa = res
- try:
- self.sock = socket.socket(af, socktype, proto)
- if self.debuglevel > 0:
- print("connect: (%s, %s)" % (self.host, self.port))
- self.sock.connect(sa)
- except socket.error as msg:
- if self.debuglevel > 0:
- print('connect fail:', (self.host, self.port))
- if self.sock:
- self.sock.close()
- self.sock = None
- continue
- break
- if not self.sock:
- raise socket.error, msg
+ self.sock = socket.create_connection((self.host,self.port),
+ self.timeout)
def close(self):
"""Close the connection to the HTTP server."""
@@ -948,8 +932,8 @@ class HTTPConnection:
self.__state = _CS_IDLE
if response.will_close:
- # Pass the socket to the response
- self.sock = None
+ # this effectively passes the connection to the response
+ self.close()
else:
# remember this, so we can tell when it is complete
self.__response = response
diff --git a/Lib/idlelib/MultiCall.py b/Lib/idlelib/MultiCall.py
index 61730b8..4311999 100644
--- a/Lib/idlelib/MultiCall.py
+++ b/Lib/idlelib/MultiCall.py
@@ -350,6 +350,8 @@ def MultiCallCreator(widget):
triplets.append(triplet)
def event_delete(self, virtual, *sequences):
+ if virtual not in self.__eventinfo:
+ return
func, triplets = self.__eventinfo[virtual]
for seq in sequences:
triplet = _parse_sequence(seq)
diff --git a/Lib/imaplib.py b/Lib/imaplib.py
index fcf68ef..2df533f 100644
--- a/Lib/imaplib.py
+++ b/Lib/imaplib.py
@@ -746,8 +746,10 @@ class IMAP4:
if not command in Commands:
raise self.error("Unknown IMAP4 UID command: %s" % command)
if self.state not in Commands[command]:
- raise self.error('command %s illegal in state %s'
- % (command, self.state))
+ raise self.error("command %s illegal in state %s, "
+ "only allowed in states %s" %
+ (command, self.state,
+ ', '.join(Commands[command])))
name = 'UID'
typ, dat = self._simple_command(name, command, *args)
if command in ('SEARCH', 'SORT'):
@@ -811,8 +813,10 @@ class IMAP4:
if self.state not in Commands[name]:
self.literal = None
- raise self.error(
- 'command %s illegal in state %s' % (name, self.state))
+ raise self.error("command %s illegal in state %s, "
+ "only allowed in states %s" %
+ (name, self.state,
+ ', '.join(Commands[name])))
for typ in ('OK', 'NO', 'BAD'):
if typ in self.untagged_responses:
diff --git a/Lib/logging/handlers.py b/Lib/logging/handlers.py
index 83bf3e3..71ec9c3 100644
--- a/Lib/logging/handlers.py
+++ b/Lib/logging/handlers.py
@@ -365,12 +365,14 @@ class SocketHandler(logging.Handler):
self.retryMax = 30.0
self.retryFactor = 2.0
- def makeSocket(self):
+ def makeSocket(self, timeout=1):
"""
A factory method which allows subclasses to define the precise
type of socket they want.
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ if hasattr(s, 'settimeout'):
+ s.settimeout(timeout)
s.connect((self.host, self.port))
return s
diff --git a/Lib/macpath.py b/Lib/macpath.py
index d389d70..f54ffa0 100644
--- a/Lib/macpath.py
+++ b/Lib/macpath.py
@@ -2,6 +2,7 @@
import os
from stat import *
+import genericpath
from genericpath import *
__all__ = ["normcase","isabs","join","splitdrive","split","splitext",
@@ -69,17 +70,8 @@ def split(s):
def splitext(p):
- """Split a path into root and extension.
- The extension is everything starting at the last dot in the last
- pathname component; the root is everything before that.
- It is always true that root + ext == p."""
-
- i = p.rfind('.')
- if i<=p.rfind(':'):
- return p, ''
- else:
- return p[:i], p[i:]
-
+ return genericpath._splitext(p, sep, altsep, extsep)
+splitext.__doc__ = genericpath._splitext.__doc__
def splitdrive(p):
"""Split a pathname into a drive specification and the rest of the
diff --git a/Lib/ntpath.py b/Lib/ntpath.py
index 23d5127..99d7a4a 100644
--- a/Lib/ntpath.py
+++ b/Lib/ntpath.py
@@ -8,6 +8,7 @@ module as os.path.
import os
import stat
import sys
+import genericpath
from genericpath import *
__all__ = ["normcase","isabs","join","splitdrive","split","splitext",
@@ -15,7 +16,7 @@ __all__ = ["normcase","isabs","join","splitdrive","split","splitext",
"getatime","getctime", "islink","exists","lexists","isdir","isfile",
"ismount","walk","expanduser","expandvars","normpath","abspath",
"splitunc","curdir","pardir","sep","pathsep","defpath","altsep",
- "extsep","devnull","realpath","supports_unicode_filenames"]
+ "extsep","devnull","realpath","supports_unicode_filenames","relpath"]
# strings representing various path-related bits and pieces
curdir = '.'
@@ -182,16 +183,8 @@ def split(p):
# It is always true that root + ext == p.
def splitext(p):
- """Split the extension from a pathname.
-
- Extension is everything from the last dot to the end.
- Return (root, ext), either part may be empty."""
-
- i = p.rfind('.')
- if i<=max(p.rfind('/'), p.rfind('\\')):
- return p, ''
- else:
- return p[:i], p[i:]
+ return genericpath._splitext(p, sep, altsep, extsep)
+splitext.__doc__ = genericpath._splitext.__doc__
# Return the tail (basename) part of a path.
@@ -285,36 +278,44 @@ def expanduser(path):
i, n = 1, len(path)
while i < n and path[i] not in '/\\':
i = i + 1
- if i == 1:
- if 'HOME' in os.environ:
- userhome = os.environ['HOME']
- elif not 'HOMEPATH' in os.environ:
- return path
- else:
- try:
- drive = os.environ['HOMEDRIVE']
- except KeyError:
- drive = ''
- userhome = join(drive, os.environ['HOMEPATH'])
- else:
+
+ if 'HOME' in os.environ:
+ userhome = os.environ['HOME']
+ elif 'USERPROFILE' in os.environ:
+ userhome = os.environ['USERPROFILE']
+ elif not 'HOMEPATH' in os.environ:
return path
+ else:
+ try:
+ drive = os.environ['HOMEDRIVE']
+ except KeyError:
+ drive = ''
+ userhome = join(drive, os.environ['HOMEPATH'])
+
+ if i != 1: #~user
+ userhome = join(dirname(userhome), path[1:i])
+
return userhome + path[i:]
# Expand paths containing shell variable substitutions.
# The following rules apply:
# - no expansion within single quotes
-# - no escape character, except for '$$' which is translated into '$'
+# - '$$' is translated into '$'
+# - '%%' is translated into '%' if '%%' are not seen in %var1%%var2%
# - ${varname} is accepted.
-# - varnames can be made out of letters, digits and the character '_'
+# - $varname is accepted.
+# - %varname% is accepted.
+# - varnames can be made out of letters, digits and the characters '_-'
+# (though is not verifed in the ${varname} and %varname% cases)
# XXX With COMMAND.COM you can use any characters in a variable name,
# XXX except '^|<>='.
def expandvars(path):
- """Expand shell variables of form $var and ${var}.
+ """Expand shell variables of the forms $var, ${var} and %var%.
Unknown variables are left unchanged."""
- if '$' not in path:
+ if '$' not in path and '%' not in path:
return path
import string
varchars = string.ascii_letters + string.digits + '_-'
@@ -332,6 +333,24 @@ def expandvars(path):
except ValueError:
res = res + path
index = pathlen - 1
+ elif c == '%': # variable or '%'
+ if path[index + 1:index + 2] == '%':
+ res = res + c
+ index = index + 1
+ else:
+ path = path[index+1:]
+ pathlen = len(path)
+ try:
+ index = path.index('%')
+ except ValueError:
+ res = res + '%' + path
+ index = pathlen - 1
+ else:
+ var = path[:index]
+ if var in os.environ:
+ res = res + os.environ[var]
+ else:
+ res = res + '%' + var + '%'
elif c == '$': # variable or '$$'
if path[index + 1:index + 2] == '$':
res = res + c
@@ -446,3 +465,29 @@ realpath = abspath
# Win9x family and earlier have no Unicode filename support.
supports_unicode_filenames = (hasattr(sys, "getwindowsversion") and
sys.getwindowsversion()[3] >= 2)
+
+def relpath(path, start=curdir):
+ """Return a relative version of a path"""
+
+ if not path:
+ raise ValueError("no path specified")
+ start_list = abspath(start).split(sep)
+ path_list = abspath(path).split(sep)
+ if start_list[0].lower() != path_list[0].lower():
+ unc_path, rest = splitunc(path)
+ unc_start, rest = splitunc(start)
+ if bool(unc_path) ^ bool(unc_start):
+ raise ValueError("Cannot mix UNC and non-UNC paths (%s and %s)"
+ % (path, start))
+ else:
+ raise ValueError("path is on drive %s, start on drive %s"
+ % (path_list[0], start_list[0]))
+ # Work out how much of the filepath is shared by start and path.
+ for i in range(min(len(start_list), len(path_list))):
+ if start_list[i].lower() != path_list[i].lower():
+ break
+ else:
+ i += 1
+
+ rel_list = [pardir] * (len(start_list)-i) + path_list[i:]
+ return join(*rel_list)
diff --git a/Lib/os.py b/Lib/os.py
index 9599cf3..fdb3bed 100644
--- a/Lib/os.py
+++ b/Lib/os.py
@@ -221,7 +221,7 @@ def renames(old, new):
__all__.extend(["makedirs", "removedirs", "renames"])
-def walk(top, topdown=True, onerror=None):
+def walk(top, topdown=True, onerror=None, followlinks=False):
"""Directory tree generator.
For each directory in the directory tree rooted at top (including top
@@ -257,6 +257,10 @@ def walk(top, topdown=True, onerror=None):
to abort the walk. Note that the filename is available as the
filename attribute of the exception object.
+ By default, os.walk does not follow symbolic links to subdirectories on
+ systems that support them. In order to get this functionality, set the
+ optional argument 'followlinks' to true.
+
Caution: if you pass a relative pathname for top, don't change the
current working directory between resumptions of walk. walk never
changes the current directory, and assumes that the client doesn't
@@ -300,8 +304,8 @@ def walk(top, topdown=True, onerror=None):
yield top, dirs, nondirs
for name in dirs:
path = join(top, name)
- if not islink(path):
- for x in walk(path, topdown, onerror):
+ if followlinks or not islink(path):
+ for x in walk(path, topdown, onerror, followlinks):
yield x
if not topdown:
yield top, dirs, nondirs
diff --git a/Lib/pdb.doc b/Lib/pdb.doc
index 81df323..c513954 100644
--- a/Lib/pdb.doc
+++ b/Lib/pdb.doc
@@ -131,6 +131,12 @@ n(ext)
r(eturn)
Continue execution until the current function returns.
+run [args...]
+ Restart the debugged python program. If a string is supplied it is
+ splitted with "shlex", and the result is used as the new sys.argv.
+ History, breakpoints, actions and debugger options are preserved.
+ "restart" is an alias for "run".
+
c(ont(inue))
Continue execution, only stop when a breakpoint is encountered.
diff --git a/Lib/pdb.py b/Lib/pdb.py
index 4eba7bb..3d06b0a 100755
--- a/Lib/pdb.py
+++ b/Lib/pdb.py
@@ -13,6 +13,12 @@ import os
import re
import pprint
import traceback
+
+
+class Restart(Exception):
+ """Causes a debugger to be restarted for the debugged python program."""
+ pass
+
# Create a custom safe Repr instance and increase its maxstring.
# The default of 30 truncates error messages too easily.
_repr = Repr()
@@ -484,11 +490,16 @@ class Pdb(bdb.Bdb, cmd.Cmd):
except ValueError:
# something went wrong
print('Breakpoint index %r is not a number' % args[0], file=self.stdout)
+ return
try:
cond = args[1]
except:
cond = None
- bp = bdb.Breakpoint.bpbynumber[bpnum]
+ try:
+ bp = bdb.Breakpoint.bpbynumber[bpnum]
+ except IndexError:
+ print >>self.stdout, 'Breakpoint index %r is not valid' % args[0]
+ return
if bp:
bp.cond = cond
if not cond:
@@ -503,11 +514,16 @@ class Pdb(bdb.Bdb, cmd.Cmd):
except ValueError:
# something went wrong
print('Breakpoint index %r is not a number' % args[0], file=self.stdout)
+ return
try:
count = int(args[1].strip())
except:
count = 0
- bp = bdb.Breakpoint.bpbynumber[bpnum]
+ try:
+ bp = bdb.Breakpoint.bpbynumber[bpnum]
+ except IndexError:
+ print >>self.stdout, 'Breakpoint index %r is not valid' % args[0]
+ return
if bp:
bp.ignore = count
if count > 0:
@@ -601,6 +617,18 @@ class Pdb(bdb.Bdb, cmd.Cmd):
return 1
do_n = do_next
+ def do_run(self, arg):
+ """Restart program by raising an exception to be caught in the main debugger
+ loop. If arguments were given, set them in sys.argv."""
+ if arg:
+ import shlex
+ argv0 = sys.argv[0:1]
+ sys.argv = shlex.split(arg)
+ sys.argv[:0] = argv0
+ raise Restart
+
+ do_restart = do_run
+
def do_return(self, arg):
self.set_return(self.curframe)
return 1
@@ -1005,6 +1033,15 @@ command with a 'global' command, e.g.:
(Pdb) global list_options; list_options = ['-l']
(Pdb)""", file=self.stdout)
+ def help_run(self):
+ print("""run [args...]
+Restart the debugged python program. If a string is supplied, it is
+splitted with "shlex" and the result is used as the new sys.argv.
+History, breakpoints, actions and debugger options are preserved.
+"restart" is an alias for "run".""")
+
+ help_restart = help_run
+
def help_quit(self):
self.help_q()
@@ -1113,11 +1150,17 @@ see no sign that the breakpoint was reached.
return None
def _runscript(self, filename):
- # Start with fresh empty copy of globals and locals and tell the script
- # that it's being run as __main__ to avoid scripts being able to access
- # the pdb.py namespace.
- globals_ = {"__name__" : "__main__"}
- locals_ = globals_
+ # The script has to run in __main__ namespace (or imports from
+ # __main__ will break).
+ #
+ # So we clear up the __main__ and set several special variables
+ # (this gets rid of pdb's globals and cleans old variables on restarts).
+ import __main__
+ __main__.__dict__.clear()
+ __main__.__dict__.update({"__name__" : "__main__",
+ "__file__" : filename,
+ "__builtins__": __builtins__,
+ })
# When bdb sets tracing, a number of call and line events happens
# BEFORE debugger even reaches user's code (and the exact sequence of
@@ -1128,7 +1171,7 @@ see no sign that the breakpoint was reached.
self.mainpyfile = self.canonic(filename)
self._user_requested_quit = 0
statement = 'execfile( "%s")' % filename
- self.run(statement, globals=globals_, locals=locals_)
+ self.run(statement)
# Simplified interface
@@ -1197,9 +1240,8 @@ def main():
# Note on saving/restoring sys.argv: it's a good idea when sys.argv was
# modified by the script being debugged. It's a bad idea when it was
- # changed by the user from the command line. The best approach would be to
- # have a "restart" command which would allow explicit specification of
- # command line arguments.
+ # changed by the user from the command line. There is a "restart" command which
+ # allows explicit specification of command line arguments.
pdb = Pdb()
while 1:
try:
@@ -1207,6 +1249,9 @@ def main():
if pdb._user_requested_quit:
break
print("The program finished and will be restarted")
+ except Restart:
+ print("Restarting", mainpyfile, "with arguments:")
+ print("\t" + " ".join(sys.argv[1:]))
except SystemExit:
# In most cases SystemExit does not warrant a post-mortem session.
print("The program exited via sys.exit(). Exit status: ", end=' ')
@@ -1223,5 +1268,6 @@ def main():
# When invoked as main program, invoke the debugger on a script
-if __name__=='__main__':
- main()
+if __name__ == '__main__':
+ import pdb
+ pdb.main()
diff --git a/Lib/popen2.py b/Lib/popen2.py
index 3618487..ab30463 100644
--- a/Lib/popen2.py
+++ b/Lib/popen2.py
@@ -200,45 +200,3 @@ else:
return inst.fromchild, inst.tochild
__all__.extend(["Popen3", "Popen4"])
-
-def _test():
- # When the test runs, there shouldn't be any open pipes
- _cleanup()
- assert not _active, "Active pipes when test starts " + repr([c.cmd for c in _active])
- cmd = "cat"
- teststr = "ab cd\n"
- if os.name == "nt":
- cmd = "more"
- # "more" doesn't act the same way across Windows flavors,
- # sometimes adding an extra newline at the start or the
- # end. So we strip whitespace off both ends for comparison.
- expected = teststr.strip()
- print("testing popen2...")
- r, w = popen2(cmd)
- w.write(teststr)
- w.close()
- got = r.read()
- if got.strip() != expected:
- raise ValueError("wrote %r read %r" % (teststr, got))
- print("testing popen3...")
- try:
- r, w, e = popen3([cmd])
- except:
- r, w, e = popen3(cmd)
- w.write(teststr)
- w.close()
- got = r.read()
- if got.strip() != expected:
- raise ValueError("wrote %r read %r" % (teststr, got))
- got = e.read()
- if got:
- raise ValueError("unexpected %r on stderr" % (got,))
- for inst in _active[:]:
- inst.wait()
- _cleanup()
- if _active:
- raise ValueError("_active not empty")
- print("All OK")
-
-if __name__ == '__main__':
- _test()
diff --git a/Lib/poplib.py b/Lib/poplib.py
index adf784f..0caed18 100644
--- a/Lib/poplib.py
+++ b/Lib/poplib.py
@@ -76,24 +76,10 @@ class POP3:
"""
- def __init__(self, host, port = POP3_PORT):
+ def __init__(self, host, port=POP3_PORT, timeout=None):
self.host = host
self.port = port
- msg = "getaddrinfo returns an empty list"
- self.sock = None
- for res in socket.getaddrinfo(self.host, self.port, 0, socket.SOCK_STREAM):
- af, socktype, proto, canonname, sa = res
- try:
- self.sock = socket.socket(af, socktype, proto)
- self.sock.connect(sa)
- except socket.error as msg:
- if self.sock:
- self.sock.close()
- self.sock = None
- continue
- break
- if not self.sock:
- raise socket.error, msg
+ self.sock = socket.create_connection((host, port), timeout)
self.file = self.sock.makefile('rb')
self._debugging = 0
self.welcome = self._getresp()
diff --git a/Lib/posixpath.py b/Lib/posixpath.py
index 1521236..6f15d48 100644
--- a/Lib/posixpath.py
+++ b/Lib/posixpath.py
@@ -12,6 +12,7 @@ for manipulation of the pathname component of URLs.
import os
import stat
+import genericpath
from genericpath import *
__all__ = ["normcase","isabs","join","splitdrive","split","splitext",
@@ -20,7 +21,7 @@ __all__ = ["normcase","isabs","join","splitdrive","split","splitext",
"ismount","walk","expanduser","expandvars","normpath","abspath",
"samefile","sameopenfile","samestat",
"curdir","pardir","sep","pathsep","defpath","altsep","extsep",
- "devnull","realpath","supports_unicode_filenames"]
+ "devnull","realpath","supports_unicode_filenames","relpath"]
# strings representing various path-related bits and pieces
curdir = '.'
@@ -88,14 +89,8 @@ def split(p):
# It is always true that root + ext == p.
def splitext(p):
- """Split the extension from a pathname. Extension is everything from the
- last dot to the end. Returns "(root, ext)", either part may be empty."""
- i = p.rfind('.')
- if i<=p.rfind('/'):
- return p, ''
- else:
- return p[:i], p[i:]
-
+ return genericpath._splitext(p, sep, altsep, extsep)
+splitext.__doc__ = genericpath._splitext.__doc__
# Split a pathname into a drive specification and the rest of the
# path. Useful on DOS/Windows/NT; on Unix, the drive is always empty.
@@ -387,3 +382,18 @@ def _resolve_link(path):
return path
supports_unicode_filenames = False
+
+def relpath(path, start=curdir):
+ """Return a relative version of a path"""
+
+ if not path:
+ raise ValueError("no path specified")
+
+ start_list = abspath(start).split(sep)
+ path_list = abspath(path).split(sep)
+
+ # Work out how much of the filepath is shared by start and path.
+ i = len(commonprefix([start_list, path_list]))
+
+ rel_list = [pardir] * (len(start_list)-i) + path_list[i:]
+ return join(*rel_list)
diff --git a/Lib/pydoc.py b/Lib/pydoc.py
index 59c4593..6a272ce 100755
--- a/Lib/pydoc.py
+++ b/Lib/pydoc.py
@@ -854,7 +854,7 @@ class HTMLDoc(Doc):
if imclass is not cl:
note = ' from ' + self.classlink(imclass, mod)
else:
- if object.im_self:
+ if object.im_self is not None:
note = ' method of %s instance' % self.classlink(
object.im_self.__class__, mod)
else:
@@ -1232,7 +1232,7 @@ class TextDoc(Doc):
if imclass is not cl:
note = ' from ' + classname(imclass, mod)
else:
- if object.im_self:
+ if object.im_self is not None:
note = ' method of %s instance' % classname(
object.im_self.__class__, mod)
else:
@@ -1468,6 +1468,27 @@ def resolve(thing, forceload=0):
else:
return thing, getattr(thing, '__name__', None)
+def render_doc(thing, title='Python Library Documentation: %s', forceload=0):
+ """Render text documentation, given an object or a path to an object."""
+ object, name = resolve(thing, forceload)
+ desc = describe(object)
+ module = inspect.getmodule(object)
+ if name and '.' in name:
+ desc += ' in ' + name[:name.rfind('.')]
+ elif module and module is not object:
+ desc += ' in module ' + module.__name__
+ elif not (inspect.ismodule(object) or
+ inspect.isclass(object) or
+ inspect.isroutine(object) or
+ inspect.isgetsetdescriptor(object) or
+ inspect.ismemberdescriptor(object) or
+ isinstance(object, property)):
+ # If the passed object is a piece of data or an instance,
+ # document its available methods instead of its value.
+ object = type(object)
+ desc += ' object'
+ return title % desc + '\n\n' + text.document(object, name)
+
def doc(thing, title='Python Library Documentation: %s', forceload=0):
"""Display text documentation, given an object or a path to an object."""
try:
@@ -1488,7 +1509,7 @@ def doc(thing, title='Python Library Documentation: %s', forceload=0):
# document its available methods instead of its value.
object = type(object)
desc += ' object'
- pager(title % desc + '\n\n' + text.document(object, name))
+ pager(render_doc(thing, title, forceload))
except (ImportError, ErrorDuringImport) as value:
print(value)
@@ -1519,6 +1540,7 @@ def raw_input(prompt):
class Helper:
keywords = {
'and': 'BOOLEAN',
+ 'as': 'with',
'assert': ('ref/assert', ''),
'break': ('ref/break', 'while for'),
'class': ('ref/class', 'CLASSES SPECIALMETHODS'),
@@ -1546,6 +1568,7 @@ class Helper:
'return': ('ref/return', 'FUNCTIONS'),
'try': ('ref/try', 'EXCEPTIONS'),
'while': ('ref/while', 'break continue if TRUTHVALUE'),
+ 'with': ('ref/with', 'CONTEXTMANAGERS EXCEPTIONS yield'),
'yield': ('ref/yield', ''),
}
@@ -1626,6 +1649,7 @@ class Helper:
'LOOPING': ('ref/compound', 'for while break continue'),
'TRUTHVALUE': ('lib/truth', 'if while and or not BASICMETHODS'),
'DEBUGGING': ('lib/module-pdb', 'pdb'),
+ 'CONTEXTMANAGERS': ('ref/context-managers', 'with'),
}
def __init__(self, input, output):
@@ -1634,16 +1658,21 @@ class Helper:
self.docdir = None
execdir = os.path.dirname(sys.executable)
homedir = os.environ.get('PYTHONHOME')
+ join = os.path.join
for dir in [os.environ.get('PYTHONDOCS'),
homedir and os.path.join(homedir, 'doc'),
- os.path.join(execdir, 'doc'),
- '/usr/doc/python-docs-' + sys.version.split()[0],
- '/usr/doc/python-' + sys.version.split()[0],
- '/usr/doc/python-docs-' + sys.version[:3],
- '/usr/doc/python-' + sys.version[:3],
- os.path.join(sys.prefix, 'Resources/English.lproj/Documentation')]:
- if dir and os.path.isdir(os.path.join(dir, 'lib')):
+ join(execdir, 'doc'), # for Windows
+ join(sys.prefix, 'doc/python-docs-' + sys.version.split()[0]),
+ join(sys.prefix, 'doc/python-' + sys.version.split()[0]),
+ join(sys.prefix, 'doc/python-docs-' + sys.version[:3]),
+ join(sys.prefix, 'doc/python-' + sys.version[:3]),
+ join(sys.prefix, 'Resources/English.lproj/Documentation')]:
+ if dir and os.path.isdir(join(dir, 'lib')):
self.docdir = dir
+ break
+ if dir and os.path.isdir(join(dir, 'html', 'lib')):
+ self.docdir = join(dir, 'html')
+ break
def __repr__(self):
if inspect.stack()[1][3] == '?':
diff --git a/Lib/rexec.py b/Lib/rexec.py
index e5ceb72..c4ce1d0 100644
--- a/Lib/rexec.py
+++ b/Lib/rexec.py
@@ -29,7 +29,7 @@ __all__ = ["RExec"]
class FileBase:
ok_file_methods = ('fileno', 'flush', 'isatty', 'read', 'readline',
- 'readlines', 'seek', 'tell', 'write', 'writelines',
+ 'readlines', 'seek', 'tell', 'write', 'writelines',
'__iter__')
@@ -181,7 +181,7 @@ class RExec(ihooks._Verbose):
"""
- raise RuntimeError, "This code is not secure in Python 2.2 and 2.3"
+ raise RuntimeError, "This code is not secure in Python 2.2 and later"
ihooks._Verbose.__init__(self, verbose)
# XXX There's a circular reference here:
diff --git a/Lib/robotparser.py b/Lib/robotparser.py
index 0edb55f..32aba46 100644
--- a/Lib/robotparser.py
+++ b/Lib/robotparser.py
@@ -65,7 +65,7 @@ class RobotFileParser:
lines.append(line.strip())
line = f.readline()
self.errcode = opener.errcode
- if self.errcode == 401 or self.errcode == 403:
+ if self.errcode in (401, 403):
self.disallow_all = True
_debug("disallow all")
elif self.errcode >= 400:
@@ -168,10 +168,7 @@ class RobotFileParser:
def __str__(self):
- ret = ""
- for entry in self.entries:
- ret = ret + str(entry) + "\n"
- return ret
+ return ''.join([str(entry) + "\n" for entry in self.entries])
class RuleLine:
@@ -198,12 +195,12 @@ class Entry:
self.rulelines = []
def __str__(self):
- ret = ""
+ ret = []
for agent in self.useragents:
- ret = ret + "User-agent: "+agent+"\n"
+ ret.extend(["User-agent: ", agent, "\n"])
for line in self.rulelines:
- ret = ret + str(line) + "\n"
- return ret
+ ret.extend([str(line), "\n"])
+ return ''.join(ret)
def applies_to(self, useragent):
"""check if this entry applies to the specified agent"""
diff --git a/Lib/sched.py b/Lib/sched.py
index 2f8df05..7c3235e 100644
--- a/Lib/sched.py
+++ b/Lib/sched.py
@@ -72,7 +72,7 @@ class scheduler:
def empty(self):
"""Check whether the queue is empty."""
- return not not self.queue
+ return not self.queue
def run(self):
"""Execute events until the queue is empty.
diff --git a/Lib/site.py b/Lib/site.py
index 48cf385..c4e0d51 100644
--- a/Lib/site.py
+++ b/Lib/site.py
@@ -134,7 +134,7 @@ def addpackage(sitedir, name, known_paths):
for line in f:
if line.startswith("#"):
continue
- if line.startswith("import"):
+ if line.startswith("import ") or line.startswith("import\t"):
exec(line)
continue
line = line.rstrip()
diff --git a/Lib/smtplib.py b/Lib/smtplib.py
index 9851d08..299a70d 100755
--- a/Lib/smtplib.py
+++ b/Lib/smtplib.py
@@ -226,10 +226,11 @@ class SMTP:
debuglevel = 0
file = None
helo_resp = None
+ ehlo_msg = "ehlo"
ehlo_resp = None
does_esmtp = 0
- def __init__(self, host = '', port = 0, local_hostname = None):
+ def __init__(self, host='', port=0, local_hostname=None, timeout=None):
"""Initialize a new instance.
If specified, `host' is the name of the remote host to which to
@@ -240,6 +241,7 @@ class SMTP:
the local hostname is found using socket.getfqdn().
"""
+ self.timeout = timeout
self.esmtp_features = {}
self.default_port = SMTP_PORT
if host:
@@ -273,12 +275,11 @@ class SMTP:
"""
self.debuglevel = debuglevel
- def _get_socket(self,af, socktype, proto,sa):
+ def _get_socket(self, port, host, timeout):
# This makes it simpler for SMTP_SSL to use the SMTP connect code
# and just alter the socket connection bit.
- self.sock = socket.socket(af, socktype, proto)
if self.debuglevel > 0: print('connect:', (host, port), file=stderr)
- self.sock.connect(sa)
+ return socket.create_connection((port, host), timeout)
def connect(self, host='localhost', port = 0):
"""Connect to a host on a given port.
@@ -297,24 +298,10 @@ class SMTP:
host, port = host[:i], host[i+1:]
try: port = int(port)
except ValueError:
- raise socket.error, "nonnumeric port"
+ raise socket.error("nonnumeric port")
if not port: port = self.default_port
if self.debuglevel > 0: print('connect:', (host, port), file=stderr)
- msg = "getaddrinfo returns an empty list"
- self.sock = None
- for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
- af, socktype, proto, canonname, sa = res
- try:
- self._get_socket(af,socktype,proto,sa)
- except socket.error as msg:
- if self.debuglevel > 0: print('connect fail:', msg, file=stderr)
- if self.sock:
- self.sock.close()
- self.sock = None
- continue
- break
- if not self.sock:
- raise socket.error, msg
+ self.sock = self._get_socket(host, port, self.timeout)
(code, msg) = self.getreply()
if self.debuglevel > 0: print("connect:", msg, file=stderr)
return (code, msg)
@@ -401,7 +388,7 @@ class SMTP:
host.
"""
self.esmtp_features = {}
- self.putcmd("ehlo", name or self.local_hostname)
+ self.putcmd(self.ehlo_msg, name or self.local_hostname)
(code,msg)=self.getreply()
# According to RFC1869 some (badly written)
# MTA's will disconnect on an ehlo. Toss an exception if
@@ -731,21 +718,64 @@ class SMTP_SSL(SMTP):
are also optional - they can contain a PEM formatted private key and
certificate chain file for the SSL connection.
"""
- def __init__(self, host = '', port = 0, local_hostname = None,
- keyfile = None, certfile = None):
+ def __init__(self, host='', port=0, local_hostname=None,
+ keyfile=None, certfile=None, timeout=None):
self.keyfile = keyfile
self.certfile = certfile
- SMTP.__init__(self,host,port,local_hostname)
+ SMTP.__init__(self, host, port, local_hostname, timeout)
self.default_port = SMTP_SSL_PORT
- def _get_socket(self,af, socktype, proto,sa):
- self.sock = socket.socket(af, socktype, proto)
+ def _get_socket(self, host, port, timeout):
if self.debuglevel > 0: print('connect:', (host, port), file=stderr)
- self.sock.connect(sa)
+ self.sock = socket.create_connection((host, port), timeout)
sslobj = socket.ssl(self.sock, self.keyfile, self.certfile)
self.sock = SSLFakeSocket(self.sock, sslobj)
self.file = SSLFakeFile(sslobj)
+#
+# LMTP extension
+#
+LMTP_PORT = 2003
+
+class LMTP(SMTP):
+ """LMTP - Local Mail Transfer Protocol
+
+ The LMTP protocol, which is very similar to ESMTP, is heavily based
+ on the standard SMTP client. It's common to use Unix sockets for LMTP,
+ so our connect() method must support that as well as a regular
+ host:port server. To specify a Unix socket, you must use an absolute
+ path as the host, starting with a '/'.
+
+ Authentication is supported, using the regular SMTP mechanism. When
+ using a Unix socket, LMTP generally don't support or require any
+ authentication, but your mileage might vary."""
+
+ ehlo_msg = "lhlo"
+
+ def __init__(self, host = '', port = LMTP_PORT, local_hostname = None):
+ """Initialize a new instance."""
+ SMTP.__init__(self, host, port, local_hostname)
+
+ def connect(self, host = 'localhost', port = 0):
+ """Connect to the LMTP daemon, on either a Unix or a TCP socket."""
+ if host[0] != '/':
+ return SMTP.connect(self, host, port)
+
+ # Handle Unix-domain sockets.
+ try:
+ self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ self.sock.connect(host)
+ except socket.error as msg:
+ if self.debuglevel > 0: print>>stderr, 'connect fail:', host
+ if self.sock:
+ self.sock.close()
+ self.sock = None
+ raise socket.error(msg)
+ (code, msg) = self.getreply()
+ if self.debuglevel > 0: print>>stderr, "connect:", msg
+ return (code, msg)
+
+
# Test the sendmail method, which tests most of the others.
# Note: This always sends to localhost.
if __name__ == '__main__':
diff --git a/Lib/socket.py b/Lib/socket.py
index 3fe6ec5..8dd2383 100644
--- a/Lib/socket.py
+++ b/Lib/socket.py
@@ -24,6 +24,7 @@ inet_ntoa() -- convert 32-bit packed format IP to string (123.45.67.89)
ssl() -- secure socket layer support (only available if configured)
socket.getdefaulttimeout() -- get the default timeout value
socket.setdefaulttimeout() -- set the default timeout value
+create_connection() -- connects to an address, with an optional timeout
[*] not available on all platforms!
@@ -139,8 +140,6 @@ class _closedsocket(object):
__slots__ = []
def _dummy(*args):
raise error(EBADF, 'Bad file descriptor')
- def close(self):
- pass
# All _delegate_methods must also be initialized here.
send = recv = recv_into = sendto = recvfrom = recvfrom_into = _dummy
__getattr__ = _dummy
@@ -159,7 +158,6 @@ class _socketobject(object):
setattr(self, method, getattr(_sock, method))
def close(self):
- self._sock.close()
self._sock = _closedsocket()
dummy = self._sock._dummy
for method in _delegate_methods:
@@ -414,3 +412,32 @@ class _fileobject(object):
if not line:
raise StopIteration
return line
+
+
+def create_connection(address, timeout=None):
+ """Connect to address (host, port) with an optional timeout.
+
+ Provides access to socketobject timeout for higher-level
+ protocols. Passing a timeout will set the timeout on the
+ socket instance (if not present, or passed as None, the
+ default global timeout setting will be used).
+ """
+
+ msg = "getaddrinfo returns an empty list"
+ host, port = address
+ for res in getaddrinfo(host, port, 0, SOCK_STREAM):
+ af, socktype, proto, canonname, sa = res
+ sock = None
+ try:
+ sock = socket(af, socktype, proto)
+ if timeout is not None:
+ sock.settimeout(timeout)
+ sock.connect(sa)
+ return sock
+
+ except error as err:
+ msg = err
+ if sock is not None:
+ sock.close()
+
+ raise error(msg)
diff --git a/Lib/sre.py b/Lib/sre.py
index 390094a..c04576b 100644
--- a/Lib/sre.py
+++ b/Lib/sre.py
@@ -8,3 +8,6 @@ warnings.warn("The sre module is deprecated, please import re.",
from re import *
from re import __all__
+
+# old pickles expect the _compile() reconstructor in this module
+from re import _compile
diff --git a/Lib/subprocess.py b/Lib/subprocess.py
index e9c9a0e..2aa02ae 100644
--- a/Lib/subprocess.py
+++ b/Lib/subprocess.py
@@ -597,7 +597,7 @@ class Popen(object):
# either have to redirect all three or none. If the subprocess
# user has only redirected one or two handles, we are
# automatically creating PIPEs for the rest. We should close
- # these after the process is started. See bug #1124861.
+ # these after the process is started. See bug #1124861.
if mswindows:
if stdin is None and p2cwrite is not None:
os.close(p2cwrite)
@@ -629,7 +629,7 @@ class Popen(object):
return data
- def __del__(self):
+ def __del__(self, sys=sys):
if not self._child_created:
# We didn't get to successfully create a child process.
return
diff --git a/Lib/tarfile.py b/Lib/tarfile.py
index 963127c..efade27 100644
--- a/Lib/tarfile.py
+++ b/Lib/tarfile.py
@@ -33,7 +33,7 @@
__version__ = "$Revision$"
# $Source$
-version = "0.8.0"
+version = "0.9.0"
__author__ = "Lars Gustäbel (lars@gustaebel.de)"
__date__ = "$Date$"
__cvsid__ = "$Id$"
@@ -50,6 +50,7 @@ import errno
import time
import struct
import copy
+import re
if sys.platform == 'mac':
# This module needs work for MacOS9, especially in the area of pathname
@@ -71,42 +72,60 @@ from __builtin__ import open as _open # Since 'open' is TarFile.open
#---------------------------------------------------------
# tar constants
#---------------------------------------------------------
-NUL = "\0" # the null character
-BLOCKSIZE = 512 # length of processing blocks
+NUL = "\0" # the null character
+BLOCKSIZE = 512 # length of processing blocks
RECORDSIZE = BLOCKSIZE * 20 # length of records
-MAGIC = "ustar" # magic tar string
-VERSION = "00" # version number
+GNU_MAGIC = "ustar \0" # magic gnu tar string
+POSIX_MAGIC = "ustar\x0000" # magic posix tar string
-LENGTH_NAME = 100 # maximum length of a filename
-LENGTH_LINK = 100 # maximum length of a linkname
-LENGTH_PREFIX = 155 # maximum length of the prefix field
-MAXSIZE_MEMBER = 077777777777 # maximum size of a file (11 octal digits)
+LENGTH_NAME = 100 # maximum length of a filename
+LENGTH_LINK = 100 # maximum length of a linkname
+LENGTH_PREFIX = 155 # maximum length of the prefix field
-REGTYPE = "0" # regular file
+REGTYPE = "0" # regular file
AREGTYPE = "\0" # regular file
-LNKTYPE = "1" # link (inside tarfile)
-SYMTYPE = "2" # symbolic link
-CHRTYPE = "3" # character special device
-BLKTYPE = "4" # block special device
-DIRTYPE = "5" # directory
+LNKTYPE = "1" # link (inside tarfile)
+SYMTYPE = "2" # symbolic link
+CHRTYPE = "3" # character special device
+BLKTYPE = "4" # block special device
+DIRTYPE = "5" # directory
FIFOTYPE = "6" # fifo special device
CONTTYPE = "7" # contiguous file
-GNUTYPE_LONGNAME = "L" # GNU tar extension for longnames
-GNUTYPE_LONGLINK = "K" # GNU tar extension for longlink
-GNUTYPE_SPARSE = "S" # GNU tar extension for sparse file
+GNUTYPE_LONGNAME = "L" # GNU tar longname
+GNUTYPE_LONGLINK = "K" # GNU tar longlink
+GNUTYPE_SPARSE = "S" # GNU tar sparse file
+
+XHDTYPE = "x" # POSIX.1-2001 extended header
+XGLTYPE = "g" # POSIX.1-2001 global header
+SOLARIS_XHDTYPE = "X" # Solaris extended header
+
+USTAR_FORMAT = 0 # POSIX.1-1988 (ustar) format
+GNU_FORMAT = 1 # GNU tar format
+PAX_FORMAT = 2 # POSIX.1-2001 (pax) format
+DEFAULT_FORMAT = GNU_FORMAT
#---------------------------------------------------------
# tarfile constants
#---------------------------------------------------------
-SUPPORTED_TYPES = (REGTYPE, AREGTYPE, LNKTYPE, # file types that tarfile
- SYMTYPE, DIRTYPE, FIFOTYPE, # can cope with.
+# File types that tarfile supports:
+SUPPORTED_TYPES = (REGTYPE, AREGTYPE, LNKTYPE,
+ SYMTYPE, DIRTYPE, FIFOTYPE,
CONTTYPE, CHRTYPE, BLKTYPE,
GNUTYPE_LONGNAME, GNUTYPE_LONGLINK,
GNUTYPE_SPARSE)
-REGULAR_TYPES = (REGTYPE, AREGTYPE, # file types that somehow
- CONTTYPE, GNUTYPE_SPARSE) # represent regular files
+# File types that will be treated as a regular file.
+REGULAR_TYPES = (REGTYPE, AREGTYPE,
+ CONTTYPE, GNUTYPE_SPARSE)
+
+# File types that are part of the GNU tar format.
+GNU_TYPES = (GNUTYPE_LONGNAME, GNUTYPE_LONGLINK,
+ GNUTYPE_SPARSE)
+
+# Fields from a pax header that override a TarInfo attribute.
+PAX_FIELDS = ("path", "linkpath", "size", "mtime",
+ "uid", "gid", "uname", "gname")
#---------------------------------------------------------
# Bits used in the mode field, values in octal.
@@ -133,6 +152,13 @@ TOWRITE = 0002 # write by other
TOEXEC = 0001 # execute/search by other
#---------------------------------------------------------
+# initialization
+#---------------------------------------------------------
+ENCODING = sys.getfilesystemencoding()
+if ENCODING is None:
+ ENCODING = "ascii"
+
+#---------------------------------------------------------
# Some useful functions
#---------------------------------------------------------
@@ -141,6 +167,15 @@ def stn(s, length):
"""
return s[:length] + (length - len(s)) * NUL
+def nts(s):
+ """Convert a null-terminated string field to a python string.
+ """
+ # Use the string up to the first null char.
+ p = s.find("\0")
+ if p == -1:
+ return s
+ return s[:p]
+
def nti(s):
"""Convert a number field to a python number.
"""
@@ -148,7 +183,7 @@ def nti(s):
# itn() below.
if s[0] != chr(0200):
try:
- n = int(s.rstrip(NUL + " ") or "0", 8)
+ n = int(nts(s) or "0", 8)
except ValueError:
raise HeaderError("invalid header")
else:
@@ -158,7 +193,7 @@ def nti(s):
n += ord(s[i + 1])
return n
-def itn(n, digits=8, posix=False):
+def itn(n, digits=8, format=DEFAULT_FORMAT):
"""Convert a python number to a number field.
"""
# POSIX 1003.1-1988 requires numbers to be encoded as a string of
@@ -170,7 +205,7 @@ def itn(n, digits=8, posix=False):
if 0 <= n < 8 ** (digits - 1):
s = "%0*o" % (digits - 1, n) + NUL
else:
- if posix:
+ if format != GNU_FORMAT or n >= 256 ** (digits - 1):
raise ValueError("overflow in number field")
if n < 0:
@@ -516,7 +551,10 @@ class _Stream:
buf = self.__read(self.bufsize)
if not buf:
break
- buf = self.cmp.decompress(buf)
+ try:
+ buf = self.cmp.decompress(buf)
+ except IOError:
+ raise ReadError("invalid compressed data")
t.append(buf)
c += len(buf)
t = "".join(t)
@@ -577,6 +615,7 @@ class _BZ2Proxy(object):
def __init__(self, fileobj, mode):
self.fileobj = fileobj
self.mode = mode
+ self.name = getattr(self.fileobj, "name", None)
self.init()
def init(self):
@@ -849,8 +888,8 @@ class TarInfo(object):
"""Construct a TarInfo object. name is the optional name
of the member.
"""
- self.name = name # member name (dirnames must end with '/')
- self.mode = 0666 # file permissions
+ self.name = name # member name
+ self.mode = 0644 # file permissions
self.uid = 0 # user id
self.gid = 0 # group id
self.size = 0 # file size
@@ -858,17 +897,274 @@ class TarInfo(object):
self.chksum = 0 # header checksum
self.type = REGTYPE # member type
self.linkname = "" # link name
- self.uname = "user" # user name
- self.gname = "group" # group name
+ self.uname = "root" # user name
+ self.gname = "root" # group name
self.devmajor = 0 # device major number
self.devminor = 0 # device minor number
self.offset = 0 # the tar header starts here
self.offset_data = 0 # the file's data starts here
+ self.pax_headers = {} # pax header information
+
+ # In pax headers the "name" and "linkname" field are called
+ # "path" and "linkpath".
+ def _getpath(self):
+ return self.name
+ def _setpath(self, name):
+ self.name = name
+ path = property(_getpath, _setpath)
+
+ def _getlinkpath(self):
+ return self.linkname
+ def _setlinkpath(self, linkname):
+ self.linkname = linkname
+ linkpath = property(_getlinkpath, _setlinkpath)
+
def __repr__(self):
return "<%s %r at %#x>" % (self.__class__.__name__,self.name,id(self))
+ def get_info(self):
+ """Return the TarInfo's attributes as a dictionary.
+ """
+ info = {
+ "name": normpath(self.name),
+ "mode": self.mode & 07777,
+ "uid": self.uid,
+ "gid": self.gid,
+ "size": self.size,
+ "mtime": self.mtime,
+ "chksum": self.chksum,
+ "type": self.type,
+ "linkname": normpath(self.linkname) if self.linkname else "",
+ "uname": self.uname,
+ "gname": self.gname,
+ "devmajor": self.devmajor,
+ "devminor": self.devminor
+ }
+
+ if info["type"] == DIRTYPE and not info["name"].endswith("/"):
+ info["name"] += "/"
+
+ return info
+
+ def tobuf(self, format=DEFAULT_FORMAT, encoding=ENCODING):
+ """Return a tar header as a string of 512 byte blocks.
+ """
+ if format == USTAR_FORMAT:
+ return self.create_ustar_header()
+ elif format == GNU_FORMAT:
+ return self.create_gnu_header()
+ elif format == PAX_FORMAT:
+ return self.create_pax_header(encoding)
+ else:
+ raise ValueError("invalid format")
+
+ def create_ustar_header(self):
+ """Return the object as a ustar header block.
+ """
+ info = self.get_info()
+ info["magic"] = POSIX_MAGIC
+
+ if len(info["linkname"]) > LENGTH_LINK:
+ raise ValueError("linkname is too long")
+
+ if len(info["name"]) > LENGTH_NAME:
+ info["prefix"], info["name"] = self._posix_split_name(info["name"])
+
+ return self._create_header(info, USTAR_FORMAT)
+
+ def create_gnu_header(self):
+ """Return the object as a GNU header block sequence.
+ """
+ info = self.get_info()
+ info["magic"] = GNU_MAGIC
+
+ buf = ""
+ if len(info["linkname"]) > LENGTH_LINK:
+ buf += self._create_gnu_long_header(info["linkname"], GNUTYPE_LONGLINK)
+
+ if len(info["name"]) > LENGTH_NAME:
+ buf += self._create_gnu_long_header(info["name"], GNUTYPE_LONGNAME)
+
+ return buf + self._create_header(info, GNU_FORMAT)
+
+ def create_pax_header(self, encoding):
+ """Return the object as a ustar header block. If it cannot be
+ represented this way, prepend a pax extended header sequence
+ with supplement information.
+ """
+ info = self.get_info()
+ info["magic"] = POSIX_MAGIC
+ pax_headers = self.pax_headers.copy()
+
+ # Test string fields for values that exceed the field length or cannot
+ # be represented in ASCII encoding.
+ for name, hname, length in (
+ ("name", "path", LENGTH_NAME), ("linkname", "linkpath", LENGTH_LINK),
+ ("uname", "uname", 32), ("gname", "gname", 32)):
+
+ val = info[name].decode(encoding)
+
+ # Try to encode the string as ASCII.
+ try:
+ val.encode("ascii")
+ except UnicodeEncodeError:
+ pax_headers[hname] = val
+ continue
+
+ if len(val) > length:
+ if name == "name":
+ # Try to squeeze a longname in the prefix and name fields as in
+ # ustar format.
+ try:
+ info["prefix"], info["name"] = self._posix_split_name(info["name"])
+ except ValueError:
+ pax_headers[hname] = val
+ else:
+ continue
+ else:
+ pax_headers[hname] = val
+
+ # Test number fields for values that exceed the field limit or values
+ # that like to be stored as float.
+ for name, digits in (("uid", 8), ("gid", 8), ("size", 12), ("mtime", 12)):
+ val = info[name]
+ if not 0 <= val < 8 ** (digits - 1) or isinstance(val, float):
+ pax_headers[name] = unicode(val)
+ info[name] = 0
+
+ if pax_headers:
+ buf = self._create_pax_generic_header(pax_headers)
+ else:
+ buf = ""
+
+ return buf + self._create_header(info, USTAR_FORMAT)
+
+ @classmethod
+ def create_pax_global_header(cls, pax_headers, encoding):
+ """Return the object as a pax global header block sequence.
+ """
+ new_headers = {}
+ for key, val in pax_headers.items():
+ key = cls._to_unicode(key, encoding)
+ val = cls._to_unicode(val, encoding)
+ new_headers[key] = val
+ return cls._create_pax_generic_header(new_headers, type=XGLTYPE)
+
+ @staticmethod
+ def _to_unicode(value, encoding):
+ if isinstance(value, unicode):
+ return value
+ elif isinstance(value, (int, float)):
+ return unicode(value)
+ elif isinstance(value, str):
+ return unicode(value, encoding)
+ else:
+ raise ValueError("unable to convert to unicode: %r" % value)
+
+ def _posix_split_name(self, name):
+ """Split a name longer than 100 chars into a prefix
+ and a name part.
+ """
+ prefix = name[:LENGTH_PREFIX + 1]
+ while prefix and prefix[-1] != "/":
+ prefix = prefix[:-1]
+
+ name = name[len(prefix):]
+ prefix = prefix[:-1]
+
+ if not prefix or len(name) > LENGTH_NAME:
+ raise ValueError("name is too long")
+ return prefix, name
+
+ @staticmethod
+ def _create_header(info, format):
+ """Return a header block. info is a dictionary with file
+ information, format must be one of the *_FORMAT constants.
+ """
+ parts = [
+ stn(info.get("name", ""), 100),
+ itn(info.get("mode", 0) & 07777, 8, format),
+ itn(info.get("uid", 0), 8, format),
+ itn(info.get("gid", 0), 8, format),
+ itn(info.get("size", 0), 12, format),
+ itn(info.get("mtime", 0), 12, format),
+ " ", # checksum field
+ info.get("type", REGTYPE),
+ stn(info.get("linkname", ""), 100),
+ stn(info.get("magic", ""), 8),
+ stn(info.get("uname", ""), 32),
+ stn(info.get("gname", ""), 32),
+ itn(info.get("devmajor", 0), 8, format),
+ itn(info.get("devminor", 0), 8, format),
+ stn(info.get("prefix", ""), 155)
+ ]
+
+ buf = struct.pack("%ds" % BLOCKSIZE, "".join(parts))
+ chksum = calc_chksums(buf[-BLOCKSIZE:])[0]
+ buf = buf[:-364] + "%06o\0" % chksum + buf[-357:]
+ return buf
+
+ @staticmethod
+ def _create_payload(payload):
+ """Return the string payload filled with zero bytes
+ up to the next 512 byte border.
+ """
+ blocks, remainder = divmod(len(payload), BLOCKSIZE)
+ if remainder > 0:
+ payload += (BLOCKSIZE - remainder) * NUL
+ return payload
+
+ @classmethod
+ def _create_gnu_long_header(cls, name, type):
+ """Return a GNUTYPE_LONGNAME or GNUTYPE_LONGLINK sequence
+ for name.
+ """
+ name += NUL
+
+ info = {}
+ info["name"] = "././@LongLink"
+ info["type"] = type
+ info["size"] = len(name)
+ info["magic"] = GNU_MAGIC
+
+ # create extended header + name blocks.
+ return cls._create_header(info, USTAR_FORMAT) + \
+ cls._create_payload(name)
+
+ @classmethod
+ def _create_pax_generic_header(cls, pax_headers, type=XHDTYPE):
+ """Return a POSIX.1-2001 extended or global header sequence
+ that contains a list of keyword, value pairs. The values
+ must be unicode objects.
+ """
+ records = []
+ for keyword, value in pax_headers.items():
+ keyword = keyword.encode("utf8")
+ value = value.encode("utf8")
+ l = len(keyword) + len(value) + 3 # ' ' + '=' + '\n'
+ n = p = 0
+ while True:
+ n = l + len(str(p))
+ if n == p:
+ break
+ p = n
+ records.append("%d %s=%s\n" % (p, keyword, value))
+ records = "".join(records)
+
+ # We use a hardcoded "././@PaxHeader" name like star does
+ # instead of the one that POSIX recommends.
+ info = {}
+ info["name"] = "././@PaxHeader"
+ info["type"] = type
+ info["size"] = len(records)
+ info["magic"] = POSIX_MAGIC
+
+ # Create pax header + record blocks.
+ return cls._create_header(info, USTAR_FORMAT) + \
+ cls._create_payload(records)
+
@classmethod
def frombuf(cls, buf):
"""Construct a TarInfo object from a 512 byte string buffer.
@@ -882,125 +1178,251 @@ class TarInfo(object):
if chksum not in calc_chksums(buf):
raise HeaderError("bad checksum")
- tarinfo = cls()
- tarinfo.buf = buf
- tarinfo.name = buf[0:100].rstrip(NUL)
- tarinfo.mode = nti(buf[100:108])
- tarinfo.uid = nti(buf[108:116])
- tarinfo.gid = nti(buf[116:124])
- tarinfo.size = nti(buf[124:136])
- tarinfo.mtime = nti(buf[136:148])
- tarinfo.chksum = chksum
- tarinfo.type = buf[156:157]
- tarinfo.linkname = buf[157:257].rstrip(NUL)
- tarinfo.uname = buf[265:297].rstrip(NUL)
- tarinfo.gname = buf[297:329].rstrip(NUL)
- tarinfo.devmajor = nti(buf[329:337])
- tarinfo.devminor = nti(buf[337:345])
- prefix = buf[345:500].rstrip(NUL)
-
- if prefix and not tarinfo.issparse():
- tarinfo.name = prefix + "/" + tarinfo.name
+ obj = cls()
+ obj.buf = buf
+ obj.name = nts(buf[0:100])
+ obj.mode = nti(buf[100:108])
+ obj.uid = nti(buf[108:116])
+ obj.gid = nti(buf[116:124])
+ obj.size = nti(buf[124:136])
+ obj.mtime = nti(buf[136:148])
+ obj.chksum = chksum
+ obj.type = buf[156:157]
+ obj.linkname = nts(buf[157:257])
+ obj.uname = nts(buf[265:297])
+ obj.gname = nts(buf[297:329])
+ obj.devmajor = nti(buf[329:337])
+ obj.devminor = nti(buf[337:345])
+ prefix = nts(buf[345:500])
+
+ # Old V7 tar format represents a directory as a regular
+ # file with a trailing slash.
+ if obj.type == AREGTYPE and obj.name.endswith("/"):
+ obj.type = DIRTYPE
- return tarinfo
+ # Remove redundant slashes from directories.
+ if obj.isdir():
+ obj.name = obj.name.rstrip("/")
- def tobuf(self, posix=False):
- """Return a tar header as a string of 512 byte blocks.
- """
- buf = ""
- type = self.type
- prefix = ""
+ # Reconstruct a ustar longname.
+ if prefix and obj.type not in GNU_TYPES:
+ obj.name = prefix + "/" + obj.name
+ return obj
- if self.name.endswith("/"):
- type = DIRTYPE
+ @classmethod
+ def fromtarfile(cls, tarfile):
+ """Return the next TarInfo object from TarFile object
+ tarfile.
+ """
+ buf = tarfile.fileobj.read(BLOCKSIZE)
+ if not buf:
+ return
+ obj = cls.frombuf(buf)
+ obj.offset = tarfile.fileobj.tell() - BLOCKSIZE
+ return obj._proc_member(tarfile)
- if type in (GNUTYPE_LONGNAME, GNUTYPE_LONGLINK):
- # Prevent "././@LongLink" from being normalized.
- name = self.name
+ #--------------------------------------------------------------------------
+ # The following are methods that are called depending on the type of a
+ # member. The entry point is _proc_member() which can be overridden in a
+ # subclass to add custom _proc_*() methods. A _proc_*() method MUST
+ # implement the following
+ # operations:
+ # 1. Set self.offset_data to the position where the data blocks begin,
+ # if there is data that follows.
+ # 2. Set tarfile.offset to the position where the next member's header will
+ # begin.
+ # 3. Return self or another valid TarInfo object.
+ def _proc_member(self, tarfile):
+ """Choose the right processing method depending on
+ the type and call it.
+ """
+ if self.type in (GNUTYPE_LONGNAME, GNUTYPE_LONGLINK):
+ return self._proc_gnulong(tarfile)
+ elif self.type == GNUTYPE_SPARSE:
+ return self._proc_sparse(tarfile)
+ elif self.type in (XHDTYPE, XGLTYPE, SOLARIS_XHDTYPE):
+ return self._proc_pax(tarfile)
else:
- name = normpath(self.name)
+ return self._proc_builtin(tarfile)
- if type == DIRTYPE:
- # directories should end with '/'
- name += "/"
+ def _proc_builtin(self, tarfile):
+ """Process a builtin type or an unknown type which
+ will be treated as a regular file.
+ """
+ self.offset_data = tarfile.fileobj.tell()
+ offset = self.offset_data
+ if self.isreg() or self.type not in SUPPORTED_TYPES:
+ # Skip the following data blocks.
+ offset += self._block(self.size)
+ tarfile.offset = offset
- linkname = self.linkname
- if linkname:
- # if linkname is empty we end up with a '.'
- linkname = normpath(linkname)
+ # Patch the TarInfo object with saved extended
+ # header information.
+ for keyword, value in tarfile.pax_headers.items():
+ if keyword in PAX_FIELDS:
+ setattr(self, keyword, value)
+ self.pax_headers[keyword] = value
- if posix:
- if self.size > MAXSIZE_MEMBER:
- raise ValueError("file is too large (>= 8 GB)")
+ return self
- if len(self.linkname) > LENGTH_LINK:
- raise ValueError("linkname is too long (>%d)" % (LENGTH_LINK))
+ def _proc_gnulong(self, tarfile):
+ """Process the blocks that hold a GNU longname
+ or longlink member.
+ """
+ buf = tarfile.fileobj.read(self._block(self.size))
- if len(name) > LENGTH_NAME:
- prefix = name[:LENGTH_PREFIX + 1]
- while prefix and prefix[-1] != "/":
- prefix = prefix[:-1]
+ # Fetch the next header and process it.
+ b = tarfile.fileobj.read(BLOCKSIZE)
+ t = self.frombuf(b)
+ t.offset = self.offset
+ next = t._proc_member(tarfile)
- name = name[len(prefix):]
- prefix = prefix[:-1]
+ # Patch the TarInfo object from the next header with
+ # the longname information.
+ next.offset = self.offset
+ if self.type == GNUTYPE_LONGNAME:
+ next.name = buf.rstrip(NUL)
+ elif self.type == GNUTYPE_LONGLINK:
+ next.linkname = buf.rstrip(NUL)
- if not prefix or len(name) > LENGTH_NAME:
- raise ValueError("name is too long")
+ return next
- else:
- if len(self.linkname) > LENGTH_LINK:
- buf += self._create_gnulong(self.linkname, GNUTYPE_LONGLINK)
+ def _proc_sparse(self, tarfile):
+ """Process a GNU sparse header plus extra headers.
+ """
+ buf = self.buf
+ sp = _ringbuffer()
+ pos = 386
+ lastpos = 0
+ realpos = 0
+ # There are 4 possible sparse structs in the
+ # first header.
+ for i in xrange(4):
+ try:
+ offset = nti(buf[pos:pos + 12])
+ numbytes = nti(buf[pos + 12:pos + 24])
+ except ValueError:
+ break
+ if offset > lastpos:
+ sp.append(_hole(lastpos, offset - lastpos))
+ sp.append(_data(offset, numbytes, realpos))
+ realpos += numbytes
+ lastpos = offset + numbytes
+ pos += 24
- if len(name) > LENGTH_NAME:
- buf += self._create_gnulong(name, GNUTYPE_LONGNAME)
+ isextended = ord(buf[482])
+ origsize = nti(buf[483:495])
- parts = [
- stn(name, 100),
- itn(self.mode & 07777, 8, posix),
- itn(self.uid, 8, posix),
- itn(self.gid, 8, posix),
- itn(self.size, 12, posix),
- itn(self.mtime, 12, posix),
- " ", # checksum field
- type,
- stn(self.linkname, 100),
- stn(MAGIC, 6),
- stn(VERSION, 2),
- stn(self.uname, 32),
- stn(self.gname, 32),
- itn(self.devmajor, 8, posix),
- itn(self.devminor, 8, posix),
- stn(prefix, 155)
- ]
+ # If the isextended flag is given,
+ # there are extra headers to process.
+ while isextended == 1:
+ buf = tarfile.fileobj.read(BLOCKSIZE)
+ pos = 0
+ for i in xrange(21):
+ try:
+ offset = nti(buf[pos:pos + 12])
+ numbytes = nti(buf[pos + 12:pos + 24])
+ except ValueError:
+ break
+ if offset > lastpos:
+ sp.append(_hole(lastpos, offset - lastpos))
+ sp.append(_data(offset, numbytes, realpos))
+ realpos += numbytes
+ lastpos = offset + numbytes
+ pos += 24
+ isextended = ord(buf[504])
- buf += struct.pack("%ds" % BLOCKSIZE, "".join(parts))
- chksum = calc_chksums(buf[-BLOCKSIZE:])[0]
- buf = buf[:-364] + "%06o\0" % chksum + buf[-357:]
- self.buf = buf
- return buf
+ if lastpos < origsize:
+ sp.append(_hole(lastpos, origsize - lastpos))
+
+ self.sparse = sp
- def _create_gnulong(self, name, type):
- """Create a GNU longname/longlink header from name.
- It consists of an extended tar header, with the length
- of the longname as size, followed by data blocks,
- which contain the longname as a null terminated string.
+ self.offset_data = tarfile.fileobj.tell()
+ tarfile.offset = self.offset_data + self._block(self.size)
+ self.size = origsize
+
+ return self
+
+ def _proc_pax(self, tarfile):
+ """Process an extended or global header as described in
+ POSIX.1-2001.
"""
- name += NUL
+ # Read the header information.
+ buf = tarfile.fileobj.read(self._block(self.size))
- tarinfo = self.__class__()
- tarinfo.name = "././@LongLink"
- tarinfo.type = type
- tarinfo.mode = 0
- tarinfo.size = len(name)
-
- # create extended header
- buf = tarinfo.tobuf()
- # create name blocks
- buf += name
- blocks, remainder = divmod(len(name), BLOCKSIZE)
- if remainder > 0:
- buf += (BLOCKSIZE - remainder) * NUL
- return buf
+ # A pax header stores supplemental information for either
+ # the following file (extended) or all following files
+ # (global).
+ if self.type == XGLTYPE:
+ pax_headers = tarfile.pax_headers
+ else:
+ pax_headers = tarfile.pax_headers.copy()
+
+ # Fields in POSIX.1-2001 that are numbers, all other fields
+ # are treated as UTF-8 strings.
+ type_mapping = {
+ "atime": float,
+ "ctime": float,
+ "mtime": float,
+ "uid": int,
+ "gid": int,
+ "size": int
+ }
+
+ # Parse pax header information. A record looks like that:
+ # "%d %s=%s\n" % (length, keyword, value). length is the size
+ # of the complete record including the length field itself and
+ # the newline.
+ regex = re.compile(r"(\d+) ([^=]+)=", re.U)
+ pos = 0
+ while True:
+ match = regex.match(buf, pos)
+ if not match:
+ break
+
+ length, keyword = match.groups()
+ length = int(length)
+ value = buf[match.end(2) + 1:match.start(1) + length - 1]
+
+ keyword = keyword.decode("utf8")
+ keyword = keyword.encode(tarfile.encoding)
+
+ value = value.decode("utf8")
+ if keyword in type_mapping:
+ try:
+ value = type_mapping[keyword](value)
+ except ValueError:
+ value = 0
+ else:
+ value = value.encode(tarfile.encoding)
+
+ pax_headers[keyword] = value
+ pos += length
+
+ # Fetch the next header that will be patched with the
+ # supplement information from the pax header (extended
+ # only).
+ t = self.fromtarfile(tarfile)
+
+ if self.type != XGLTYPE and t is not None:
+ # Patch the TarInfo object from the next header with
+ # the pax header's information.
+ for keyword, value in pax_headers.items():
+ if keyword in PAX_FIELDS:
+ setattr(t, keyword, value)
+ pax_headers[keyword] = value
+ t.pax_headers = pax_headers.copy()
+
+ return t
+
+ def _block(self, count):
+ """Round up a byte count by BLOCKSIZE and return it,
+ e.g. _block(834) => 1024.
+ """
+ blocks, remainder = divmod(count, BLOCKSIZE)
+ if remainder:
+ blocks += 1
+ return blocks * BLOCKSIZE
def isreg(self):
return self.type in REGULAR_TYPES
@@ -1040,12 +1462,18 @@ class TarFile(object):
# messages (if debug >= 0). If > 0, errors
# are passed to the caller as exceptions.
- posix = False # If True, generates POSIX.1-1990-compliant
- # archives (no GNU extensions!)
+ format = DEFAULT_FORMAT # The format to use when creating an archive.
+
+ encoding = ENCODING # Transfer UTF-8 strings from POSIX.1-2001
+ # headers to this encoding.
+
+ tarinfo = TarInfo # The default TarInfo class to use.
- fileobject = ExFileObject
+ fileobject = ExFileObject # The default ExFileObject class to use.
- def __init__(self, name=None, mode="r", fileobj=None):
+ def __init__(self, name=None, mode="r", fileobj=None, format=None,
+ tarinfo=None, dereference=None, ignore_zeros=None, encoding=None,
+ pax_headers=None, debug=None, errorlevel=None):
"""Open an (uncompressed) tar archive `name'. `mode' is either 'r' to
read from an existing archive, 'a' to append data to an existing
file or 'w' to create a new file overwriting an existing one. `mode'
@@ -1054,58 +1482,86 @@ class TarFile(object):
can be determined, `mode' is overridden by `fileobj's mode.
`fileobj' is not closed, when TarFile is closed.
"""
- self.name = os.path.abspath(name)
-
if len(mode) > 1 or mode not in "raw":
raise ValueError("mode must be 'r', 'a' or 'w'")
- self._mode = mode
- self.mode = {"r": "rb", "a": "r+b", "w": "wb"}[mode]
+ self.mode = mode
+ self._mode = {"r": "rb", "a": "r+b", "w": "wb"}[mode]
if not fileobj:
- if self._mode == "a" and not os.path.exists(self.name):
+ if self.mode == "a" and not os.path.exists(name):
# Create nonexistent files in append mode.
- self._mode = "w"
- self.mode = "wb"
- fileobj = _open(self.name, self.mode)
+ self.mode = "w"
+ self._mode = "wb"
+ fileobj = _open(name, self._mode)
self._extfileobj = False
else:
- if self.name is None and hasattr(fileobj, "name"):
- self.name = os.path.abspath(fileobj.name)
+ if name is None and hasattr(fileobj, "name"):
+ name = fileobj.name
if hasattr(fileobj, "mode"):
- self.mode = fileobj.mode
+ self._mode = fileobj.mode
self._extfileobj = True
+ self.name = os.path.abspath(name)
self.fileobj = fileobj
- # Init datastructures
+ # Init attributes.
+ if format is not None:
+ self.format = format
+ if tarinfo is not None:
+ self.tarinfo = tarinfo
+ if dereference is not None:
+ self.dereference = dereference
+ if ignore_zeros is not None:
+ self.ignore_zeros = ignore_zeros
+ if encoding is not None:
+ self.encoding = encoding
+ if debug is not None:
+ self.debug = debug
+ if errorlevel is not None:
+ self.errorlevel = errorlevel
+
+ # Init datastructures.
self.closed = False
self.members = [] # list of members as TarInfo objects
self._loaded = False # flag if all members have been read
self.offset = 0 # current position in the archive file
self.inodes = {} # dictionary caching the inodes of
# archive members already added
+ self.pax_headers = {} # save contents of global pax headers
- if self._mode == "r":
+ if self.mode == "r":
self.firstmember = None
self.firstmember = self.next()
- if self._mode == "a":
+ if self.mode == "a":
# Move to the end of the archive,
# before the first empty block.
self.firstmember = None
while True:
- try:
- tarinfo = self.next()
- except ReadError:
- self.fileobj.seek(0)
- break
- if tarinfo is None:
+ if self.next() is None:
if self.offset > 0:
self.fileobj.seek(- BLOCKSIZE, 1)
break
- if self._mode in "aw":
+ if self.mode in "aw":
self._loaded = True
+ if pax_headers:
+ buf = self.tarinfo.create_pax_global_header(
+ pax_headers.copy(), self.encoding)
+ self.fileobj.write(buf)
+ self.offset += len(buf)
+
+ def _getposix(self):
+ return self.format == USTAR_FORMAT
+ def _setposix(self, value):
+ import warnings
+ warnings.warn("use the format attribute instead", DeprecationWarning)
+ if value:
+ self.format = USTAR_FORMAT
+ else:
+ self.format = GNU_FORMAT
+ posix = property(_getposix, _setposix)
+
#--------------------------------------------------------------------------
# Below are the classmethods which act as alternate constructors to the
# TarFile class. The open() method is the only one that is needed for
@@ -1118,7 +1574,7 @@ class TarFile(object):
# by adding it to the mapping in OPEN_METH.
@classmethod
- def open(cls, name=None, mode="r", fileobj=None, bufsize=20*512):
+ def open(cls, name=None, mode="r", fileobj=None, bufsize=RECORDSIZE, **kwargs):
"""Open a tar archive for reading, writing or appending. Return
an appropriate TarFile class.
@@ -1151,8 +1607,8 @@ class TarFile(object):
if fileobj is not None:
saved_pos = fileobj.tell()
try:
- return func(name, "r", fileobj)
- except (ReadError, CompressionError):
+ return func(name, "r", fileobj, **kwargs)
+ except (ReadError, CompressionError) as e:
if fileobj is not None:
fileobj.seek(saved_pos)
continue
@@ -1169,7 +1625,7 @@ class TarFile(object):
func = getattr(cls, cls.OPEN_METH[comptype])
else:
raise CompressionError("unknown compression type %r" % comptype)
- return func(name, filemode, fileobj)
+ return func(name, filemode, fileobj, **kwargs)
elif "|" in mode:
filemode, comptype = mode.split("|", 1)
@@ -1180,25 +1636,26 @@ class TarFile(object):
raise ValueError("mode must be 'r' or 'w'")
t = cls(name, filemode,
- _Stream(name, filemode, comptype, fileobj, bufsize))
+ _Stream(name, filemode, comptype, fileobj, bufsize),
+ **kwargs)
t._extfileobj = False
return t
elif mode in "aw":
- return cls.taropen(name, mode, fileobj)
+ return cls.taropen(name, mode, fileobj, **kwargs)
raise ValueError("undiscernible mode")
@classmethod
- def taropen(cls, name, mode="r", fileobj=None):
+ def taropen(cls, name, mode="r", fileobj=None, **kwargs):
"""Open uncompressed tar archive name for reading or writing.
"""
if len(mode) > 1 or mode not in "raw":
raise ValueError("mode must be 'r', 'a' or 'w'")
- return cls(name, mode, fileobj)
+ return cls(name, mode, fileobj, **kwargs)
@classmethod
- def gzopen(cls, name, mode="r", fileobj=None, compresslevel=9):
+ def gzopen(cls, name, mode="r", fileobj=None, compresslevel=9, **kwargs):
"""Open gzip compressed tar archive name for reading or writing.
Appending is not allowed.
"""
@@ -1216,14 +1673,15 @@ class TarFile(object):
try:
t = cls.taropen(name, mode,
- gzip.GzipFile(name, mode, compresslevel, fileobj))
+ gzip.GzipFile(name, mode, compresslevel, fileobj),
+ **kwargs)
except IOError:
raise ReadError("not a gzip file")
t._extfileobj = False
return t
@classmethod
- def bz2open(cls, name, mode="r", fileobj=None, compresslevel=9):
+ def bz2open(cls, name, mode="r", fileobj=None, compresslevel=9, **kwargs):
"""Open bzip2 compressed tar archive name for reading or writing.
Appending is not allowed.
"""
@@ -1241,7 +1699,7 @@ class TarFile(object):
fileobj = bz2.BZ2File(name, mode, compresslevel=compresslevel)
try:
- t = cls.taropen(name, mode, fileobj)
+ t = cls.taropen(name, mode, fileobj, **kwargs)
except IOError:
raise ReadError("not a bzip2 file")
t._extfileobj = False
@@ -1264,7 +1722,7 @@ class TarFile(object):
if self.closed:
return
- if self._mode in "aw":
+ if self.mode in "aw":
self.fileobj.write(NUL * (BLOCKSIZE * 2))
self.offset += (BLOCKSIZE * 2)
# fill up the end with zero-blocks
@@ -1330,7 +1788,8 @@ class TarFile(object):
# Now, fill the TarInfo object with
# information specific for the file.
- tarinfo = TarInfo()
+ tarinfo = self.tarinfo()
+ tarinfo.tarfile = self
# Use os.stat or os.lstat, depending on platform
# and if symlinks shall be resolved.
@@ -1346,8 +1805,8 @@ class TarFile(object):
stmd = statres.st_mode
if stat.S_ISREG(stmd):
inode = (statres.st_ino, statres.st_dev)
- if not self.dereference and \
- statres.st_nlink > 1 and inode in self.inodes:
+ if not self.dereference and statres.st_nlink > 1 and \
+ inode in self.inodes and arcname != self.inodes[inode]:
# Is it a hardlink to an already
# archived file?
type = LNKTYPE
@@ -1424,7 +1883,7 @@ class TarFile(object):
print("%d-%02d-%02d %02d:%02d:%02d" \
% time.localtime(tarinfo.mtime)[:6], end=' ')
- print(tarinfo.name, end=' ')
+ print(tarinfo.name + ("/" if tarinfo.isdir() else ""), end=' ')
if verbose:
if tarinfo.issym():
@@ -1456,7 +1915,7 @@ class TarFile(object):
if recursive:
if arcname == ".":
arcname = ""
- for f in os.listdir("."):
+ for f in os.listdir(name):
self.add(f, os.path.join(arcname, f))
return
@@ -1495,7 +1954,7 @@ class TarFile(object):
tarinfo = copy.copy(tarinfo)
- buf = tarinfo.tobuf(self.posix)
+ buf = tarinfo.tobuf(self.format, self.encoding)
self.fileobj.write(buf)
self.offset += len(buf)
@@ -1527,7 +1986,7 @@ class TarFile(object):
# Extract directory with a safe mode, so that
# all files below can be extracted as well.
try:
- os.makedirs(os.path.join(path, tarinfo.name), 0777)
+ os.makedirs(os.path.join(path, tarinfo.name), 0700)
except EnvironmentError:
pass
directories.append(tarinfo)
@@ -1559,10 +2018,10 @@ class TarFile(object):
"""
self._check("r")
- if isinstance(member, TarInfo):
- tarinfo = member
- else:
+ if isinstance(member, basestring):
tarinfo = self.getmember(member)
+ else:
+ tarinfo = member
# Prepare the link target for makelink().
if tarinfo.islnk():
@@ -1595,10 +2054,10 @@ class TarFile(object):
"""
self._check("r")
- if isinstance(member, TarInfo):
- tarinfo = member
- else:
+ if isinstance(member, basestring):
tarinfo = self.getmember(member)
+ else:
+ tarinfo = member
if tarinfo.isreg():
return self.fileobject(self, tarinfo)
@@ -1811,20 +2270,11 @@ class TarFile(object):
# Read the next block.
self.fileobj.seek(self.offset)
while True:
- buf = self.fileobj.read(BLOCKSIZE)
- if not buf:
- return None
-
try:
- tarinfo = TarInfo.frombuf(buf)
-
- # Set the TarInfo object's offset to the current position of the
- # TarFile and set self.offset to the position where the data blocks
- # should begin.
- tarinfo.offset = self.offset
- self.offset += BLOCKSIZE
-
- tarinfo = self.proc_member(tarinfo)
+ tarinfo = self.tarinfo.fromtarfile(self)
+ if tarinfo is None:
+ return
+ self.members.append(tarinfo)
except HeaderError as e:
if self.ignore_zeros:
@@ -1837,149 +2287,11 @@ class TarFile(object):
return None
break
- # Some old tar programs represent a directory as a regular
- # file with a trailing slash.
- if tarinfo.isreg() and tarinfo.name.endswith("/"):
- tarinfo.type = DIRTYPE
-
- # Directory names should have a '/' at the end.
- if tarinfo.isdir():
- tarinfo.name += "/"
-
- self.members.append(tarinfo)
- return tarinfo
-
- #--------------------------------------------------------------------------
- # The following are methods that are called depending on the type of a
- # member. The entry point is proc_member() which is called with a TarInfo
- # object created from the header block from the current offset. The
- # proc_member() method can be overridden in a subclass to add custom
- # proc_*() methods. A proc_*() method MUST implement the following
- # operations:
- # 1. Set tarinfo.offset_data to the position where the data blocks begin,
- # if there is data that follows.
- # 2. Set self.offset to the position where the next member's header will
- # begin.
- # 3. Return tarinfo or another valid TarInfo object.
- def proc_member(self, tarinfo):
- """Choose the right processing method for tarinfo depending
- on its type and call it.
- """
- if tarinfo.type in (GNUTYPE_LONGNAME, GNUTYPE_LONGLINK):
- return self.proc_gnulong(tarinfo)
- elif tarinfo.type == GNUTYPE_SPARSE:
- return self.proc_sparse(tarinfo)
- else:
- return self.proc_builtin(tarinfo)
-
- def proc_builtin(self, tarinfo):
- """Process a builtin type member or an unknown member
- which will be treated as a regular file.
- """
- tarinfo.offset_data = self.offset
- if tarinfo.isreg() or tarinfo.type not in SUPPORTED_TYPES:
- # Skip the following data blocks.
- self.offset += self._block(tarinfo.size)
- return tarinfo
-
- def proc_gnulong(self, tarinfo):
- """Process the blocks that hold a GNU longname
- or longlink member.
- """
- buf = ""
- count = tarinfo.size
- while count > 0:
- block = self.fileobj.read(BLOCKSIZE)
- buf += block
- self.offset += BLOCKSIZE
- count -= BLOCKSIZE
-
- # Fetch the next header and process it.
- b = self.fileobj.read(BLOCKSIZE)
- t = TarInfo.frombuf(b)
- t.offset = self.offset
- self.offset += BLOCKSIZE
- next = self.proc_member(t)
-
- # Patch the TarInfo object from the next header with
- # the longname information.
- next.offset = tarinfo.offset
- if tarinfo.type == GNUTYPE_LONGNAME:
- next.name = buf.rstrip(NUL)
- elif tarinfo.type == GNUTYPE_LONGLINK:
- next.linkname = buf.rstrip(NUL)
-
- return next
-
- def proc_sparse(self, tarinfo):
- """Process a GNU sparse header plus extra headers.
- """
- buf = tarinfo.buf
- sp = _ringbuffer()
- pos = 386
- lastpos = 0
- realpos = 0
- # There are 4 possible sparse structs in the
- # first header.
- for i in xrange(4):
- try:
- offset = nti(buf[pos:pos + 12])
- numbytes = nti(buf[pos + 12:pos + 24])
- except ValueError:
- break
- if offset > lastpos:
- sp.append(_hole(lastpos, offset - lastpos))
- sp.append(_data(offset, numbytes, realpos))
- realpos += numbytes
- lastpos = offset + numbytes
- pos += 24
-
- isextended = ord(buf[482])
- origsize = nti(buf[483:495])
-
- # If the isextended flag is given,
- # there are extra headers to process.
- while isextended == 1:
- buf = self.fileobj.read(BLOCKSIZE)
- self.offset += BLOCKSIZE
- pos = 0
- for i in xrange(21):
- try:
- offset = nti(buf[pos:pos + 12])
- numbytes = nti(buf[pos + 12:pos + 24])
- except ValueError:
- break
- if offset > lastpos:
- sp.append(_hole(lastpos, offset - lastpos))
- sp.append(_data(offset, numbytes, realpos))
- realpos += numbytes
- lastpos = offset + numbytes
- pos += 24
- isextended = ord(buf[504])
-
- if lastpos < origsize:
- sp.append(_hole(lastpos, origsize - lastpos))
-
- tarinfo.sparse = sp
-
- tarinfo.offset_data = self.offset
- self.offset += self._block(tarinfo.size)
- tarinfo.size = origsize
-
return tarinfo
#--------------------------------------------------------------------------
# Little helper methods:
- def _block(self, count):
- """Round up a byte count by BLOCKSIZE and return it,
- e.g. _block(834) => 1024.
- """
- blocks, remainder = divmod(count, BLOCKSIZE)
- if remainder:
- blocks += 1
- return blocks * BLOCKSIZE
-
def _getmember(self, name, tarinfo=None):
"""Find an archive member by name from bottom to top.
If tarinfo is given, it is used as the starting point.
@@ -2012,8 +2324,8 @@ class TarFile(object):
"""
if self.closed:
raise IOError("%s is closed" % self.__class__.__name__)
- if mode is not None and self._mode not in mode:
- raise IOError("bad operation for mode %r" % self._mode)
+ if mode is not None and self.mode not in mode:
+ raise IOError("bad operation for mode %r" % self.mode)
def __iter__(self):
"""Provide an iterator object.
diff --git a/Lib/telnetlib.py b/Lib/telnetlib.py
index dd263ae..1040e3c 100644
--- a/Lib/telnetlib.py
+++ b/Lib/telnetlib.py
@@ -184,7 +184,7 @@ class Telnet:
"""
- def __init__(self, host=None, port=0):
+ def __init__(self, host=None, port=0, timeout=None):
"""Constructor.
When called without arguments, create an unconnected instance.
@@ -195,6 +195,7 @@ class Telnet:
self.debuglevel = DEBUGLEVEL
self.host = host
self.port = port
+ self.timeout = timeout
self.sock = None
self.rawq = ''
self.irawq = 0
@@ -205,9 +206,9 @@ class Telnet:
self.sbdataq = ''
self.option_callback = None
if host is not None:
- self.open(host, port)
+ self.open(host, port, timeout)
- def open(self, host, port=0):
+ def open(self, host, port=0, timeout=None):
"""Connect to a host.
The optional second argument is the port number, which
@@ -221,20 +222,9 @@ class Telnet:
port = TELNET_PORT
self.host = host
self.port = port
- msg = "getaddrinfo returns an empty list"
- for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
- af, socktype, proto, canonname, sa = res
- try:
- self.sock = socket.socket(af, socktype, proto)
- self.sock.connect(sa)
- except socket.error as msg:
- if self.sock:
- self.sock.close()
- self.sock = None
- continue
- break
- if not self.sock:
- raise socket.error, msg
+ if timeout is not None:
+ self.timeout = timeout
+ self.sock = socket.create_connection((host, port), self.timeout)
def __del__(self):
"""Destructor -- close the connection."""
@@ -661,7 +651,7 @@ def test():
port = socket.getservbyname(portstr, 'tcp')
tn = Telnet()
tn.set_debuglevel(debuglevel)
- tn.open(host, port)
+ tn.open(host, port, timeout=0.5)
tn.interact()
tn.close()
diff --git a/Lib/tempfile.py b/Lib/tempfile.py
index 0ebf6b4..b63a46a 100644
--- a/Lib/tempfile.py
+++ b/Lib/tempfile.py
@@ -19,6 +19,7 @@ This module also provides some data items to the user:
__all__ = [
"NamedTemporaryFile", "TemporaryFile", # high level safe interfaces
+ "SpooledTemporaryFile",
"mkstemp", "mkdtemp", # low level safe interfaces
"mktemp", # deprecated unsafe interface
"TMP_MAX", "gettempprefix", # constants
@@ -37,6 +38,11 @@ if _os.name == 'mac':
import Carbon.Folders as _Folders
try:
+ from cStringIO import StringIO as _StringIO
+except:
+ from StringIO import StringIO as _StringIO
+
+try:
import fcntl as _fcntl
except ImportError:
def _set_cloexec(fd):
@@ -114,7 +120,7 @@ class _RandomNameSequence:
characters = ("abcdefghijklmnopqrstuvwxyz" +
"ABCDEFGHIJKLMNOPQRSTUVWXYZ" +
- "0123456789-_")
+ "0123456789_")
def __init__(self):
self.mutex = _allocate_lock()
@@ -372,10 +378,11 @@ class _TemporaryFileWrapper:
remove the file when it is no longer needed.
"""
- def __init__(self, file, name):
+ def __init__(self, file, name, delete=True):
self.file = file
self.name = name
self.close_called = False
+ self.delete = delete
def __getattr__(self, name):
file = self.__dict__['file']
@@ -400,23 +407,25 @@ class _TemporaryFileWrapper:
if not self.close_called:
self.close_called = True
self.file.close()
- self.unlink(self.name)
+ if self.delete:
+ self.unlink(self.name)
def __del__(self):
self.close()
def NamedTemporaryFile(mode='w+b', bufsize=-1, suffix="",
- prefix=template, dir=None):
+ prefix=template, dir=None, delete=True):
"""Create and return a temporary file.
Arguments:
'prefix', 'suffix', 'dir' -- as for mkstemp.
'mode' -- the mode argument to os.fdopen (default "w+b").
'bufsize' -- the buffer size argument to os.fdopen (default -1).
+ 'delete' -- whether the file is deleted on close (default True).
The file is created as mkstemp() would do it.
Returns an object with a file-like interface; the name of the file
is accessible as file.name. The file will be automatically deleted
- when it is closed.
+ when it is closed unless the 'delete' argument is set to False.
"""
if dir is None:
@@ -429,12 +438,12 @@ def NamedTemporaryFile(mode='w+b', bufsize=-1, suffix="",
# Setting O_TEMPORARY in the flags causes the OS to delete
# the file when it is closed. This is only supported by Windows.
- if _os.name == 'nt':
+ if _os.name == 'nt' and delete:
flags |= _os.O_TEMPORARY
(fd, name) = _mkstemp_inner(dir, prefix, suffix, flags)
file = _os.fdopen(fd, mode, bufsize)
- return _TemporaryFileWrapper(file, name)
+ return _TemporaryFileWrapper(file, name, delete)
if _os.name != 'posix' or _os.sys.platform == 'cygwin':
# On non-POSIX and Cygwin systems, assume that we cannot unlink a file
@@ -470,3 +479,111 @@ else:
except:
_os.close(fd)
raise
+
+class SpooledTemporaryFile:
+ """Temporary file wrapper, specialized to switch from
+ StringIO to a real file when it exceeds a certain size or
+ when a fileno is needed.
+ """
+ _rolled = False
+
+ def __init__(self, max_size=0, mode='w+b', bufsize=-1,
+ suffix="", prefix=template, dir=None):
+ self._file = _StringIO()
+ self._max_size = max_size
+ self._rolled = False
+ self._TemporaryFileArgs = (mode, bufsize, suffix, prefix, dir)
+
+ def _check(self, file):
+ if self._rolled: return
+ max_size = self._max_size
+ if max_size and file.tell() > max_size:
+ self.rollover()
+
+ def rollover(self):
+ if self._rolled: return
+ file = self._file
+ newfile = self._file = TemporaryFile(*self._TemporaryFileArgs)
+ del self._TemporaryFileArgs
+
+ newfile.write(file.getvalue())
+ newfile.seek(file.tell(), 0)
+
+ self._rolled = True
+
+ # file protocol
+ def __iter__(self):
+ return self._file.__iter__()
+
+ def close(self):
+ self._file.close()
+
+ @property
+ def closed(self):
+ return self._file.closed
+
+ @property
+ def encoding(self):
+ return self._file.encoding
+
+ def fileno(self):
+ self.rollover()
+ return self._file.fileno()
+
+ def flush(self):
+ self._file.flush()
+
+ def isatty(self):
+ return self._file.isatty()
+
+ @property
+ def mode(self):
+ return self._file.mode
+
+ @property
+ def name(self):
+ return self._file.name
+
+ @property
+ def newlines(self):
+ return self._file.newlines
+
+ def next(self):
+ return self._file.next
+
+ def read(self, *args):
+ return self._file.read(*args)
+
+ def readline(self, *args):
+ return self._file.readline(*args)
+
+ def readlines(self, *args):
+ return self._file.readlines(*args)
+
+ def seek(self, *args):
+ self._file.seek(*args)
+
+ @property
+ def softspace(self):
+ return self._file.softspace
+
+ def tell(self):
+ return self._file.tell()
+
+ def truncate(self):
+ self._file.truncate()
+
+ def write(self, s):
+ file = self._file
+ rv = file.write(s)
+ self._check(file)
+ return rv
+
+ def writelines(self, iterable):
+ file = self._file
+ rv = file.writelines(iterable)
+ self._check(file)
+ return rv
+
+ def xreadlines(self, *args):
+ return self._file.xreadlines(*args)
diff --git a/Lib/test/README b/Lib/test/README
index 27f696c..747d842 100644
--- a/Lib/test/README
+++ b/Lib/test/README
@@ -15,7 +15,7 @@ testing facility provided with Python; any particular test should use only
one of these options. Each option requires writing a test module using the
conventions of the selected option:
- - PyUnit_ based tests
+ - unittest_ based tests
- doctest_ based tests
- "traditional" Python test modules
@@ -28,31 +28,34 @@ your test cases to exercise it more completely. In particular, you will be
able to refer to the C and Python code in the CVS repository when writing
your regression test cases.
-.. _PyUnit:
.. _unittest: http://www.python.org/doc/current/lib/module-unittest.html
.. _doctest: http://www.python.org/doc/current/lib/module-doctest.html
-PyUnit based tests
+unittest-based tests
------------------
-The PyUnit_ framework is based on the ideas of unit testing as espoused
+The unittest_ framework is based on the ideas of unit testing as espoused
by Kent Beck and the `Extreme Programming`_ (XP) movement. The specific
interface provided by the framework is tightly based on the JUnit_
Java implementation of Beck's original SmallTalk test framework. Please
see the documentation of the unittest_ module for detailed information on
-the interface and general guidelines on writing PyUnit based tests.
-
-The test_support helper module provides two functions for use by
-PyUnit based tests in the Python regression testing framework:
-
-- ``run_unittest()`` takes a ``unittest.TestCase`` derived class as a
- parameter and runs the tests defined in that class
+the interface and general guidelines on writing unittest-based tests.
+
+The test_support helper module provides a function for use by
+unittest-based tests in the Python regression testing framework,
+``run_unittest()``. This is the primary way of running tests in the
+standard library. You can pass it any number of the following:
+
+- classes derived from or instances of ``unittest.TestCase`` or
+ ``unittest.TestSuite``. These will be handed off to unittest for
+ converting into a proper TestSuite instance.
+
+- a string; this must be a key in sys.modules. The module associated with
+ that string will be scanned by ``unittest.TestLoader.loadTestsFromModule``.
+ This is usually seen as ``test_support.run_unittest(__name__)`` in a test
+ module's ``test_main()`` function. This has the advantage of picking up
+ new tests automatically, without you having to add each new test case
+ manually.
-- ``run_suite()`` takes a populated ``TestSuite`` instance and runs the
- tests
-
-``run_suite()`` is preferred because unittest files typically grow multiple
-test classes, and you might as well be prepared.
-
All test methods in the Python regression framework have names that
start with "``test_``" and use lower-case names with words separated with
underscores.
@@ -63,7 +66,7 @@ and the full class name. When there's a problem with a test, the
latter information makes it easier to find the source for the test
than the docstring.
-All PyUnit-based tests in the Python test suite use boilerplate that
+All unittest-based tests in the Python test suite use boilerplate that
looks like this (with minor variations)::
import unittest
@@ -97,11 +100,7 @@ looks like this (with minor variations)::
...etc...
def test_main():
- suite = unittest.TestSuite()
- suite.addTest(unittest.makeSuite(MyTestCase1))
- suite.addTest(unittest.makeSuite(MyTestCase2))
- ...add more suites...
- test_support.run_suite(suite)
+ test_support.run_unittest(__name__)
if __name__ == "__main__":
test_main()
@@ -415,7 +414,7 @@ Some Non-Obvious regrtest Features
This is rarely required with the "traditional" Python tests, and
you shouldn't create a module global with name test_main unless
you're specifically exploiting this gimmick. This usage does
- prove useful with PyUnit-based tests as well, however; defining
+ prove useful with unittest-based tests as well, however; defining
a ``test_main()`` which is run by regrtest and a script-stub in the
test module ("``if __name__ == '__main__': test_main()``") allows
the test to be used like any other Python test and also work
diff --git a/Lib/test/crashers/modify_dict_attr.py b/Lib/test/crashers/modify_dict_attr.py
index ac1f0a8..be675c1 100644
--- a/Lib/test/crashers/modify_dict_attr.py
+++ b/Lib/test/crashers/modify_dict_attr.py
@@ -4,15 +4,16 @@
class Y(object):
pass
-class type_with_modifiable_dict(Y, type):
+class type_with_modifiable_dict(type, Y):
pass
class MyClass(object, metaclass=type_with_modifiable_dict):
- """This class has its __dict__ attribute completely exposed:
- user code can read, reassign and even delete it.
+ """This class has its __dict__ attribute indirectly
+ exposed via the __dict__ getter/setter of Y.
"""
if __name__ == '__main__':
- del MyClass.__dict__ # if we set tp_dict to NULL,
+ dictattr = Y.__dict__['__dict__']
+ dictattr.__delete__(MyClass) # if we set tp_dict to NULL,
print(MyClass) # doing anything with MyClass segfaults
diff --git a/Lib/test/infinite_reload.py b/Lib/test/infinite_reload.py
new file mode 100644
index 0000000..bfbec91
--- /dev/null
+++ b/Lib/test/infinite_reload.py
@@ -0,0 +1,7 @@
+# For testing http://python.org/sf/742342, which reports that Python
+# segfaults (infinite recursion in C) in the presence of infinite
+# reload()ing. This module is imported by test_import.py:test_infinite_reload
+# to make sure this doesn't happen any more.
+
+import infinite_reload
+reload(infinite_reload)
diff --git a/Lib/test/output/test_operations b/Lib/test/output/test_operations
deleted file mode 100644
index 309cd5b..0000000
--- a/Lib/test/output/test_operations
+++ /dev/null
@@ -1,19 +0,0 @@
-test_operations
-3. Operations
-XXX Mostly not yet implemented
-3.1 Dictionary lookups fail if __cmp__() raises an exception
-raising error
-d[x2] = 2: caught the RuntimeError outside
-raising error
-z = d[x2]: caught the RuntimeError outside
-raising error
-x2 in d: caught the RuntimeError outside
-raising error
-d.get(x2): caught the RuntimeError outside
-raising error
-d.setdefault(x2, 42): caught the RuntimeError outside
-raising error
-d.pop(x2): caught the RuntimeError outside
-raising error
-d.update({x2: 2}): caught the RuntimeError outside
-resize bugs not triggered.
diff --git a/Lib/test/output/test_popen2 b/Lib/test/output/test_popen2
deleted file mode 100644
index a66cde9..0000000
--- a/Lib/test/output/test_popen2
+++ /dev/null
@@ -1,9 +0,0 @@
-test_popen2
-Test popen2 module:
-testing popen2...
-testing popen3...
-All OK
-Testing os module:
-testing popen2...
-testing popen3...
-All OK
diff --git a/Lib/test/output/test_pty b/Lib/test/output/test_pty
deleted file mode 100644
index b6e0e32..0000000
--- a/Lib/test/output/test_pty
+++ /dev/null
@@ -1,3 +0,0 @@
-test_pty
-I wish to buy a fish license.
-For my pet fish, Eric.
diff --git a/Lib/test/output/test_pyexpat b/Lib/test/output/test_pyexpat
deleted file mode 100644
index 61fe81d..0000000
--- a/Lib/test/output/test_pyexpat
+++ /dev/null
@@ -1,110 +0,0 @@
-test_pyexpat
-OK.
-OK.
-OK.
-OK.
-OK.
-OK.
-OK.
-OK.
-OK.
-OK.
-OK.
-OK.
-PI:
- 'xml-stylesheet' 'href="stylesheet.css"'
-Comment:
- ' comment data '
-Notation declared: ('notation', None, 'notation.jpeg', None)
-Unparsed entity decl:
- ('unparsed_entity', None, 'entity.file', None, 'notation')
-Start element:
- 'root' {'attr1': 'value1', 'attr2': 'value2\xe1\xbd\x80'}
-NS decl:
- 'myns' 'http://www.python.org/namespace'
-Start element:
- 'http://www.python.org/namespace!subelement' {}
-Character data:
- 'Contents of subelements'
-End element:
- 'http://www.python.org/namespace!subelement'
-End of NS decl:
- 'myns'
-Start element:
- 'sub2' {}
-Start of CDATA section
-Character data:
- 'contents of CDATA section'
-End of CDATA section
-End element:
- 'sub2'
-External entity ref: (None, 'entity.file', None)
-End element:
- 'root'
-PI:
- u'xml-stylesheet' u'href="stylesheet.css"'
-Comment:
- u' comment data '
-Notation declared: (u'notation', None, u'notation.jpeg', None)
-Unparsed entity decl:
- (u'unparsed_entity', None, u'entity.file', None, u'notation')
-Start element:
- u'root' {u'attr1': u'value1', u'attr2': u'value2\u1f40'}
-NS decl:
- u'myns' u'http://www.python.org/namespace'
-Start element:
- u'http://www.python.org/namespace!subelement' {}
-Character data:
- u'Contents of subelements'
-End element:
- u'http://www.python.org/namespace!subelement'
-End of NS decl:
- u'myns'
-Start element:
- u'sub2' {}
-Start of CDATA section
-Character data:
- u'contents of CDATA section'
-End of CDATA section
-End element:
- u'sub2'
-External entity ref: (None, u'entity.file', None)
-End element:
- u'root'
-PI:
- u'xml-stylesheet' u'href="stylesheet.css"'
-Comment:
- u' comment data '
-Notation declared: (u'notation', None, u'notation.jpeg', None)
-Unparsed entity decl:
- (u'unparsed_entity', None, u'entity.file', None, u'notation')
-Start element:
- u'root' {u'attr1': u'value1', u'attr2': u'value2\u1f40'}
-NS decl:
- u'myns' u'http://www.python.org/namespace'
-Start element:
- u'http://www.python.org/namespace!subelement' {}
-Character data:
- u'Contents of subelements'
-End element:
- u'http://www.python.org/namespace!subelement'
-End of NS decl:
- u'myns'
-Start element:
- u'sub2' {}
-Start of CDATA section
-Character data:
- u'contents of CDATA section'
-End of CDATA section
-End element:
- u'sub2'
-External entity ref: (None, u'entity.file', None)
-End element:
- u'root'
-
-Testing constructor for proper handling of namespace_separator values:
-Legal values tested o.k.
-Caught expected TypeError:
-ParserCreate() argument 2 must be string or None, not int
-Caught expected ValueError:
-namespace_separator must be at most one character, omitted, or None
diff --git a/Lib/test/output/test_threadedtempfile b/Lib/test/output/test_threadedtempfile
deleted file mode 100644
index 2552877..0000000
--- a/Lib/test/output/test_threadedtempfile
+++ /dev/null
@@ -1,5 +0,0 @@
-test_threadedtempfile
-Creating
-Starting
-Reaping
-Done: errors 0 ok 1000
diff --git a/Lib/test/output/xmltests b/Lib/test/output/xmltests
deleted file mode 100644
index c798f6e..0000000
--- a/Lib/test/output/xmltests
+++ /dev/null
@@ -1,364 +0,0 @@
-xmltests
-Passed testAAA
-Passed setAttribute() sets ownerDocument
-Passed setAttribute() sets ownerElement
-Test Succeeded testAAA
-Passed assertion: len(Node.allnodes) == 0
-Passed testAAB
-Test Succeeded testAAB
-Passed assertion: len(Node.allnodes) == 0
-Passed Test
-Passed Test
-Passed Test
-Passed Test
-Passed Test
-Passed Test
-Passed Test
-Passed Test
-Test Succeeded testAddAttr
-Passed assertion: len(Node.allnodes) == 0
-Passed Test
-Passed Test
-Test Succeeded testAppendChild
-Passed assertion: len(Node.allnodes) == 0
-Passed appendChild(<fragment>)
-Test Succeeded testAppendChildFragment
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testAttrListItem
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testAttrListItemNS
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testAttrListItems
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testAttrListKeys
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testAttrListKeysNS
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testAttrListLength
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testAttrListValues
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testAttrList__getitem__
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testAttrList__setitem__
-Passed assertion: len(Node.allnodes) == 0
-Passed Test
-Passed Test
-Test Succeeded testAttributeRepr
-Passed assertion: len(Node.allnodes) == 0
-Passed Test
-Passed Test
-Passed Test
-Passed Test
-Passed Test
-Test Succeeded testChangeAttr
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testChildNodes
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testCloneAttributeDeep
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testCloneAttributeShallow
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testCloneDocumentDeep
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testCloneDocumentShallow
-Passed assertion: len(Node.allnodes) == 0
-Passed clone of element has same attribute keys
-Passed clone of attribute node has proper attribute values
-Passed clone of attribute node correctly owned
-Passed testCloneElementDeep
-Test Succeeded testCloneElementDeep
-Passed assertion: len(Node.allnodes) == 0
-Passed clone of element has same attribute keys
-Passed clone of attribute node has proper attribute values
-Passed clone of attribute node correctly owned
-Passed testCloneElementShallow
-Test Succeeded testCloneElementShallow
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testClonePIDeep
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testClonePIShallow
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testComment
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testCreateAttributeNS
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testCreateElementNS
-Passed assertion: len(Node.allnodes) == 0
-Passed Test
-Passed Test
-Passed Test
-Test Succeeded testDeleteAttr
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testDocumentElement
-Passed assertion: len(Node.allnodes) == 0
-Passed Test
-Test Succeeded testElement
-Passed assertion: len(Node.allnodes) == 0
-Passed Test
-Test Succeeded testElementReprAndStr
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testFirstChild
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testGetAttrLength
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testGetAttrList
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testGetAttrValues
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testGetAttribute
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testGetAttributeNS
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testGetAttributeNode
-Passed assertion: len(Node.allnodes) == 0
-Passed Test
-Test Succeeded testGetElementsByTagName
-Passed assertion: len(Node.allnodes) == 0
-Passed Test
-Test Succeeded testGetElementsByTagNameNS
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testGetEmptyNodeListFromElementsByTagNameNS
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testHasChildNodes
-Passed assertion: len(Node.allnodes) == 0
-Passed testInsertBefore -- node properly placed in tree
-Passed testInsertBefore -- node properly placed in tree
-Passed testInsertBefore -- node properly placed in tree
-Test Succeeded testInsertBefore
-Passed assertion: len(Node.allnodes) == 0
-Passed insertBefore(<fragment>, None)
-Passed insertBefore(<fragment>, orig)
-Test Succeeded testInsertBeforeFragment
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testLegalChildren
-Passed assertion: len(Node.allnodes) == 0
-Passed NamedNodeMap.__setitem__() sets ownerDocument
-Passed NamedNodeMap.__setitem__() sets ownerElement
-Passed NamedNodeMap.__setitem__() sets value
-Passed NamedNodeMap.__setitem__() sets nodeValue
-Test Succeeded testNamedNodeMapSetItem
-Passed assertion: len(Node.allnodes) == 0
-Passed test NodeList.item()
-Test Succeeded testNodeListItem
-Passed assertion: len(Node.allnodes) == 0
-Passed Test
-Passed Test
-Test Succeeded testNonZero
-Passed assertion: len(Node.allnodes) == 0
-Passed testNormalize -- preparation
-Passed testNormalize -- result
-Passed testNormalize -- single empty node removed
-Test Succeeded testNormalize
-Passed assertion: len(Node.allnodes) == 0
-Passed testParents
-Test Succeeded testParents
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testParse
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testParseAttributeNamespaces
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testParseAttributes
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testParseElement
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testParseElementNamespaces
-Passed assertion: len(Node.allnodes) == 0
-Passed Test
-Test Succeeded testParseFromFile
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testParseProcessingInstructions
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testParseString
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testProcessingInstruction
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testProcessingInstructionRepr
-Passed assertion: len(Node.allnodes) == 0
-Passed Test
-Passed Test
-Test Succeeded testRemoveAttr
-Passed assertion: len(Node.allnodes) == 0
-Passed Test
-Passed Test
-Test Succeeded testRemoveAttrNS
-Passed assertion: len(Node.allnodes) == 0
-Passed Test
-Passed Test
-Test Succeeded testRemoveAttributeNode
-Passed assertion: len(Node.allnodes) == 0
-Passed replaceChild(<fragment>)
-Test Succeeded testReplaceChildFragment
-Passed assertion: len(Node.allnodes) == 0
-Passed testSAX2DOM - siblings
-Passed testSAX2DOM - parents
-Test Succeeded testSAX2DOM
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testSetAttrValueandNodeValue
-Passed assertion: len(Node.allnodes) == 0
-Passed testSiblings
-Test Succeeded testSiblings
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testTextNodeRepr
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testTextRepr
-Passed assertion: len(Node.allnodes) == 0
-Caught expected exception when adding extra document element.
-Test Succeeded testTooManyDocumentElements
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testUnlink
-Passed assertion: len(Node.allnodes) == 0
-Test Succeeded testWriteText
-Passed assertion: len(Node.allnodes) == 0
-Passed Test
-Passed Test
-Test Succeeded testWriteXML
-Passed assertion: len(Node.allnodes) == 0
-All tests succeeded
-OK.
-OK.
-OK.
-OK.
-OK.
-OK.
-OK.
-OK.
-OK.
-OK.
-OK.
-OK.
-PI:
- 'xml-stylesheet' 'href="stylesheet.css"'
-Comment:
- ' comment data '
-Notation declared: ('notation', None, 'notation.jpeg', None)
-Unparsed entity decl:
- ('unparsed_entity', None, 'entity.file', None, 'notation')
-Start element:
- 'root' {'attr1': 'value1', 'attr2': 'value2\xe1\xbd\x80'}
-NS decl:
- 'myns' 'http://www.python.org/namespace'
-Start element:
- 'http://www.python.org/namespace!subelement' {}
-Character data:
- 'Contents of subelements'
-End element:
- 'http://www.python.org/namespace!subelement'
-End of NS decl:
- 'myns'
-Start element:
- 'sub2' {}
-Start of CDATA section
-Character data:
- 'contents of CDATA section'
-End of CDATA section
-End element:
- 'sub2'
-External entity ref: (None, 'entity.file', None)
-End element:
- 'root'
-PI:
- u'xml-stylesheet' u'href="stylesheet.css"'
-Comment:
- u' comment data '
-Notation declared: (u'notation', None, u'notation.jpeg', None)
-Unparsed entity decl:
- (u'unparsed_entity', None, u'entity.file', None, u'notation')
-Start element:
- u'root' {u'attr1': u'value1', u'attr2': u'value2\u1f40'}
-NS decl:
- u'myns' u'http://www.python.org/namespace'
-Start element:
- u'http://www.python.org/namespace!subelement' {}
-Character data:
- u'Contents of subelements'
-End element:
- u'http://www.python.org/namespace!subelement'
-End of NS decl:
- u'myns'
-Start element:
- u'sub2' {}
-Start of CDATA section
-Character data:
- u'contents of CDATA section'
-End of CDATA section
-End element:
- u'sub2'
-External entity ref: (None, u'entity.file', None)
-End element:
- u'root'
-PI:
- u'xml-stylesheet' u'href="stylesheet.css"'
-Comment:
- u' comment data '
-Notation declared: (u'notation', None, u'notation.jpeg', None)
-Unparsed entity decl:
- (u'unparsed_entity', None, u'entity.file', None, u'notation')
-Start element:
- u'root' {u'attr1': u'value1', u'attr2': u'value2\u1f40'}
-NS decl:
- u'myns' u'http://www.python.org/namespace'
-Start element:
- u'http://www.python.org/namespace!subelement' {}
-Character data:
- u'Contents of subelements'
-End element:
- u'http://www.python.org/namespace!subelement'
-End of NS decl:
- u'myns'
-Start element:
- u'sub2' {}
-Start of CDATA section
-Character data:
- u'contents of CDATA section'
-End of CDATA section
-End element:
- u'sub2'
-External entity ref: (None, u'entity.file', None)
-End element:
- u'root'
-
-Testing constructor for proper handling of namespace_separator values:
-Legal values tested o.k.
-Caught expected TypeError:
-ParserCreate() argument 2 must be string or None, not int
-Caught expected ValueError:
-namespace_separator must be at most one character, omitted, or None
-Passed test_attrs_empty
-Passed test_attrs_wattr
-Passed test_double_quoteattr
-Passed test_escape_all
-Passed test_escape_basic
-Passed test_escape_extra
-Passed test_expat_attrs_empty
-Passed test_expat_attrs_wattr
-Passed test_expat_dtdhandler
-Passed test_expat_entityresolver
-Passed test_expat_file
-Passed test_expat_incomplete
-Passed test_expat_incremental
-Passed test_expat_incremental_reset
-Passed test_expat_inpsource_filename
-Passed test_expat_inpsource_location
-Passed test_expat_inpsource_stream
-Passed test_expat_inpsource_sysid
-Passed test_expat_locator_noinfo
-Passed test_expat_locator_withinfo
-Passed test_expat_nsattrs_empty
-Passed test_expat_nsattrs_wattr
-Passed test_filter_basic
-Passed test_make_parser
-Passed test_make_parser2
-Passed test_nsattrs_empty
-Passed test_nsattrs_wattr
-Passed test_quoteattr_basic
-Passed test_single_double_quoteattr
-Passed test_single_quoteattr
-Passed test_xmlgen_attr_escape
-Passed test_xmlgen_basic
-Passed test_xmlgen_content
-Passed test_xmlgen_content_escape
-Passed test_xmlgen_ignorable
-Passed test_xmlgen_ns
-Passed test_xmlgen_pi
-37 tests, 0 failures
diff --git a/Lib/test/outstanding_bugs.py b/Lib/test/outstanding_bugs.py
index 04afcbd..7c6cd9e 100644
--- a/Lib/test/outstanding_bugs.py
+++ b/Lib/test/outstanding_bugs.py
@@ -10,13 +10,44 @@ import unittest
from test import test_support
#
-# No test cases for outstanding bugs at the moment.
+# One test case for outstanding bugs at the moment:
#
+class TestDifflibLongestMatch(unittest.TestCase):
+ # From Patch #1678339:
+ # The find_longest_match method in the difflib's SequenceMatcher has a bug.
+
+ # The bug is in turn caused by a problem with creating a b2j mapping which
+ # should contain a list of indices for each of the list elements in b.
+ # However, when the b2j mapping is being created (this is being done in
+ # __chain_b method in the SequenceMatcher) the mapping becomes broken. The
+ # cause of this is that for the frequently used elements the list of indices
+ # is removed and the element is being enlisted in the populardict mapping.
+
+ # The test case tries to match two strings like:
+ # abbbbbb.... and ...bbbbbbc
+
+ # The number of b is equal and the find_longest_match should have returned
+ # the proper amount. However, in case the number of "b"s is large enough, the
+ # method reports that the length of the longest common substring is 0. It
+ # simply can't find it.
+
+ # A bug was raised some time ago on this matter. It's ID is 1528074.
+
+ def test_find_longest_match(self):
+ import difflib
+ for i in (190, 200, 210):
+ text1 = "a" + "b"*i
+ text2 = "b"*i + "c"
+ m = difflib.SequenceMatcher(None, text1, text2)
+ (aptr, bptr, l) = m.find_longest_match(0, len(text1), 0, len(text2))
+ self.assertEquals(i, l)
+ self.assertEquals(aptr, 1)
+ self.assertEquals(bptr, 0)
+
def test_main():
- #test_support.run_unittest()
- pass
+ test_support.run_unittest(TestDifflibLongestMatch)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py
index b10c57d..4691e13 100644
--- a/Lib/test/pickletester.py
+++ b/Lib/test/pickletester.py
@@ -835,6 +835,24 @@ class AbstractPickleTests(unittest.TestCase):
y = self.loads(s)
self.assertEqual(y._proto, None)
+ def test_reduce_ex_calls_base(self):
+ for proto in 0, 1, 2:
+ x = REX_four()
+ self.assertEqual(x._proto, None)
+ s = self.dumps(x, proto)
+ self.assertEqual(x._proto, proto)
+ y = self.loads(s)
+ self.assertEqual(y._proto, proto)
+
+ def test_reduce_calls_base(self):
+ for proto in 0, 1, 2:
+ x = REX_five()
+ self.assertEqual(x._reduce_called, 0)
+ s = self.dumps(x, proto)
+ self.assertEqual(x._reduce_called, 1)
+ y = self.loads(s)
+ self.assertEqual(y._reduce_called, 1)
+
# Test classes for reduce_ex
class REX_one(object):
@@ -859,6 +877,20 @@ class REX_three(object):
def __reduce__(self):
raise TestFailed, "This __reduce__ shouldn't be called"
+class REX_four(object):
+ _proto = None
+ def __reduce_ex__(self, proto):
+ self._proto = proto
+ return object.__reduce_ex__(self, proto)
+ # Calling base class method should succeed
+
+class REX_five(object):
+ _reduce_called = 0
+ def __reduce__(self):
+ self._reduce_called = 1
+ return object.__reduce__(self)
+ # This one used to fail with infinite recursion
+
# Test classes for newobj
class MyInt(int):
diff --git a/Lib/test/regrtest.py b/Lib/test/regrtest.py
index 6ef6663..87d9dc9 100755
--- a/Lib/test/regrtest.py
+++ b/Lib/test/regrtest.py
@@ -474,7 +474,7 @@ def main(tests=None, testdir=None, verbose=0, quiet=False, generate=False,
STDTESTS = [
'test_grammar',
'test_opcodes',
- 'test_operations',
+ 'test_dict',
'test_builtin',
'test_exceptions',
'test_types',
diff --git a/Lib/test/ssl_cert.pem b/Lib/test/ssl_cert.pem
new file mode 100644
index 0000000..9d7ac23
--- /dev/null
+++ b/Lib/test/ssl_cert.pem
@@ -0,0 +1,14 @@
+-----BEGIN CERTIFICATE-----
+MIICLDCCAdYCAQAwDQYJKoZIhvcNAQEEBQAwgaAxCzAJBgNVBAYTAlBUMRMwEQYD
+VQQIEwpRdWVlbnNsYW5kMQ8wDQYDVQQHEwZMaXNib2ExFzAVBgNVBAoTDk5ldXJv
+bmlvLCBMZGEuMRgwFgYDVQQLEw9EZXNlbnZvbHZpbWVudG8xGzAZBgNVBAMTEmJy
+dXR1cy5uZXVyb25pby5wdDEbMBkGCSqGSIb3DQEJARYMc2FtcG9AaWtpLmZpMB4X
+DTk2MDkwNTAzNDI0M1oXDTk2MTAwNTAzNDI0M1owgaAxCzAJBgNVBAYTAlBUMRMw
+EQYDVQQIEwpRdWVlbnNsYW5kMQ8wDQYDVQQHEwZMaXNib2ExFzAVBgNVBAoTDk5l
+dXJvbmlvLCBMZGEuMRgwFgYDVQQLEw9EZXNlbnZvbHZpbWVudG8xGzAZBgNVBAMT
+EmJydXR1cy5uZXVyb25pby5wdDEbMBkGCSqGSIb3DQEJARYMc2FtcG9AaWtpLmZp
+MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBAL7+aty3S1iBA/+yxjxv4q1MUTd1kjNw
+L4lYKbpzzlmC5beaQXeQ2RmGMTXU+mDvuqItjVHOK3DvPK7lTcSGftUCAwEAATAN
+BgkqhkiG9w0BAQQFAANBAFqPEKFjk6T6CKTHvaQeEAsX0/8YHPHqH/9AnhSjrwuX
+9EBc0n6bVGhN7XaXd6sJ7dym9sbsWxb+pJdurnkxjx4=
+-----END CERTIFICATE-----
diff --git a/Lib/test/ssl_key.pem b/Lib/test/ssl_key.pem
new file mode 100644
index 0000000..239ad66
--- /dev/null
+++ b/Lib/test/ssl_key.pem
@@ -0,0 +1,9 @@
+-----BEGIN RSA PRIVATE KEY-----
+MIIBPAIBAAJBAL7+aty3S1iBA/+yxjxv4q1MUTd1kjNwL4lYKbpzzlmC5beaQXeQ
+2RmGMTXU+mDvuqItjVHOK3DvPK7lTcSGftUCAwEAAQJBALjkK+jc2+iihI98riEF
+oudmkNziSRTYjnwjx8mCoAjPWviB3c742eO3FG4/soi1jD9A5alihEOXfUzloenr
+8IECIQD3B5+0l+68BA/6d76iUNqAAV8djGTzvxnCxycnxPQydQIhAMXt4trUI3nc
+a+U8YL2HPFA3gmhBsSICbq2OptOCnM7hAiEA6Xi3JIQECob8YwkRj29DU3/4WYD7
+WLPgsQpwo1GuSpECICGsnWH5oaeD9t9jbFoSfhJvv0IZmxdcLpRcpslpeWBBAiEA
+6/5B8J0GHdJq89FHwEG/H2eVVUYu5y/aD6sgcm+0Avg=
+-----END RSA PRIVATE KEY-----
diff --git a/Lib/test/string_tests.py b/Lib/test/string_tests.py
index b0852ee..2431262 100644
--- a/Lib/test/string_tests.py
+++ b/Lib/test/string_tests.py
@@ -1104,6 +1104,9 @@ class MixinStrStringUserStringTest:
self.checkequal('Abc', 'abc', 'translate', table)
self.checkequal('xyz', 'xyz', 'translate', table)
self.checkequal('yz', 'xyz', 'translate', table, 'x')
+ self.checkequal('yx', 'zyzzx', 'translate', None, 'z')
+ self.checkequal('zyzzx', 'zyzzx', 'translate', None, '')
+ self.checkequal('zyzzx', 'zyzzx', 'translate', None)
self.checkraises(ValueError, 'xyz', 'translate', 'too short', 'strip')
self.checkraises(ValueError, 'xyz', 'translate', 'too short')
diff --git a/Lib/test/test___all__.py b/Lib/test/test___all__.py
index bb1fd8d..6003733 100644
--- a/Lib/test/test___all__.py
+++ b/Lib/test/test___all__.py
@@ -1,7 +1,5 @@
import unittest
-from test import test_support
-
-from test.test_support import verify, verbose
+from test.test_support import verbose, run_unittest
import sys
import warnings
@@ -20,15 +18,15 @@ class AllTest(unittest.TestCase):
# Silent fail here seems the best route since some modules
# may not be available in all environments.
return
- verify(hasattr(sys.modules[modname], "__all__"),
- "%s has no __all__ attribute" % modname)
+ self.failUnless(hasattr(sys.modules[modname], "__all__"),
+ "%s has no __all__ attribute" % modname)
names = {}
exec("from %s import *" % modname, names)
if "__builtins__" in names:
del names["__builtins__"]
keys = set(names)
all = set(sys.modules[modname].__all__)
- verify(keys==all, "%s != %s" % (keys, all))
+ self.assertEqual(keys, all)
def test_all(self):
if not sys.platform.startswith('java'):
@@ -177,7 +175,7 @@ class AllTest(unittest.TestCase):
def test_main():
- test_support.run_unittest(AllTest)
+ run_unittest(AllTest)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_array.py b/Lib/test/test_array.py
index ae7156b..34b573f 100755
--- a/Lib/test/test_array.py
+++ b/Lib/test/test_array.py
@@ -111,6 +111,21 @@ class BaseTest(unittest.TestCase):
self.assertEqual(a.x, b.x)
self.assertEqual(type(a), type(b))
+ def test_pickle_for_empty_array(self):
+ for protocol in (0, 1, 2):
+ a = array.array(self.typecode)
+ b = loads(dumps(a, protocol))
+ self.assertNotEqual(id(a), id(b))
+ self.assertEqual(a, b)
+
+ a = ArraySubclass(self.typecode)
+ a.x = 10
+ b = loads(dumps(a, protocol))
+ self.assertNotEqual(id(a), id(b))
+ self.assertEqual(a, b)
+ self.assertEqual(a.x, b.x)
+ self.assertEqual(type(a), type(b))
+
def test_insert(self):
a = array.array(self.typecode, self.example)
a.insert(0, self.example[0])
@@ -713,7 +728,6 @@ class CharacterTest(StringTest):
return array.array.__new__(cls, 'c', s)
def __init__(self, s, color='blue'):
- array.array.__init__(self, 'c', s)
self.color = color
def strip(self):
diff --git a/Lib/test/test_atexit.py b/Lib/test/test_atexit.py
index 56077e7..142e543 100644
--- a/Lib/test/test_atexit.py
+++ b/Lib/test/test_atexit.py
@@ -28,7 +28,7 @@ class TestCase(unittest.TestCase):
self.stream = StringIO.StringIO()
sys.stdout = sys.stderr = self.stream
atexit._clear()
-
+
def tearDown(self):
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
@@ -50,63 +50,63 @@ class TestCase(unittest.TestCase):
atexit.register(h2)
atexit.register(h3)
atexit._run_exitfuncs()
-
+
self.assertEqual(self.stream.getvalue(), "h3\nh2\nh1\n")
def test_raise(self):
# be sure raises are handled properly
atexit.register(raise1)
atexit.register(raise2)
-
+
self.assertRaises(TypeError, atexit._run_exitfuncs)
-
+
def test_stress(self):
a = [0]
def inc():
a[0] += 1
-
+
for i in range(128):
atexit.register(inc)
atexit._run_exitfuncs()
-
+
self.assertEqual(a[0], 128)
-
+
def test_clear(self):
a = [0]
def inc():
a[0] += 1
-
+
atexit.register(inc)
atexit._clear()
atexit._run_exitfuncs()
-
+
self.assertEqual(a[0], 0)
-
+
def test_unregister(self):
a = [0]
def inc():
a[0] += 1
def dec():
a[0] -= 1
-
- for i in range(4):
+
+ for i in range(4):
atexit.register(inc)
atexit.register(dec)
atexit.unregister(inc)
atexit._run_exitfuncs()
-
+
self.assertEqual(a[0], -1)
-
+
def test_bound_methods(self):
l = []
atexit.register(l.append, 5)
atexit._run_exitfuncs()
self.assertEqual(l, [5])
-
+
atexit.unregister(l.append)
atexit._run_exitfuncs()
self.assertEqual(l, [5])
-
+
def test_main():
test_support.run_unittest(TestCase)
diff --git a/Lib/test/test_base64.py b/Lib/test/test_base64.py
index 997a413..ff2c370 100644
--- a/Lib/test/test_base64.py
+++ b/Lib/test/test_base64.py
@@ -183,16 +183,8 @@ class BaseXYTestCase(unittest.TestCase):
-def suite():
- suite = unittest.TestSuite()
- suite.addTest(unittest.makeSuite(LegacyBase64TestCase))
- suite.addTest(unittest.makeSuite(BaseXYTestCase))
- return suite
-
-
def test_main():
- test_support.run_suite(suite())
-
+ test_support.run_unittest(__name__)
if __name__ == '__main__':
- unittest.main(defaultTest='suite')
+ test_main()
diff --git a/Lib/test/test_binascii.py b/Lib/test/test_binascii.py
index 8272ad9..ea8be31 100755
--- a/Lib/test/test_binascii.py
+++ b/Lib/test/test_binascii.py
@@ -148,6 +148,15 @@ class BinASCIITest(unittest.TestCase):
"0"*75+"=\r\n=FF\r\n=FF\r\n=FF"
)
+ self.assertEqual(binascii.b2a_qp('\0\n'), '=00\n')
+ self.assertEqual(binascii.b2a_qp('\0\n', quotetabs=True), '=00\n')
+ self.assertEqual(binascii.b2a_qp('foo\tbar\t\n'), 'foo\tbar=09\n')
+ self.assertEqual(binascii.b2a_qp('foo\tbar\t\n', quotetabs=True), 'foo=09bar=09\n')
+
+ self.assertEqual(binascii.b2a_qp('.'), '=2E')
+ self.assertEqual(binascii.b2a_qp('.\n'), '=2E\n')
+ self.assertEqual(binascii.b2a_qp('a.\n'), 'a.\n')
+
def test_empty_string(self):
# A test for SF bug #1022953. Make sure SystemError is not raised.
for n in ['b2a_qp', 'a2b_hex', 'b2a_base64', 'a2b_uu', 'a2b_qp',
diff --git a/Lib/test/test_bool.py b/Lib/test/test_bool.py
index 1e19cf5..dd04b27 100644
--- a/Lib/test/test_bool.py
+++ b/Lib/test/test_bool.py
@@ -321,7 +321,7 @@ class BoolTest(unittest.TestCase):
self.assertEqual(pickle.dumps(False), "I00\n.")
self.assertEqual(pickle.dumps(True, True), "I01\n.")
self.assertEqual(pickle.dumps(False, True), "I00\n.")
-
+
try:
import cPickle
except ImportError:
diff --git a/Lib/test/test_bsddb3.py b/Lib/test/test_bsddb3.py
index 69e99c0..fe0469c 100644
--- a/Lib/test/test_bsddb3.py
+++ b/Lib/test/test_bsddb3.py
@@ -4,7 +4,7 @@ Run all test cases.
"""
import sys
import unittest
-from test.test_support import requires, verbose, run_suite, unlink
+from test.test_support import requires, verbose, run_unittest, unlink
# When running as a script instead of within the regrtest framework, skip the
# requires test, since it's obvious we want to run them.
@@ -58,9 +58,7 @@ def suite():
# For invocation through regrtest
def test_main():
- tests = suite()
- run_suite(tests)
-
+ run_unittest(suite())
# For invocation as a script
if __name__ == '__main__':
@@ -73,4 +71,4 @@ if __name__ == '__main__':
print('python version: %s' % sys.version)
print('-=' * 38)
- unittest.main(defaultTest='suite')
+ test_main()
diff --git a/Lib/test/test_builtin.py b/Lib/test/test_builtin.py
index 500516c..ffa74af 100644
--- a/Lib/test/test_builtin.py
+++ b/Lib/test/test_builtin.py
@@ -107,9 +107,12 @@ class BuiltinTest(unittest.TestCase):
__import__('sys')
__import__('time')
__import__('string')
+ __import__(name='sys')
+ __import__(name='time', level=0)
self.assertRaises(ImportError, __import__, 'spamspam')
self.assertRaises(TypeError, __import__, 1, 2, 3, 4)
self.assertRaises(ValueError, __import__, '')
+ self.assertRaises(TypeError, __import__, 'sys', name='sys')
def test_abs(self):
# int
@@ -207,15 +210,21 @@ class BuiltinTest(unittest.TestCase):
compile('print(1)\n', '', 'exec')
bom = '\xef\xbb\xbf'
compile(bom + 'print(1)\n', '', 'exec')
+ compile(source='pass', filename='?', mode='exec')
+ compile(dont_inherit=0, filename='tmp', source='0', mode='eval')
+ compile('pass', '?', dont_inherit=1, mode='exec')
self.assertRaises(TypeError, compile)
self.assertRaises(ValueError, compile, 'print(42)\n', '<string>', 'badmode')
self.assertRaises(ValueError, compile, 'print(42)\n', '<string>', 'single', 0xff)
self.assertRaises(TypeError, compile, chr(0), 'f', 'exec')
+ self.assertRaises(TypeError, compile, 'pass', '?', 'exec',
+ mode='eval', source='0', filename='tmp')
if have_unicode:
compile(unicode('print(u"\xc3\xa5")\n', 'utf8'), '', 'exec')
self.assertRaises(TypeError, compile, unichr(0), 'f', 'exec')
self.assertRaises(ValueError, compile, unicode('a = 1'), 'f', 'bad')
+
def test_delattr(self):
import sys
sys.spam = 1
@@ -1035,6 +1044,11 @@ class BuiltinTest(unittest.TestCase):
self.assertRaises(ValueError, int, '53', 40)
self.assertRaises(TypeError, int, 1, 12)
+ # SF patch #1638879: embedded NULs were not detected with
+ # explicit base
+ self.assertRaises(ValueError, int, '123\0', 10)
+ self.assertRaises(ValueError, int, '123\x00 245', 20)
+
self.assertEqual(int('100000000000000000000000000000000', 2),
4294967296)
self.assertEqual(int('102002022201221111211', 3), 4294967296)
@@ -1138,10 +1152,10 @@ class BuiltinTest(unittest.TestCase):
self.assertEqual(int(Foo0()), 42)
self.assertEqual(int(Foo1()), 42)
- # XXX invokes __int__ now
+ # XXX invokes __int__ now
# self.assertEqual(long(Foo2()), 42L)
self.assertEqual(int(Foo3()), 0)
- # XXX likewise
+ # XXX likewise
# self.assertEqual(long(Foo4()), 42)
# self.assertRaises(TypeError, long, Foo5())
diff --git a/Lib/test/test_cfgparser.py b/Lib/test/test_cfgparser.py
index 2295772..85dfa32 100644
--- a/Lib/test/test_cfgparser.py
+++ b/Lib/test/test_cfgparser.py
@@ -417,6 +417,18 @@ class SafeConfigParserTestCase(ConfigParserTestCase):
self.assertEqual(cf.get("section", "ok"), "xxx/%s")
self.assertEqual(cf.get("section", "not_ok"), "xxx/xxx/%s")
+ def test_set_malformatted_interpolation(self):
+ cf = self.fromstring("[sect]\n"
+ "option1=foo\n")
+
+ self.assertEqual(cf.get('sect', "option1"), "foo")
+
+ self.assertRaises(ValueError, cf.set, "sect", "option1", "%foo")
+ self.assertRaises(ValueError, cf.set, "sect", "option1", "foo%")
+ self.assertRaises(ValueError, cf.set, "sect", "option1", "f%oo")
+
+ self.assertEqual(cf.get('sect', "option1"), "foo")
+
def test_set_nonstring_types(self):
cf = self.fromstring("[sect]\n"
"option1=foo\n")
diff --git a/Lib/test/test_cmath.py b/Lib/test/test_cmath.py
index fd3a6bf..e091bd6 100755
--- a/Lib/test/test_cmath.py
+++ b/Lib/test/test_cmath.py
@@ -1,52 +1,196 @@
-#! /usr/bin/env python
-""" Simple test script for cmathmodule.c
- Roger E. Masse
-"""
+from test.test_support import run_unittest
+import unittest
import cmath, math
-from test.test_support import verbose, verify, TestFailed
-
-verify(abs(cmath.log(10) - math.log(10)) < 1e-9)
-verify(abs(cmath.log(10,2) - math.log(10,2)) < 1e-9)
-try:
- cmath.log('a')
-except TypeError:
- pass
-else:
- raise TestFailed
-
-try:
- cmath.log(10, 'a')
-except TypeError:
- pass
-else:
- raise TestFailed
-
-
-testdict = {'acos' : 1.0,
- 'acosh' : 1.0,
- 'asin' : 1.0,
- 'asinh' : 1.0,
- 'atan' : 0.2,
- 'atanh' : 0.2,
- 'cos' : 1.0,
- 'cosh' : 1.0,
- 'exp' : 1.0,
- 'log' : 1.0,
- 'log10' : 1.0,
- 'sin' : 1.0,
- 'sinh' : 1.0,
- 'sqrt' : 1.0,
- 'tan' : 1.0,
- 'tanh' : 1.0}
-
-for func in testdict.keys():
- f = getattr(cmath, func)
- r = f(testdict[func])
- if verbose:
- print('Calling %s(%f) = %f' % (func, testdict[func], abs(r)))
-
-p = cmath.pi
-e = cmath.e
-if verbose:
- print('PI = ', abs(p))
- print('E = ', abs(e))
+
+class CMathTests(unittest.TestCase):
+ # list of all functions in cmath
+ test_functions = [getattr(cmath, fname) for fname in [
+ 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh',
+ 'cos', 'cosh', 'exp', 'log', 'log10', 'sin', 'sinh',
+ 'sqrt', 'tan', 'tanh']]
+ # test first and second arguments independently for 2-argument log
+ test_functions.append(lambda x : cmath.log(x, 1729. + 0j))
+ test_functions.append(lambda x : cmath.log(14.-27j, x))
+
+ def cAssertAlmostEqual(self, a, b, rel_eps = 1e-10, abs_eps = 1e-100):
+ """Check that two complex numbers are almost equal."""
+ # the two complex numbers are considered almost equal if
+ # either the relative error is <= rel_eps or the absolute error
+ # is tiny, <= abs_eps.
+ if a == b == 0:
+ return
+ absolute_error = abs(a-b)
+ relative_error = absolute_error/max(abs(a), abs(b))
+ if relative_error > rel_eps and absolute_error > abs_eps:
+ self.fail("%s and %s are not almost equal" % (a, b))
+
+ def test_constants(self):
+ e_expected = 2.71828182845904523536
+ pi_expected = 3.14159265358979323846
+ self.assertAlmostEqual(cmath.pi, pi_expected, 9,
+ "cmath.pi is %s; should be %s" % (cmath.pi, pi_expected))
+ self.assertAlmostEqual(cmath.e, e_expected, 9,
+ "cmath.e is %s; should be %s" % (cmath.e, e_expected))
+
+ def test_user_object(self):
+ # Test automatic calling of __complex__ and __float__ by cmath
+ # functions
+
+ # some random values to use as test values; we avoid values
+ # for which any of the functions in cmath is undefined
+ # (i.e. 0., 1., -1., 1j, -1j) or would cause overflow
+ cx_arg = 4.419414439 + 1.497100113j
+ flt_arg = -6.131677725
+
+ # a variety of non-complex numbers, used to check that
+ # non-complex return values from __complex__ give an error
+ non_complexes = ["not complex", 1, 5, 2., None,
+ object(), NotImplemented]
+
+ # Now we introduce a variety of classes whose instances might
+ # end up being passed to the cmath functions
+
+ # usual case: new-style class implementing __complex__
+ class MyComplex(object):
+ def __init__(self, value):
+ self.value = value
+ def __complex__(self):
+ return self.value
+
+ # old-style class implementing __complex__
+ class MyComplexOS:
+ def __init__(self, value):
+ self.value = value
+ def __complex__(self):
+ return self.value
+
+ # classes for which __complex__ raises an exception
+ class SomeException(Exception):
+ pass
+ class MyComplexException(object):
+ def __complex__(self):
+ raise SomeException
+ class MyComplexExceptionOS:
+ def __complex__(self):
+ raise SomeException
+
+ # some classes not providing __float__ or __complex__
+ class NeitherComplexNorFloat(object):
+ pass
+ class NeitherComplexNorFloatOS:
+ pass
+ class MyInt(object):
+ def __int__(self): return 2
+ def __long__(self): return 2
+ def __index__(self): return 2
+ class MyIntOS:
+ def __int__(self): return 2
+ def __long__(self): return 2
+ def __index__(self): return 2
+
+ # other possible combinations of __float__ and __complex__
+ # that should work
+ class FloatAndComplex(object):
+ def __float__(self):
+ return flt_arg
+ def __complex__(self):
+ return cx_arg
+ class FloatAndComplexOS:
+ def __float__(self):
+ return flt_arg
+ def __complex__(self):
+ return cx_arg
+ class JustFloat(object):
+ def __float__(self):
+ return flt_arg
+ class JustFloatOS:
+ def __float__(self):
+ return flt_arg
+
+ for f in self.test_functions:
+ # usual usage
+ self.cAssertAlmostEqual(f(MyComplex(cx_arg)), f(cx_arg))
+ self.cAssertAlmostEqual(f(MyComplexOS(cx_arg)), f(cx_arg))
+ # other combinations of __float__ and __complex__
+ self.cAssertAlmostEqual(f(FloatAndComplex()), f(cx_arg))
+ self.cAssertAlmostEqual(f(FloatAndComplexOS()), f(cx_arg))
+ self.cAssertAlmostEqual(f(JustFloat()), f(flt_arg))
+ self.cAssertAlmostEqual(f(JustFloatOS()), f(flt_arg))
+ # TypeError should be raised for classes not providing
+ # either __complex__ or __float__, even if they provide
+ # __int__, __long__ or __index__. An old-style class
+ # currently raises AttributeError instead of a TypeError;
+ # this could be considered a bug.
+ self.assertRaises(TypeError, f, NeitherComplexNorFloat())
+ self.assertRaises(TypeError, f, MyInt())
+ self.assertRaises(Exception, f, NeitherComplexNorFloatOS())
+ self.assertRaises(Exception, f, MyIntOS())
+ # non-complex return value from __complex__ -> TypeError
+ for bad_complex in non_complexes:
+ self.assertRaises(TypeError, f, MyComplex(bad_complex))
+ self.assertRaises(TypeError, f, MyComplexOS(bad_complex))
+ # exceptions in __complex__ should be propagated correctly
+ self.assertRaises(SomeException, f, MyComplexException())
+ self.assertRaises(SomeException, f, MyComplexExceptionOS())
+
+ def test_input_type(self):
+ # ints and longs should be acceptable inputs to all cmath
+ # functions, by virtue of providing a __float__ method
+ for f in self.test_functions:
+ for arg in [2, 2.]:
+ self.cAssertAlmostEqual(f(arg), f(arg.__float__()))
+
+ # but strings should give a TypeError
+ for f in self.test_functions:
+ for arg in ["a", "long_string", "0", "1j", ""]:
+ self.assertRaises(TypeError, f, arg)
+
+ def test_cmath_matches_math(self):
+ # check that corresponding cmath and math functions are equal
+ # for floats in the appropriate range
+
+ # test_values in (0, 1)
+ test_values = [0.01, 0.1, 0.2, 0.5, 0.9, 0.99]
+
+ # test_values for functions defined on [-1., 1.]
+ unit_interval = test_values + [-x for x in test_values] + \
+ [0., 1., -1.]
+
+ # test_values for log, log10, sqrt
+ positive = test_values + [1.] + [1./x for x in test_values]
+ nonnegative = [0.] + positive
+
+ # test_values for functions defined on the whole real line
+ real_line = [0.] + positive + [-x for x in positive]
+
+ test_functions = {
+ 'acos' : unit_interval,
+ 'asin' : unit_interval,
+ 'atan' : real_line,
+ 'cos' : real_line,
+ 'cosh' : real_line,
+ 'exp' : real_line,
+ 'log' : positive,
+ 'log10' : positive,
+ 'sin' : real_line,
+ 'sinh' : real_line,
+ 'sqrt' : nonnegative,
+ 'tan' : real_line,
+ 'tanh' : real_line}
+
+ for fn, values in test_functions.items():
+ float_fn = getattr(math, fn)
+ complex_fn = getattr(cmath, fn)
+ for v in values:
+ self.cAssertAlmostEqual(float_fn(v), complex_fn(v))
+
+ # test two-argument version of log with various bases
+ for base in [0.5, 2., 10.]:
+ for v in positive:
+ self.cAssertAlmostEqual(cmath.log(v, base), math.log(v, base))
+
+def test_main():
+ run_unittest(CMathTests)
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_cmd_line.py b/Lib/test/test_cmd_line.py
index 5e89863..cacae7a 100644
--- a/Lib/test/test_cmd_line.py
+++ b/Lib/test/test_cmd_line.py
@@ -6,7 +6,7 @@ import subprocess
class CmdLineTest(unittest.TestCase):
def start_python(self, cmd_line):
- outfp, infp = popen2.popen4('%s %s' % (sys.executable, cmd_line))
+ outfp, infp = popen2.popen4('"%s" %s' % (sys.executable, cmd_line))
infp.close()
data = outfp.read()
outfp.close()
diff --git a/Lib/test/test_codecencodings_cn.py b/Lib/test/test_codecencodings_cn.py
index c558f1b..96b0d77 100644
--- a/Lib/test/test_codecencodings_cn.py
+++ b/Lib/test/test_codecencodings_cn.py
@@ -51,11 +51,7 @@ class Test_GB18030(test_multibytecodec_support.TestBase, unittest.TestCase):
has_iso10646 = True
def test_main():
- suite = unittest.TestSuite()
- suite.addTest(unittest.makeSuite(Test_GB2312))
- suite.addTest(unittest.makeSuite(Test_GBK))
- suite.addTest(unittest.makeSuite(Test_GB18030))
- test_support.run_suite(suite)
+ test_support.run_unittest(__name__)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_codecencodings_hk.py b/Lib/test/test_codecencodings_hk.py
index 1cd020f..b1c2606 100644
--- a/Lib/test/test_codecencodings_hk.py
+++ b/Lib/test/test_codecencodings_hk.py
@@ -21,9 +21,7 @@ class Test_Big5HKSCS(test_multibytecodec_support.TestBase, unittest.TestCase):
)
def test_main():
- suite = unittest.TestSuite()
- suite.addTest(unittest.makeSuite(Test_Big5HKSCS))
- test_support.run_suite(suite)
+ test_support.run_unittest(__name__)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_codecencodings_jp.py b/Lib/test/test_codecencodings_jp.py
index 558598a..5f81f41 100644
--- a/Lib/test/test_codecencodings_jp.py
+++ b/Lib/test/test_codecencodings_jp.py
@@ -99,13 +99,7 @@ class Test_SJISX0213(test_multibytecodec_support.TestBase, unittest.TestCase):
)
def test_main():
- suite = unittest.TestSuite()
- suite.addTest(unittest.makeSuite(Test_CP932))
- suite.addTest(unittest.makeSuite(Test_EUC_JISX0213))
- suite.addTest(unittest.makeSuite(Test_EUC_JP_COMPAT))
- suite.addTest(unittest.makeSuite(Test_SJIS_COMPAT))
- suite.addTest(unittest.makeSuite(Test_SJISX0213))
- test_support.run_suite(suite)
+ test_support.run_unittest(__name__)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_codecencodings_kr.py b/Lib/test/test_codecencodings_kr.py
index 8139f76..a30eaf9 100644
--- a/Lib/test/test_codecencodings_kr.py
+++ b/Lib/test/test_codecencodings_kr.py
@@ -45,11 +45,7 @@ class Test_JOHAB(test_multibytecodec_support.TestBase, unittest.TestCase):
)
def test_main():
- suite = unittest.TestSuite()
- suite.addTest(unittest.makeSuite(Test_CP949))
- suite.addTest(unittest.makeSuite(Test_EUCKR))
- suite.addTest(unittest.makeSuite(Test_JOHAB))
- test_support.run_suite(suite)
+ test_support.run_unittest(__name__)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_codecencodings_tw.py b/Lib/test/test_codecencodings_tw.py
index 7c59478..983d06f 100644
--- a/Lib/test/test_codecencodings_tw.py
+++ b/Lib/test/test_codecencodings_tw.py
@@ -21,9 +21,7 @@ class Test_Big5(test_multibytecodec_support.TestBase, unittest.TestCase):
)
def test_main():
- suite = unittest.TestSuite()
- suite.addTest(unittest.makeSuite(Test_Big5))
- test_support.run_suite(suite)
+ test_support.run_unittest(__name__)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_codecmaps_cn.py b/Lib/test/test_codecmaps_cn.py
index 8cbee76..75541ac 100644
--- a/Lib/test/test_codecmaps_cn.py
+++ b/Lib/test/test_codecmaps_cn.py
@@ -20,10 +20,7 @@ class TestGBKMap(test_multibytecodec_support.TestBase_Mapping,
'MICSFT/WINDOWS/CP936.TXT'
def test_main():
- suite = unittest.TestSuite()
- suite.addTest(unittest.makeSuite(TestGB2312Map))
- suite.addTest(unittest.makeSuite(TestGBKMap))
- test_support.run_suite(suite)
+ test_support.run_unittest(__name__)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_codecmaps_hk.py b/Lib/test/test_codecmaps_hk.py
index e7f7b96..1068d0b 100644
--- a/Lib/test/test_codecmaps_hk.py
+++ b/Lib/test/test_codecmaps_hk.py
@@ -14,9 +14,7 @@ class TestBig5HKSCSMap(test_multibytecodec_support.TestBase_Mapping,
mapfileurl = 'http://people.freebsd.org/~perky/i18n/BIG5HKSCS.TXT'
def test_main():
- suite = unittest.TestSuite()
- suite.addTest(unittest.makeSuite(TestBig5HKSCSMap))
- test_support.run_suite(suite)
+ test_support.run_unittest(__name__)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_codecmaps_jp.py b/Lib/test/test_codecmaps_jp.py
index 08052d4..5466a98 100644
--- a/Lib/test/test_codecmaps_jp.py
+++ b/Lib/test/test_codecmaps_jp.py
@@ -61,13 +61,7 @@ class TestSJISX0213Map(test_multibytecodec_support.TestBase_Mapping,
def test_main():
- suite = unittest.TestSuite()
- suite.addTest(unittest.makeSuite(TestCP932Map))
- suite.addTest(unittest.makeSuite(TestEUCJPCOMPATMap))
- suite.addTest(unittest.makeSuite(TestSJISCOMPATMap))
- suite.addTest(unittest.makeSuite(TestEUCJISX0213Map))
- suite.addTest(unittest.makeSuite(TestSJISX0213Map))
- test_support.run_suite(suite)
+ test_support.run_unittest(__name__)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_codecmaps_kr.py b/Lib/test/test_codecmaps_kr.py
index 7484a66..1b350b9 100644
--- a/Lib/test/test_codecmaps_kr.py
+++ b/Lib/test/test_codecmaps_kr.py
@@ -34,11 +34,7 @@ class TestJOHABMap(test_multibytecodec_support.TestBase_Mapping,
pass_dectest = [('\\', u'\u20a9')]
def test_main():
- suite = unittest.TestSuite()
- suite.addTest(unittest.makeSuite(TestCP949Map))
- suite.addTest(unittest.makeSuite(TestEUCKRMap))
- suite.addTest(unittest.makeSuite(TestJOHABMap))
- test_support.run_suite(suite)
+ test_support.run_unittest(__name__)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_codecmaps_tw.py b/Lib/test/test_codecmaps_tw.py
index 0b195f4..143ae23 100644
--- a/Lib/test/test_codecmaps_tw.py
+++ b/Lib/test/test_codecmaps_tw.py
@@ -25,10 +25,7 @@ class TestCP950Map(test_multibytecodec_support.TestBase_Mapping,
]
def test_main():
- suite = unittest.TestSuite()
- suite.addTest(unittest.makeSuite(TestBIG5Map))
- suite.addTest(unittest.makeSuite(TestCP950Map))
- test_support.run_suite(suite)
+ test_support.run_unittest(__name__)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py
new file mode 100644
index 0000000..a139129
--- /dev/null
+++ b/Lib/test/test_collections.py
@@ -0,0 +1,57 @@
+import unittest
+from test import test_support
+from collections import NamedTuple
+
+class TestNamedTuple(unittest.TestCase):
+
+ def test_factory(self):
+ Point = NamedTuple('Point', 'x y')
+ self.assertEqual(Point.__name__, 'Point')
+ self.assertEqual(Point.__doc__, 'Point(x, y)')
+ self.assertEqual(Point.__slots__, ())
+ self.assertEqual(Point.__module__, __name__)
+ self.assertEqual(Point.__getitem__, tuple.__getitem__)
+ self.assert_('__getitem__' in Point.__dict__) # superclass methods localized
+
+ def test_instance(self):
+ Point = NamedTuple('Point', 'x y')
+ p = Point(11, 22)
+ self.assertEqual(p, Point(x=11, y=22))
+ self.assertEqual(p, Point(11, y=22))
+ self.assertEqual(p, Point(y=22, x=11))
+ self.assertEqual(p, Point(*(11, 22)))
+ self.assertEqual(p, Point(**dict(x=11, y=22)))
+ self.assertRaises(TypeError, Point, 1) # too few args
+ self.assertRaises(TypeError, Point, 1, 2, 3) # too many args
+ self.assertRaises(TypeError, eval, 'Point(XXX=1, y=2)', locals()) # wrong keyword argument
+ self.assertRaises(TypeError, eval, 'Point(x=1)', locals()) # missing keyword argument
+ self.assertEqual(repr(p), 'Point(x=11, y=22)')
+ self.assert_('__dict__' not in dir(p)) # verify instance has no dict
+ self.assert_('__weakref__' not in dir(p))
+
+ def test_tupleness(self):
+ Point = NamedTuple('Point', 'x y')
+ p = Point(11, 22)
+
+ self.assert_(isinstance(p, tuple))
+ self.assertEqual(p, (11, 22)) # matches a real tuple
+ self.assertEqual(tuple(p), (11, 22)) # coercable to a real tuple
+ self.assertEqual(list(p), [11, 22]) # coercable to a list
+ self.assertEqual(max(p), 22) # iterable
+ self.assertEqual(max(*p), 22) # star-able
+ x, y = p
+ self.assertEqual(p, (x, y)) # unpacks like a tuple
+ self.assertEqual((p[0], p[1]), (11, 22)) # indexable like a tuple
+ self.assertRaises(IndexError, p.__getitem__, 3)
+
+ self.assertEqual(p.x, x)
+ self.assertEqual(p.y, y)
+ self.assertRaises(AttributeError, eval, 'p.z', locals())
+
+
+def test_main(verbose=None):
+ test_classes = [TestNamedTuple]
+ test_support.run_unittest(*test_classes)
+
+if __name__ == "__main__":
+ test_main(verbose=True)
diff --git a/Lib/test/test_commands.py b/Lib/test/test_commands.py
index b72a1b9..d899d66 100644
--- a/Lib/test/test_commands.py
+++ b/Lib/test/test_commands.py
@@ -4,6 +4,10 @@
'''
import unittest
import os, tempfile, re
+import warnings
+
+warnings.filterwarnings('ignore', r".*commands.getstatus.. is deprecated",
+ DeprecationWarning)
from test.test_support import TestSkipped, run_unittest, reap_children
from commands import *
diff --git a/Lib/test/test_compile.py b/Lib/test/test_compile.py
index d5fda13..ae55485 100644
--- a/Lib/test/test_compile.py
+++ b/Lib/test/test_compile.py
@@ -399,13 +399,26 @@ if 1:
# is the max. Ensure the result of too many annotations is a
# SyntaxError.
s = "def f((%s)): pass"
- s %= ', '.join('a%d:%d' % (i,i) for i in xrange(65535))
+ s %= ', '.join('a%d:%d' % (i,i) for i in xrange(65535))
self.assertRaises(SyntaxError, compile, s, '?', 'exec')
# Test that the max # of annotations compiles.
s = "def f((%s)): pass"
s %= ', '.join('a%d:%d' % (i,i) for i in xrange(65534))
compile(s, '?', 'exec')
-
+
+ def test_mangling(self):
+ class A:
+ def f():
+ __mangled = 1
+ __not_mangled__ = 2
+ import __mangled_mod
+ import __package__.module
+
+ self.assert_("_A__mangled" in A.f.__code__.co_varnames)
+ self.assert_("__not_mangled__" in A.f.__code__.co_varnames)
+ self.assert_("_A__mangled_mod" in A.f.__code__.co_varnames)
+ self.assert_("__package__" in A.f.__code__.co_varnames)
+
def test_main():
test_support.run_unittest(TestSpecifics)
diff --git a/Lib/test/test_compiler.py b/Lib/test/test_compiler.py
index 4fb6cc1..c55dc0e 100644
--- a/Lib/test/test_compiler.py
+++ b/Lib/test/test_compiler.py
@@ -190,7 +190,7 @@ class CompilerTest(unittest.TestCase):
def testBytesLiteral(self):
c = compiler.compile("b'foo'", '<string>', 'eval')
b = eval(c)
-
+
c = compiler.compile('def f(b=b"foo"):\n'
' b[0] += 1\n'
' return b\n'
@@ -200,7 +200,7 @@ class CompilerTest(unittest.TestCase):
dct = {}
exec(c, dct)
self.assertEquals(dct.get('result'), b"ioo")
-
+
c = compiler.compile('def f():\n'
' b = b"foo"\n'
' b[0] += 1\n'
diff --git a/Lib/test/test_complex.py b/Lib/test/test_complex.py
index 91f074b..0d034f5 100644
--- a/Lib/test/test_complex.py
+++ b/Lib/test/test_complex.py
@@ -208,6 +208,8 @@ class ComplexTest(unittest.TestCase):
self.assertAlmostEqual(complex(), 0)
self.assertAlmostEqual(complex("-1"), -1)
self.assertAlmostEqual(complex("+1"), +1)
+ self.assertAlmostEqual(complex("(1+2j)"), 1+2j)
+ self.assertAlmostEqual(complex("(1.3+2.2j)"), 1.3+2.2j)
class complex2(complex): pass
self.assertAlmostEqual(complex(complex2(1+1j)), 1+1j)
@@ -237,12 +239,17 @@ class ComplexTest(unittest.TestCase):
self.assertRaises(ValueError, complex, "")
self.assertRaises(TypeError, complex, None)
self.assertRaises(ValueError, complex, "\0")
+ self.assertRaises(ValueError, complex, "3\09")
self.assertRaises(TypeError, complex, "1", "2")
self.assertRaises(TypeError, complex, "1", 42)
self.assertRaises(TypeError, complex, 1, "2")
self.assertRaises(ValueError, complex, "1+")
self.assertRaises(ValueError, complex, "1+1j+1j")
self.assertRaises(ValueError, complex, "--")
+ self.assertRaises(ValueError, complex, "(1+2j")
+ self.assertRaises(ValueError, complex, "1+2j)")
+ self.assertRaises(ValueError, complex, "1+(2j)")
+ self.assertRaises(ValueError, complex, "(1+2j)123")
if test_support.have_unicode:
self.assertRaises(ValueError, complex, unicode("1"*500))
self.assertRaises(ValueError, complex, unicode("x"))
@@ -305,6 +312,11 @@ class ComplexTest(unittest.TestCase):
self.assertNotEqual(repr(-(1+0j)), '(-1+-0j)')
+ self.assertEqual(1-6j,complex(repr(1-6j)))
+ self.assertEqual(1+6j,complex(repr(1+6j)))
+ self.assertEqual(-6j,complex(repr(-6j)))
+ self.assertEqual(6j,complex(repr(6j)))
+
def test_neg(self):
self.assertEqual(-(1+6j), -1-6j)
diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py
index 9142428..a3a9a5b 100644
--- a/Lib/test/test_contextlib.py
+++ b/Lib/test/test_contextlib.py
@@ -9,7 +9,7 @@ import tempfile
import unittest
import threading
from contextlib import * # Tests __all__
-from test.test_support import run_suite
+from test import test_support
class ContextManagerTestCase(unittest.TestCase):
@@ -332,9 +332,7 @@ class LockContextTestCase(unittest.TestCase):
# This is needed to make the test actually run under regrtest.py!
def test_main():
- run_suite(
- unittest.defaultTestLoader.loadTestsFromModule(sys.modules[__name__])
- )
+ test_support.run_unittest(__name__)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_crypt.py b/Lib/test/test_crypt.py
index 7336711..a9c28cd 100755
--- a/Lib/test/test_crypt.py
+++ b/Lib/test/test_crypt.py
@@ -3,7 +3,7 @@
Roger E. Masse
"""
-from test.test_support import verify, verbose
+from test.test_support import verbose
import crypt
c = crypt.crypt('mypassword', 'ab')
diff --git a/Lib/test/test_csv.py b/Lib/test/test_csv.py
index 980f3fc..0ca6b78 100644
--- a/Lib/test/test_csv.py
+++ b/Lib/test/test_csv.py
@@ -484,12 +484,16 @@ class TestDialectExcel(TestCsvBase):
self.readerAssertEqual('a"b"c', [['a"b"c']])
def test_quotes_and_more(self):
+ # Excel would never write a field containing '"a"b', but when
+ # reading one, it will return 'ab'.
self.readerAssertEqual('"a"b', [['ab']])
def test_lone_quote(self):
self.readerAssertEqual('a"b', [['a"b']])
def test_quote_and_quote(self):
+ # Excel would never write a field containing '"a" "b"', but when
+ # reading one, it will return 'a "b"'.
self.readerAssertEqual('"a" "b"', [['a "b"']])
def test_space_and_quote(self):
diff --git a/Lib/test/test_ctypes.py b/Lib/test/test_ctypes.py
index fd2032b..7a81ab4 100644
--- a/Lib/test/test_ctypes.py
+++ b/Lib/test/test_ctypes.py
@@ -1,12 +1,12 @@
import unittest
-from test.test_support import run_suite
+from test.test_support import run_unittest
import ctypes.test
def test_main():
skipped, testcases = ctypes.test.get_tests(ctypes.test, "test_*.py", verbosity=0)
suites = [unittest.makeSuite(t) for t in testcases]
- run_suite(unittest.TestSuite(suites))
+ run_unittest(unittest.TestSuite(suites))
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_curses.py b/Lib/test/test_curses.py
index fd8e6ed..ee679e7 100644
--- a/Lib/test/test_curses.py
+++ b/Lib/test/test_curses.py
@@ -129,6 +129,12 @@ def window_funcs(stdscr):
stdscr.touchline(5,5,0)
stdscr.vline('a', 3)
stdscr.vline('a', 3, curses.A_STANDOUT)
+ stdscr.chgat(5, 2, 3, curses.A_BLINK)
+ stdscr.chgat(3, curses.A_BOLD)
+ stdscr.chgat(5, 8, curses.A_UNDERLINE)
+ stdscr.chgat(curses.A_BLINK)
+ stdscr.refresh()
+
stdscr.vline(1,1, 'a', 3)
stdscr.vline(1,1, 'a', 3, curses.A_STANDOUT)
@@ -241,12 +247,21 @@ def test_userptr_without_set(stdscr):
except curses.panel.error:
pass
+def test_resize_term(stdscr):
+ if hasattr(curses, 'resizeterm'):
+ lines, cols = curses.LINES, curses.COLS
+ curses.resizeterm(lines - 1, cols + 1)
+
+ if curses.LINES != lines - 1 or curses.COLS != cols + 1:
+ raise RuntimeError, "Expected resizeterm to update LINES and COLS"
+
def main(stdscr):
curses.savetty()
try:
module_funcs(stdscr)
window_funcs(stdscr)
test_userptr_without_set(stdscr)
+ test_resize_term(stdscr)
finally:
curses.resetty()
diff --git a/Lib/test/test_datetime.py b/Lib/test/test_datetime.py
index 8a7c1b6..287585e 100644
--- a/Lib/test/test_datetime.py
+++ b/Lib/test/test_datetime.py
@@ -3,6 +3,7 @@
See http://www.zope.org/Members/fdrake/DateTimeWiki/TestCases
"""
+import os
import sys
import pickle
import unittest
@@ -135,7 +136,7 @@ class TestTZInfo(unittest.TestCase):
# Base clase for testing a particular aspect of timedelta, time, date and
# datetime comparisons.
-class HarmlessMixedComparison(unittest.TestCase):
+class HarmlessMixedComparison:
# Test that __eq__ and __ne__ don't complain for mixed-type comparisons.
# Subclasses must define 'theclass', and theclass(1, 1, 1) must be a
@@ -174,7 +175,7 @@ class HarmlessMixedComparison(unittest.TestCase):
#############################################################################
# timedelta tests
-class TestTimeDelta(HarmlessMixedComparison):
+class TestTimeDelta(HarmlessMixedComparison, unittest.TestCase):
theclass = timedelta
@@ -521,7 +522,7 @@ class TestDateOnly(unittest.TestCase):
class SubclassDate(date):
sub_var = 1
-class TestDate(HarmlessMixedComparison):
+class TestDate(HarmlessMixedComparison, unittest.TestCase):
# Tests here should pass for both dates and datetimes, except for a
# few tests that TestDateTime overrides.
@@ -1452,6 +1453,21 @@ class TestDateTime(TestDate):
self.assertRaises(ValueError, self.theclass.utcfromtimestamp,
insane)
+ def test_negative_float_fromtimestamp(self):
+ # Windows doesn't accept negative timestamps
+ if os.name == "nt":
+ return
+ # The result is tz-dependent; at least test that this doesn't
+ # fail (like it did before bug 1646728 was fixed).
+ self.theclass.fromtimestamp(-1.05)
+
+ def test_negative_float_utcfromtimestamp(self):
+ # Windows doesn't accept negative timestamps
+ if os.name == "nt":
+ return
+ d = self.theclass.utcfromtimestamp(-1.05)
+ self.assertEquals(d, self.theclass(1969, 12, 31, 23, 59, 58, 950000))
+
def test_utcnow(self):
import time
@@ -1607,7 +1623,7 @@ class TestDateTime(TestDate):
class SubclassTime(time):
sub_var = 1
-class TestTime(HarmlessMixedComparison):
+class TestTime(HarmlessMixedComparison, unittest.TestCase):
theclass = time
@@ -1890,7 +1906,7 @@ class TestTime(HarmlessMixedComparison):
# A mixin for classes with a tzinfo= argument. Subclasses must define
# theclass as a class atribute, and theclass(1, 1, 1, tzinfo=whatever)
# must be legit (which is true for time and datetime).
-class TZInfoBase(unittest.TestCase):
+class TZInfoBase:
def test_argument_passing(self):
cls = self.theclass
@@ -2050,7 +2066,7 @@ class TZInfoBase(unittest.TestCase):
# Testing time objects with a non-None tzinfo.
-class TestTimeTZ(TestTime, TZInfoBase):
+class TestTimeTZ(TestTime, TZInfoBase, unittest.TestCase):
theclass = time
def test_empty(self):
@@ -2298,7 +2314,7 @@ class TestTimeTZ(TestTime, TZInfoBase):
# Testing datetime objects with a non-None tzinfo.
-class TestDateTimeTZ(TestDateTime, TZInfoBase):
+class TestDateTimeTZ(TestDateTime, TZInfoBase, unittest.TestCase):
theclass = datetime
def test_trivial(self):
@@ -3259,45 +3275,8 @@ class Oddballs(unittest.TestCase):
self.assertEqual(as_datetime, datetime_sc)
self.assertEqual(datetime_sc, as_datetime)
-def test_suite():
- allsuites = [unittest.makeSuite(klass, 'test')
- for klass in (TestModule,
- TestTZInfo,
- TestTimeDelta,
- TestDateOnly,
- TestDate,
- TestDateTime,
- TestTime,
- TestTimeTZ,
- TestDateTimeTZ,
- TestTimezoneConversions,
- Oddballs,
- )
- ]
- return unittest.TestSuite(allsuites)
-
def test_main():
- import gc
- import sys
-
- thesuite = test_suite()
- lastrc = None
- while True:
- test_support.run_suite(thesuite)
- if 1: # change to 0, under a debug build, for some leak detection
- break
- gc.collect()
- if gc.garbage:
- raise SystemError("gc.garbage not empty after test run: %r" %
- gc.garbage)
- if hasattr(sys, 'gettotalrefcount'):
- thisrc = sys.gettotalrefcount()
- print('*' * 10, 'total refs:', thisrc, end=' ', file=sys.stderr)
- if lastrc:
- print('delta:', thisrc - lastrc, file=sys.stderr)
- else:
- print(file=sys.stderr)
- lastrc = thisrc
+ test_support.run_unittest(__name__)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_defaultdict.py b/Lib/test/test_defaultdict.py
index c8ae2be..a8dfdb1 100644
--- a/Lib/test/test_defaultdict.py
+++ b/Lib/test/test_defaultdict.py
@@ -132,6 +132,15 @@ class TestDefaultDict(unittest.TestCase):
self.assertEqual(d2.default_factory, list)
self.assertEqual(d2, d1)
+ def test_keyerror_without_factory(self):
+ d1 = defaultdict()
+ try:
+ d1[(1,)]
+ except KeyError as err:
+ self.assertEqual(err.message, (1,))
+ else:
+ self.fail("expected KeyError")
+
def test_main():
test_support.run_unittest(TestDefaultDict)
diff --git a/Lib/test/test_deque.py b/Lib/test/test_deque.py
index 7df31b7..4f7e60c 100644
--- a/Lib/test/test_deque.py
+++ b/Lib/test/test_deque.py
@@ -575,7 +575,7 @@ deque(['a', 'b', 'd', 'e', 'f'])
>>> for value in roundrobin('abc', 'd', 'efgh'):
... print(value)
-...
+...
a
d
e
diff --git a/Lib/test/test_descr.py b/Lib/test/test_descr.py
index eec4e40..aba1c74 100644
--- a/Lib/test/test_descr.py
+++ b/Lib/test/test_descr.py
@@ -773,6 +773,22 @@ def metaclass():
except TypeError: pass
else: raise TestFailed, "calling object w/o call method should raise TypeError"
+ # Testing code to find most derived baseclass
+ class A(type):
+ def __new__(*args, **kwargs):
+ return type.__new__(*args, **kwargs)
+
+ class B(object):
+ pass
+
+ class C(object, metaclass=A):
+ pass
+
+ # The most derived metaclass of D is A rather than type.
+ class D(B, C):
+ pass
+
+
def pymods():
if verbose: print("Testing Python subclass of module...")
log = []
@@ -1072,6 +1088,45 @@ def slots():
raise TestFailed, "[''] slots not caught"
class C(object):
__slots__ = ["a", "a_b", "_a", "A0123456789Z"]
+ # XXX(nnorwitz): was there supposed to be something tested
+ # from the class above?
+
+ # Test a single string is not expanded as a sequence.
+ class C(object):
+ __slots__ = "abc"
+ c = C()
+ c.abc = 5
+ vereq(c.abc, 5)
+
+ # Test unicode slot names
+ try:
+ unicode
+ except NameError:
+ pass
+ else:
+ # Test a single unicode string is not expanded as a sequence.
+ class C(object):
+ __slots__ = unicode("abc")
+ c = C()
+ c.abc = 5
+ vereq(c.abc, 5)
+
+ # _unicode_to_string used to modify slots in certain circumstances
+ slots = (unicode("foo"), unicode("bar"))
+ class C(object):
+ __slots__ = slots
+ x = C()
+ x.foo = 5
+ vereq(x.foo, 5)
+ veris(type(slots[0]), unicode)
+ # this used to leak references
+ try:
+ class C(object):
+ __slots__ = [unichr(128)]
+ except (TypeError, UnicodeEncodeError):
+ pass
+ else:
+ raise TestFailed, "[unichr(128)] slots not caught"
# Test leaks
class Counted(object):
@@ -1318,6 +1373,22 @@ def errors():
else:
verify(0, "__slots__ = [1] should be illegal")
+ class M1(type):
+ pass
+ class M2(type):
+ pass
+ class A1(object, metaclass=M1):
+ pass
+ class A2(object, metaclass=M2):
+ pass
+ try:
+ class B(A1, A2):
+ pass
+ except TypeError:
+ pass
+ else:
+ verify(0, "finding the most derived metaclass should have failed")
+
def classmethods():
if verbose: print("Testing class methods...")
class C(object):
@@ -2092,7 +2163,6 @@ def inherits():
__slots__ = ['prec']
def __init__(self, value=0.0, prec=12):
self.prec = int(prec)
- float.__init__(self, value)
def __repr__(self):
return "%.*g" % (self.prec, self)
vereq(repr(precfloat(1.1)), "1.1")
@@ -2644,6 +2714,51 @@ def setclass():
cant(o, type(1))
cant(o, type(None))
del o
+ class G(object):
+ __slots__ = ["a", "b"]
+ class H(object):
+ __slots__ = ["b", "a"]
+ try:
+ unicode
+ except NameError:
+ class I(object):
+ __slots__ = ["a", "b"]
+ else:
+ class I(object):
+ __slots__ = [unicode("a"), unicode("b")]
+ class J(object):
+ __slots__ = ["c", "b"]
+ class K(object):
+ __slots__ = ["a", "b", "d"]
+ class L(H):
+ __slots__ = ["e"]
+ class M(I):
+ __slots__ = ["e"]
+ class N(J):
+ __slots__ = ["__weakref__"]
+ class P(J):
+ __slots__ = ["__dict__"]
+ class Q(J):
+ pass
+ class R(J):
+ __slots__ = ["__dict__", "__weakref__"]
+
+ for cls, cls2 in ((G, H), (G, I), (I, H), (Q, R), (R, Q)):
+ x = cls()
+ x.a = 1
+ x.__class__ = cls2
+ verify(x.__class__ is cls2,
+ "assigning %r as __class__ for %r silently failed" % (cls2, x))
+ vereq(x.a, 1)
+ x.__class__ = cls
+ verify(x.__class__ is cls,
+ "assigning %r as __class__ for %r silently failed" % (cls, x))
+ vereq(x.a, 1)
+ for cls in G, J, K, L, M, N, P, R, list, Int:
+ for cls2 in G, J, K, L, M, N, P, R, list, Int:
+ if cls is cls2:
+ continue
+ cant(cls(), cls2)
def setdict():
if verbose: print("Testing __dict__ assignment...")
@@ -3999,6 +4114,19 @@ def notimplemented():
check(iexpr, c, N1)
check(iexpr, c, N2)
+def test_assign_slice():
+ # ceval.c's assign_slice used to check for
+ # tp->tp_as_sequence->sq_slice instead of
+ # tp->tp_as_sequence->sq_ass_slice
+
+ class C(object):
+ def __setslice__(self, start, stop, value):
+ self.value = value
+
+ c = C()
+ c[1:2] = 3
+ vereq(c.value, 3)
+
def test_main():
weakref_segfault() # Must be first, somehow
wrapper_segfault()
@@ -4094,6 +4222,7 @@ def test_main():
test_init()
methodwrapper()
notimplemented()
+ test_assign_slice()
if verbose: print("All OK")
diff --git a/Lib/test/test_descrtut.py b/Lib/test/test_descrtut.py
index 001aa49..5b11666 100644
--- a/Lib/test/test_descrtut.py
+++ b/Lib/test/test_descrtut.py
@@ -243,7 +243,7 @@ methods. Static methods are easy to describe: they behave pretty much like
static methods in C++ or Java. Here's an example:
>>> class C:
- ...
+ ...
... @staticmethod
... def foo(x, y):
... print("staticmethod", x, y)
diff --git a/Lib/test/test_dict.py b/Lib/test/test_dict.py
index d8dfd7e..d98c607 100644
--- a/Lib/test/test_dict.py
+++ b/Lib/test/test_dict.py
@@ -425,7 +425,7 @@ class DictTest(unittest.TestCase):
except RuntimeError as err:
self.assertEqual(err.args, (42,))
else:
- self.fail_("e[42] didn't raise RuntimeError")
+ self.fail("e[42] didn't raise RuntimeError")
class F(dict):
def __init__(self):
# An instance variable __missing__ should have no effect
@@ -436,7 +436,7 @@ class DictTest(unittest.TestCase):
except KeyError as err:
self.assertEqual(err.args, (42,))
else:
- self.fail_("f[42] didn't raise KeyError")
+ self.fail("f[42] didn't raise KeyError")
class G(dict):
pass
g = G()
@@ -445,7 +445,7 @@ class DictTest(unittest.TestCase):
except KeyError as err:
self.assertEqual(err.args, (42,))
else:
- self.fail_("g[42] didn't raise KeyError")
+ self.fail("g[42] didn't raise KeyError")
def test_tuple_keyerror(self):
# SF #1576657
@@ -457,6 +457,76 @@ class DictTest(unittest.TestCase):
else:
self.fail("missing KeyError")
+ def test_bad_key(self):
+ # Dictionary lookups should fail if __cmp__() raises an exception.
+ class CustomException(Exception):
+ pass
+
+ class BadDictKey:
+ def __hash__(self):
+ return hash(self.__class__)
+
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ raise CustomException
+ return other
+
+ d = {}
+ x1 = BadDictKey()
+ x2 = BadDictKey()
+ d[x1] = 1
+ for stmt in ['d[x2] = 2',
+ 'z = d[x2]',
+ 'x2 in d',
+ 'd.get(x2)',
+ 'd.setdefault(x2, 42)',
+ 'd.pop(x2)',
+ 'd.update({x2: 2})']:
+ try:
+ exec(stmt, locals())
+ except CustomException:
+ pass
+ else:
+ self.fail("Statement %r didn't raise exception" % stmt)
+
+ def test_resize1(self):
+ # Dict resizing bug, found by Jack Jansen in 2.2 CVS development.
+ # This version got an assert failure in debug build, infinite loop in
+ # release build. Unfortunately, provoking this kind of stuff requires
+ # a mix of inserts and deletes hitting exactly the right hash codes in
+ # exactly the right order, and I can't think of a randomized approach
+ # that would be *likely* to hit a failing case in reasonable time.
+
+ d = {}
+ for i in range(5):
+ d[i] = i
+ for i in range(5):
+ del d[i]
+ for i in range(5, 9): # i==8 was the problem
+ d[i] = i
+
+ def test_resize2(self):
+ # Another dict resizing bug (SF bug #1456209).
+ # This caused Segmentation faults or Illegal instructions.
+
+ class X(object):
+ def __hash__(self):
+ return 5
+ def __eq__(self, other):
+ if resizing:
+ d.clear()
+ return False
+ d = {}
+ resizing = False
+ d[X()] = 1
+ d[X()] = 2
+ d[X()] = 3
+ d[X()] = 4
+ d[X()] = 5
+ # now trigger a resize
+ resizing = True
+ d[9] = 6
+
from test import mapping_tests
diff --git a/Lib/test/test_dis.py b/Lib/test/test_dis.py
index 8e97a6a..5d82cd7 100644
--- a/Lib/test/test_dis.py
+++ b/Lib/test/test_dis.py
@@ -1,11 +1,11 @@
-from test.test_support import verify, verbose, TestFailed, run_unittest
+# Minimal tests for dis module
+
+from test.test_support import verbose, run_unittest
+import unittest
import sys
import dis
import StringIO
-# Minimal tests for dis module
-
-import unittest
def _f(a):
print(a)
diff --git a/Lib/test/test_doctest.py b/Lib/test/test_doctest.py
index 087b340..60079a6 100644
--- a/Lib/test/test_doctest.py
+++ b/Lib/test/test_doctest.py
@@ -35,7 +35,7 @@ class SampleClass:
>>> for i in range(10):
... sc = sc.double()
... print(sc.get(), end=' ')
- 6 12 24 48 96 192 384 768 1536 3072
+ 6 12 24 48 96 192 384 768 1536 3072
"""
def __init__(self, val):
"""
@@ -571,7 +571,7 @@ DocTestFinder finds the line number of each example:
...
... >>> for x in range(10):
... ... print(x, end=' ')
- ... 0 1 2 3 4 5 6 7 8 9
+ ... 0 1 2 3 4 5 6 7 8 9
... >>> x//2
... 6
... '''
@@ -1461,11 +1461,11 @@ at the end of any line:
>>> def f(x): r'''
... >>> for x in range(10): # doctest: +ELLIPSIS
... ... print(x, end=' ')
- ... 0 1 2 ... 9
+ ... 0 1 2 ... 9
...
... >>> for x in range(10):
... ... print(x, end=' ') # doctest: +ELLIPSIS
- ... 0 1 2 ... 9
+ ... 0 1 2 ... 9
... '''
>>> test = doctest.DocTestFinder().find(f)[0]
>>> doctest.DocTestRunner(verbose=False).run(test)
@@ -1478,7 +1478,7 @@ option directive, then they are combined:
... Should fail (option directive not on the last line):
... >>> for x in range(10): # doctest: +ELLIPSIS
... ... print(x, end=' ') # doctest: +NORMALIZE_WHITESPACE
- ... 0 1 2...9
+ ... 0 1 2...9
... '''
>>> test = doctest.DocTestFinder().find(f)[0]
>>> doctest.DocTestRunner(verbose=False).run(test)
diff --git a/Lib/test/test_email.py b/Lib/test/test_email.py
index de0eee3..f609968 100644
--- a/Lib/test/test_email.py
+++ b/Lib/test/test_email.py
@@ -4,10 +4,10 @@
import unittest
# The specific tests now live in Lib/email/test
from email.test.test_email import suite
-from test.test_support import run_suite
+from test import test_support
def test_main():
- run_suite(suite())
+ test_support.run_unittest(suite())
if __name__ == '__main__':
test_main()
diff --git a/Lib/test/test_email_codecs.py b/Lib/test/test_email_codecs.py
index c550a6f..8951f81 100644
--- a/Lib/test/test_email_codecs.py
+++ b/Lib/test/test_email_codecs.py
@@ -9,7 +9,7 @@ from test import test_support
def test_main():
suite = test_email_codecs.suite()
suite.addTest(test_email_codecs_renamed.suite())
- test_support.run_suite(suite)
+ test_support.run_unittest(suite)
if __name__ == '__main__':
test_main()
diff --git a/Lib/test/test_email_renamed.py b/Lib/test/test_email_renamed.py
index c3af598..163e791 100644
--- a/Lib/test/test_email_renamed.py
+++ b/Lib/test/test_email_renamed.py
@@ -4,10 +4,10 @@
import unittest
# The specific tests now live in Lib/email/test
from email.test.test_email_renamed import suite
-from test.test_support import run_suite
+from test import test_support
def test_main():
- run_suite(suite())
+ test_support.run_unittest(suite())
if __name__ == '__main__':
test_main()
diff --git a/Lib/test/test_exceptions.py b/Lib/test/test_exceptions.py
index c0a5094..5a22297 100644
--- a/Lib/test/test_exceptions.py
+++ b/Lib/test/test_exceptions.py
@@ -229,6 +229,9 @@ class ExceptionTests(unittest.TestCase):
(EnvironmentError, (1, 'strErrorStr', 'filenameStr'),
{'message' : '', 'args' : (1, 'strErrorStr'), 'errno' : 1,
'strerror' : 'strErrorStr', 'filename' : 'filenameStr'}),
+ (SyntaxError, (), {'message' : '', 'msg' : None, 'text' : None,
+ 'filename' : None, 'lineno' : None, 'offset' : None,
+ 'print_file_and_line' : None}),
(SyntaxError, ('msgStr',),
{'message' : 'msgStr', 'args' : ('msgStr',), 'text' : None,
'print_file_and_line' : None, 'msg' : 'msgStr',
@@ -337,7 +340,7 @@ class ExceptionTests(unittest.TestCase):
def testExceptionCleanup(self):
# Make sure "except V as N" exceptions are cleaned up properly
-
+
try:
raise Exception()
except Exception as e:
diff --git a/Lib/test/test_fileinput.py b/Lib/test/test_fileinput.py
index 17ca944..10d3cfc 100644
--- a/Lib/test/test_fileinput.py
+++ b/Lib/test/test_fileinput.py
@@ -3,7 +3,9 @@ Tests for fileinput module.
Nick Mathewson
'''
-from test.test_support import verify, verbose, TESTFN, TestFailed
+import unittest
+from test.test_support import verbose, TESTFN, run_unittest
+from test.test_support import unlink as safe_unlink
import sys, os, re
from StringIO import StringIO
from fileinput import FileInput, hook_encoded
@@ -18,211 +20,206 @@ from fileinput import FileInput, hook_encoded
def writeTmp(i, lines, mode='w'): # opening in text mode is the default
name = TESTFN + str(i)
f = open(name, mode)
- for line in lines:
- f.write(line)
+ f.writelines(lines)
f.close()
return name
-pat = re.compile(r'LINE (\d+) OF FILE (\d+)')
-
def remove_tempfiles(*names):
for name in names:
- try:
- os.unlink(name)
- except:
- pass
-
-def runTests(t1, t2, t3, t4, bs=0, round=0):
- start = 1 + round*6
- if verbose:
- print('%s. Simple iteration (bs=%s)' % (start+0, bs))
- fi = FileInput(files=(t1, t2, t3, t4), bufsize=bs)
- lines = list(fi)
- fi.close()
- verify(len(lines) == 31)
- verify(lines[4] == 'Line 5 of file 1\n')
- verify(lines[30] == 'Line 1 of file 4\n')
- verify(fi.lineno() == 31)
- verify(fi.filename() == t4)
-
- if verbose:
- print('%s. Status variables (bs=%s)' % (start+1, bs))
- fi = FileInput(files=(t1, t2, t3, t4), bufsize=bs)
- s = "x"
- while s and s != 'Line 6 of file 2\n':
- s = fi.readline()
- verify(fi.filename() == t2)
- verify(fi.lineno() == 21)
- verify(fi.filelineno() == 6)
- verify(not fi.isfirstline())
- verify(not fi.isstdin())
-
- if verbose:
- print('%s. Nextfile (bs=%s)' % (start+2, bs))
- fi.nextfile()
- verify(fi.readline() == 'Line 1 of file 3\n')
- verify(fi.lineno() == 22)
- fi.close()
-
- if verbose:
- print('%s. Stdin (bs=%s)' % (start+3, bs))
- fi = FileInput(files=(t1, t2, t3, t4, '-'), bufsize=bs)
- savestdin = sys.stdin
- try:
- sys.stdin = StringIO("Line 1 of stdin\nLine 2 of stdin\n")
+ safe_unlink(name)
+
+class BufferSizesTests(unittest.TestCase):
+ def test_buffer_sizes(self):
+ # First, run the tests with default and teeny buffer size.
+ for round, bs in (0, 0), (1, 30):
+ try:
+ t1 = writeTmp(1, ["Line %s of file 1\n" % (i+1) for i in range(15)])
+ t2 = writeTmp(2, ["Line %s of file 2\n" % (i+1) for i in range(10)])
+ t3 = writeTmp(3, ["Line %s of file 3\n" % (i+1) for i in range(5)])
+ t4 = writeTmp(4, ["Line %s of file 4\n" % (i+1) for i in range(1)])
+ self.buffer_size_test(t1, t2, t3, t4, bs, round)
+ finally:
+ remove_tempfiles(t1, t2, t3, t4)
+
+ def buffer_size_test(self, t1, t2, t3, t4, bs=0, round=0):
+ pat = re.compile(r'LINE (\d+) OF FILE (\d+)')
+
+ start = 1 + round*6
+ if verbose:
+ print('%s. Simple iteration (bs=%s)' % (start+0, bs))
+ fi = FileInput(files=(t1, t2, t3, t4), bufsize=bs)
lines = list(fi)
- verify(len(lines) == 33)
- verify(lines[32] == 'Line 2 of stdin\n')
- verify(fi.filename() == '<stdin>')
+ fi.close()
+ self.assertEqual(len(lines), 31)
+ self.assertEqual(lines[4], 'Line 5 of file 1\n')
+ self.assertEqual(lines[30], 'Line 1 of file 4\n')
+ self.assertEqual(fi.lineno(), 31)
+ self.assertEqual(fi.filename(), t4)
+
+ if verbose:
+ print('%s. Status variables (bs=%s)' % (start+1, bs))
+ fi = FileInput(files=(t1, t2, t3, t4), bufsize=bs)
+ s = "x"
+ while s and s != 'Line 6 of file 2\n':
+ s = fi.readline()
+ self.assertEqual(fi.filename(), t2)
+ self.assertEqual(fi.lineno(), 21)
+ self.assertEqual(fi.filelineno(), 6)
+ self.failIf(fi.isfirstline())
+ self.failIf(fi.isstdin())
+
+ if verbose:
+ print('%s. Nextfile (bs=%s)' % (start+2, bs))
+ fi.nextfile()
+ self.assertEqual(fi.readline(), 'Line 1 of file 3\n')
+ self.assertEqual(fi.lineno(), 22)
+ fi.close()
+
+ if verbose:
+ print('%s. Stdin (bs=%s)' % (start+3, bs))
+ fi = FileInput(files=(t1, t2, t3, t4, '-'), bufsize=bs)
+ savestdin = sys.stdin
+ try:
+ sys.stdin = StringIO("Line 1 of stdin\nLine 2 of stdin\n")
+ lines = list(fi)
+ self.assertEqual(len(lines), 33)
+ self.assertEqual(lines[32], 'Line 2 of stdin\n')
+ self.assertEqual(fi.filename(), '<stdin>')
+ fi.nextfile()
+ finally:
+ sys.stdin = savestdin
+
+ if verbose:
+ print('%s. Boundary conditions (bs=%s)' % (start+4, bs))
+ fi = FileInput(files=(t1, t2, t3, t4), bufsize=bs)
+ self.assertEqual(fi.lineno(), 0)
+ self.assertEqual(fi.filename(), None)
fi.nextfile()
- finally:
- sys.stdin = savestdin
-
- if verbose:
- print('%s. Boundary conditions (bs=%s)' % (start+4, bs))
- fi = FileInput(files=(t1, t2, t3, t4), bufsize=bs)
- verify(fi.lineno() == 0)
- verify(fi.filename() == None)
- fi.nextfile()
- verify(fi.lineno() == 0)
- verify(fi.filename() == None)
-
- if verbose:
- print('%s. Inplace (bs=%s)' % (start+5, bs))
- savestdout = sys.stdout
- try:
- fi = FileInput(files=(t1, t2, t3, t4), inplace=1, bufsize=bs)
+ self.assertEqual(fi.lineno(), 0)
+ self.assertEqual(fi.filename(), None)
+
+ if verbose:
+ print('%s. Inplace (bs=%s)' % (start+5, bs))
+ savestdout = sys.stdout
+ try:
+ fi = FileInput(files=(t1, t2, t3, t4), inplace=1, bufsize=bs)
+ for line in fi:
+ line = line[:-1].upper()
+ print(line)
+ fi.close()
+ finally:
+ sys.stdout = savestdout
+
+ fi = FileInput(files=(t1, t2, t3, t4), bufsize=bs)
for line in fi:
- line = line[:-1].upper()
- print(line)
+ self.assertEqual(line[-1], '\n')
+ m = pat.match(line[:-1])
+ self.assertNotEqual(m, None)
+ self.assertEqual(int(m.group(1)), fi.filelineno())
fi.close()
- finally:
- sys.stdout = savestdout
-
- fi = FileInput(files=(t1, t2, t3, t4), bufsize=bs)
- for line in fi:
- verify(line[-1] == '\n')
- m = pat.match(line[:-1])
- verify(m != None)
- verify(int(m.group(1)) == fi.filelineno())
- fi.close()
-
-
-def writeFiles():
- global t1, t2, t3, t4
- t1 = writeTmp(1, ["Line %s of file 1\n" % (i+1) for i in range(15)])
- t2 = writeTmp(2, ["Line %s of file 2\n" % (i+1) for i in range(10)])
- t3 = writeTmp(3, ["Line %s of file 3\n" % (i+1) for i in range(5)])
- t4 = writeTmp(4, ["Line %s of file 4\n" % (i+1) for i in range(1)])
-
-# First, run the tests with default and teeny buffer size.
-for round, bs in (0, 0), (1, 30):
- try:
- writeFiles()
- runTests(t1, t2, t3, t4, bs, round)
- finally:
- remove_tempfiles(t1, t2, t3, t4)
-
-# Next, check for proper behavior with 0-byte files.
-if verbose:
- print("13. 0-byte files")
-try:
- t1 = writeTmp(1, [""])
- t2 = writeTmp(2, [""])
- t3 = writeTmp(3, ["The only line there is.\n"])
- t4 = writeTmp(4, [""])
- fi = FileInput(files=(t1, t2, t3, t4))
- line = fi.readline()
- verify(line == 'The only line there is.\n')
- verify(fi.lineno() == 1)
- verify(fi.filelineno() == 1)
- verify(fi.filename() == t3)
- line = fi.readline()
- verify(not line)
- verify(fi.lineno() == 1)
- verify(fi.filelineno() == 0)
- verify(fi.filename() == t4)
- fi.close()
-finally:
- remove_tempfiles(t1, t2, t3, t4)
-
-if verbose:
- print("14. Files that don't end with newline")
-try:
- t1 = writeTmp(1, ["A\nB\nC"])
- t2 = writeTmp(2, ["D\nE\nF"])
- fi = FileInput(files=(t1, t2))
- lines = list(fi)
- verify(lines == ["A\n", "B\n", "C", "D\n", "E\n", "F"])
- verify(fi.filelineno() == 3)
- verify(fi.lineno() == 6)
-finally:
- remove_tempfiles(t1, t2)
-
-if verbose:
- print("15. Unicode filenames")
-try:
- t1 = writeTmp(1, ["A\nB"])
- encoding = sys.getfilesystemencoding()
- if encoding is None:
- encoding = 'ascii'
- fi = FileInput(files=unicode(t1, encoding))
- lines = list(fi)
- verify(lines == ["A\n", "B"])
-finally:
- remove_tempfiles(t1)
-
-if verbose:
- print("16. fileno()")
-try:
- t1 = writeTmp(1, ["A\nB"])
- t2 = writeTmp(2, ["C\nD"])
- fi = FileInput(files=(t1, t2))
- verify(fi.fileno() == -1)
- line = next(fi)
- verify(fi.fileno() != -1)
- fi.nextfile()
- verify(fi.fileno() == -1)
- line = list(fi)
- verify(fi.fileno() == -1)
-finally:
- remove_tempfiles(t1, t2)
-
-if verbose:
- print("17. Specify opening mode")
-try:
- # invalid mode, should raise ValueError
- fi = FileInput(mode="w")
- raise TestFailed("FileInput should reject invalid mode argument")
-except ValueError:
- pass
-try:
- # try opening in universal newline mode
- t1 = writeTmp(1, ["A\nB\r\nC\rD"], mode="wb")
- fi = FileInput(files=t1, mode="U")
- lines = list(fi)
- verify(lines == ["A\n", "B\n", "C\n", "D"])
-finally:
- remove_tempfiles(t1)
-
-if verbose:
- print("18. Test file opening hook")
-try:
- # cannot use openhook and inplace mode
- fi = FileInput(inplace=1, openhook=lambda f,m: None)
- raise TestFailed("FileInput should raise if both inplace "
- "and openhook arguments are given")
-except ValueError:
- pass
-try:
- fi = FileInput(openhook=1)
- raise TestFailed("FileInput should check openhook for being callable")
-except ValueError:
- pass
-try:
- t1 = writeTmp(1, ["A\nB"], mode="wb")
- fi = FileInput(files=t1, openhook=hook_encoded("rot13"))
- lines = list(fi)
- verify(lines == ["N\n", "O"])
-finally:
- remove_tempfiles(t1)
+
+class FileInputTests(unittest.TestCase):
+ def test_zero_byte_files(self):
+ try:
+ t1 = writeTmp(1, [""])
+ t2 = writeTmp(2, [""])
+ t3 = writeTmp(3, ["The only line there is.\n"])
+ t4 = writeTmp(4, [""])
+ fi = FileInput(files=(t1, t2, t3, t4))
+
+ line = fi.readline()
+ self.assertEqual(line, 'The only line there is.\n')
+ self.assertEqual(fi.lineno(), 1)
+ self.assertEqual(fi.filelineno(), 1)
+ self.assertEqual(fi.filename(), t3)
+
+ line = fi.readline()
+ self.failIf(line)
+ self.assertEqual(fi.lineno(), 1)
+ self.assertEqual(fi.filelineno(), 0)
+ self.assertEqual(fi.filename(), t4)
+ fi.close()
+ finally:
+ remove_tempfiles(t1, t2, t3, t4)
+
+ def test_files_that_dont_end_with_newline(self):
+ try:
+ t1 = writeTmp(1, ["A\nB\nC"])
+ t2 = writeTmp(2, ["D\nE\nF"])
+ fi = FileInput(files=(t1, t2))
+ lines = list(fi)
+ self.assertEqual(lines, ["A\n", "B\n", "C", "D\n", "E\n", "F"])
+ self.assertEqual(fi.filelineno(), 3)
+ self.assertEqual(fi.lineno(), 6)
+ finally:
+ remove_tempfiles(t1, t2)
+
+ def test_unicode_filenames(self):
+ try:
+ t1 = writeTmp(1, ["A\nB"])
+ encoding = sys.getfilesystemencoding()
+ if encoding is None:
+ encoding = 'ascii'
+ fi = FileInput(files=unicode(t1, encoding))
+ lines = list(fi)
+ self.assertEqual(lines, ["A\n", "B"])
+ finally:
+ remove_tempfiles(t1)
+
+ def test_fileno(self):
+ try:
+ t1 = writeTmp(1, ["A\nB"])
+ t2 = writeTmp(2, ["C\nD"])
+ fi = FileInput(files=(t1, t2))
+ self.assertEqual(fi.fileno(), -1)
+ line =next( fi)
+ self.assertNotEqual(fi.fileno(), -1)
+ fi.nextfile()
+ self.assertEqual(fi.fileno(), -1)
+ line = list(fi)
+ self.assertEqual(fi.fileno(), -1)
+ finally:
+ remove_tempfiles(t1, t2)
+
+ def test_opening_mode(self):
+ try:
+ # invalid mode, should raise ValueError
+ fi = FileInput(mode="w")
+ self.fail("FileInput should reject invalid mode argument")
+ except ValueError:
+ pass
+ try:
+ # try opening in universal newline mode
+ t1 = writeTmp(1, ["A\nB\r\nC\rD"], mode="wb")
+ fi = FileInput(files=t1, mode="U")
+ lines = list(fi)
+ self.assertEqual(lines, ["A\n", "B\n", "C\n", "D"])
+ finally:
+ remove_tempfiles(t1)
+
+ def test_file_opening_hook(self):
+ try:
+ # cannot use openhook and inplace mode
+ fi = FileInput(inplace=1, openhook=lambda f, m: None)
+ self.fail("FileInput should raise if both inplace "
+ "and openhook arguments are given")
+ except ValueError:
+ pass
+ try:
+ fi = FileInput(openhook=1)
+ self.fail("FileInput should check openhook for being callable")
+ except ValueError:
+ pass
+ try:
+ t1 = writeTmp(1, ["A\nB"], mode="wb")
+ fi = FileInput(files=t1, openhook=hook_encoded("rot13"))
+ lines = list(fi)
+ self.assertEqual(lines, ["N\n", "O"])
+ finally:
+ remove_tempfiles(t1)
+
+def test_main():
+ run_unittest(BufferSizesTests, FileInputTests)
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_fileio.py b/Lib/test/test_fileio.py
index 6a78154..4d969f5 100644
--- a/Lib/test/test_fileio.py
+++ b/Lib/test/test_fileio.py
@@ -110,13 +110,13 @@ class OtherFileTests(unittest.TestCase):
self.assertEquals(f.writable(), True)
self.assertEquals(f.seekable(), True)
f.close()
-
+
f = _fileio._FileIO(TESTFN, "r")
self.assertEquals(f.readable(), True)
self.assertEquals(f.writable(), False)
self.assertEquals(f.seekable(), True)
f.close()
-
+
f = _fileio._FileIO(TESTFN, "a+")
self.assertEquals(f.readable(), True)
self.assertEquals(f.writable(), True)
diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py
new file mode 100644
index 0000000..9298bf4
--- /dev/null
+++ b/Lib/test/test_ftplib.py
@@ -0,0 +1,93 @@
+import socket
+import threading
+import ftplib
+import time
+
+from unittest import TestCase
+from test import test_support
+
+def server(evt):
+ serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ serv.settimeout(3)
+ serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ serv.bind(("", 9091))
+ serv.listen(5)
+ try:
+ conn, addr = serv.accept()
+ except socket.timeout:
+ pass
+ else:
+ conn.send("1 Hola mundo\n")
+ conn.close()
+ finally:
+ serv.close()
+ evt.set()
+
+class GeneralTests(TestCase):
+
+ def setUp(self):
+ ftplib.FTP.port = 9091
+ self.evt = threading.Event()
+ threading.Thread(target=server, args=(self.evt,)).start()
+ time.sleep(.1)
+
+ def tearDown(self):
+ self.evt.wait()
+
+ def testBasic(self):
+ # do nothing
+ ftplib.FTP()
+
+ # connects
+ ftp = ftplib.FTP("localhost")
+ ftp.sock.close()
+
+ def testTimeoutDefault(self):
+ # default
+ ftp = ftplib.FTP("localhost")
+ self.assertTrue(ftp.sock.gettimeout() is None)
+ ftp.sock.close()
+
+ def testTimeoutValue(self):
+ # a value
+ ftp = ftplib.FTP("localhost", timeout=30)
+ self.assertEqual(ftp.sock.gettimeout(), 30)
+ ftp.sock.close()
+
+ def testTimeoutConnect(self):
+ ftp = ftplib.FTP()
+ ftp.connect("localhost", timeout=30)
+ self.assertEqual(ftp.sock.gettimeout(), 30)
+ ftp.sock.close()
+
+ def testTimeoutDifferentOrder(self):
+ ftp = ftplib.FTP(timeout=30)
+ ftp.connect("localhost")
+ self.assertEqual(ftp.sock.gettimeout(), 30)
+ ftp.sock.close()
+
+ def testTimeoutDirectAccess(self):
+ ftp = ftplib.FTP()
+ ftp.timeout = 30
+ ftp.connect("localhost")
+ self.assertEqual(ftp.sock.gettimeout(), 30)
+ ftp.sock.close()
+
+ def testTimeoutNone(self):
+ # None, having other default
+ previous = socket.getdefaulttimeout()
+ socket.setdefaulttimeout(30)
+ try:
+ ftp = ftplib.FTP("localhost", timeout=None)
+ finally:
+ socket.setdefaulttimeout(previous)
+ self.assertEqual(ftp.sock.gettimeout(), 30)
+ ftp.close()
+
+
+
+def test_main(verbose=None):
+ test_support.run_unittest(GeneralTests)
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
index 2d5e33c..a2df21c 100644
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -285,7 +285,7 @@ class TestReduce(unittest.TestCase):
self.sofar.append(n*n)
n += 1
return self.sofar[i]
-
+
self.assertEqual(self.func(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc')
self.assertEqual(
self.func(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []),
@@ -321,7 +321,7 @@ class TestReduce(unittest.TestCase):
return i
else:
raise IndexError
-
+
from operator import add
self.assertEqual(self.func(add, SequenceClass(5)), 10)
self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
@@ -333,7 +333,7 @@ class TestReduce(unittest.TestCase):
d = {"one": 1, "two": 2, "three": 3}
self.assertEqual(self.func(add, d), "".join(d.keys()))
-
+
def test_main(verbose=None):
diff --git a/Lib/test/test_gc.py b/Lib/test/test_gc.py
index 8068b35..10b02da 100644
--- a/Lib/test/test_gc.py
+++ b/Lib/test/test_gc.py
@@ -1,390 +1,11 @@
-from test.test_support import verify, verbose, TestFailed, vereq
+import unittest
+from test.test_support import verbose, run_unittest
import sys
import gc
import weakref
-def expect(actual, expected, name):
- if actual != expected:
- raise TestFailed, "test_%s: actual %r, expected %r" % (
- name, actual, expected)
-
-def expect_nonzero(actual, name):
- if actual == 0:
- raise TestFailed, "test_%s: unexpected zero" % name
-
-def run_test(name, thunk):
- if verbose:
- print("testing %s..." % name, end=' ')
- thunk()
- if verbose:
- print("ok")
-
-def test_list():
- l = []
- l.append(l)
- gc.collect()
- del l
- expect(gc.collect(), 1, "list")
-
-def test_dict():
- d = {}
- d[1] = d
- gc.collect()
- del d
- expect(gc.collect(), 1, "dict")
-
-def test_tuple():
- # since tuples are immutable we close the loop with a list
- l = []
- t = (l,)
- l.append(t)
- gc.collect()
- del t
- del l
- expect(gc.collect(), 2, "tuple")
-
-def test_class():
- class A:
- pass
- A.a = A
- gc.collect()
- del A
- expect_nonzero(gc.collect(), "class")
-
-def test_newstyleclass():
- class A(object):
- pass
- gc.collect()
- del A
- expect_nonzero(gc.collect(), "staticclass")
-
-def test_instance():
- class A:
- pass
- a = A()
- a.a = a
- gc.collect()
- del a
- expect_nonzero(gc.collect(), "instance")
-
-def test_newinstance():
- class A(object):
- pass
- a = A()
- a.a = a
- gc.collect()
- del a
- expect_nonzero(gc.collect(), "newinstance")
- class B(list):
- pass
- class C(B, A):
- pass
- a = C()
- a.a = a
- gc.collect()
- del a
- expect_nonzero(gc.collect(), "newinstance(2)")
- del B, C
- expect_nonzero(gc.collect(), "newinstance(3)")
- A.a = A()
- del A
- expect_nonzero(gc.collect(), "newinstance(4)")
- expect(gc.collect(), 0, "newinstance(5)")
-
-def test_method():
- # Tricky: self.__init__ is a bound method, it references the instance.
- class A:
- def __init__(self):
- self.init = self.__init__
- a = A()
- gc.collect()
- del a
- expect_nonzero(gc.collect(), "method")
-
-def test_finalizer():
- # A() is uncollectable if it is part of a cycle, make sure it shows up
- # in gc.garbage.
- class A:
- def __del__(self): pass
- class B:
- pass
- a = A()
- a.a = a
- id_a = id(a)
- b = B()
- b.b = b
- gc.collect()
- del a
- del b
- expect_nonzero(gc.collect(), "finalizer")
- for obj in gc.garbage:
- if id(obj) == id_a:
- del obj.a
- break
- else:
- raise TestFailed, "didn't find obj in garbage (finalizer)"
- gc.garbage.remove(obj)
-
-def test_finalizer_newclass():
- # A() is uncollectable if it is part of a cycle, make sure it shows up
- # in gc.garbage.
- class A(object):
- def __del__(self): pass
- class B(object):
- pass
- a = A()
- a.a = a
- id_a = id(a)
- b = B()
- b.b = b
- gc.collect()
- del a
- del b
- expect_nonzero(gc.collect(), "finalizer")
- for obj in gc.garbage:
- if id(obj) == id_a:
- del obj.a
- break
- else:
- raise TestFailed, "didn't find obj in garbage (finalizer)"
- gc.garbage.remove(obj)
-
-def test_function():
- # Tricky: f -> d -> f, code should call d.clear() after the exec to
- # break the cycle.
- d = {}
- exec("def f(): pass\n", d)
- gc.collect()
- del d
- expect(gc.collect(), 2, "function")
-
-def test_frame():
- def f():
- frame = sys._getframe()
- gc.collect()
- f()
- expect(gc.collect(), 1, "frame")
-
-
-def test_saveall():
- # Verify that cyclic garbage like lists show up in gc.garbage if the
- # SAVEALL option is enabled.
-
- # First make sure we don't save away other stuff that just happens to
- # be waiting for collection.
- gc.collect()
- vereq(gc.garbage, []) # if this fails, someone else created immortal trash
-
- L = []
- L.append(L)
- id_L = id(L)
-
- debug = gc.get_debug()
- gc.set_debug(debug | gc.DEBUG_SAVEALL)
- del L
- gc.collect()
- gc.set_debug(debug)
-
- vereq(len(gc.garbage), 1)
- obj = gc.garbage.pop()
- vereq(id(obj), id_L)
-
-def test_del():
- # __del__ methods can trigger collection, make this to happen
- thresholds = gc.get_threshold()
- gc.enable()
- gc.set_threshold(1)
-
- class A:
- def __del__(self):
- dir(self)
- a = A()
- del a
-
- gc.disable()
- gc.set_threshold(*thresholds)
-
-def test_del_newclass():
- # __del__ methods can trigger collection, make this to happen
- thresholds = gc.get_threshold()
- gc.enable()
- gc.set_threshold(1)
-
- class A(object):
- def __del__(self):
- dir(self)
- a = A()
- del a
-
- gc.disable()
- gc.set_threshold(*thresholds)
-
-def test_get_count():
- gc.collect()
- expect(gc.get_count(), (0, 0, 0), "get_count()")
- a = dict()
- expect(gc.get_count(), (1, 0, 0), "get_count()")
-
-def test_collect_generations():
- gc.collect()
- a = dict()
- gc.collect(0)
- expect(gc.get_count(), (0, 1, 0), "collect(0)")
- gc.collect(1)
- expect(gc.get_count(), (0, 0, 1), "collect(1)")
- gc.collect(2)
- expect(gc.get_count(), (0, 0, 0), "collect(1)")
-
-class Ouch:
- n = 0
- def __del__(self):
- Ouch.n = Ouch.n + 1
- if Ouch.n % 17 == 0:
- gc.collect()
-
-def test_trashcan():
- # "trashcan" is a hack to prevent stack overflow when deallocating
- # very deeply nested tuples etc. It works in part by abusing the
- # type pointer and refcount fields, and that can yield horrible
- # problems when gc tries to traverse the structures.
- # If this test fails (as it does in 2.0, 2.1 and 2.2), it will
- # most likely die via segfault.
-
- # Note: In 2.3 the possibility for compiling without cyclic gc was
- # removed, and that in turn allows the trashcan mechanism to work
- # via much simpler means (e.g., it never abuses the type pointer or
- # refcount fields anymore). Since it's much less likely to cause a
- # problem now, the various constants in this expensive (we force a lot
- # of full collections) test are cut back from the 2.2 version.
- gc.enable()
- N = 150
- for count in range(2):
- t = []
- for i in range(N):
- t = [t, Ouch()]
- u = []
- for i in range(N):
- u = [u, Ouch()]
- v = {}
- for i in range(N):
- v = {1: v, 2: Ouch()}
- gc.disable()
-
-class Boom:
- def __getattr__(self, someattribute):
- del self.attr
- raise AttributeError
-
-def test_boom():
- a = Boom()
- b = Boom()
- a.attr = b
- b.attr = a
-
- gc.collect()
- garbagelen = len(gc.garbage)
- del a, b
- # a<->b are in a trash cycle now. Collection will invoke Boom.__getattr__
- # (to see whether a and b have __del__ methods), and __getattr__ deletes
- # the internal "attr" attributes as a side effect. That causes the
- # trash cycle to get reclaimed via refcounts falling to 0, thus mutating
- # the trash graph as a side effect of merely asking whether __del__
- # exists. This used to (before 2.3b1) crash Python. Now __getattr__
- # isn't called.
- expect(gc.collect(), 4, "boom")
- expect(len(gc.garbage), garbagelen, "boom")
-
-class Boom2:
- def __init__(self):
- self.x = 0
-
- def __getattr__(self, someattribute):
- self.x += 1
- if self.x > 1:
- del self.attr
- raise AttributeError
-
-def test_boom2():
- a = Boom2()
- b = Boom2()
- a.attr = b
- b.attr = a
-
- gc.collect()
- garbagelen = len(gc.garbage)
- del a, b
- # Much like test_boom(), except that __getattr__ doesn't break the
- # cycle until the second time gc checks for __del__. As of 2.3b1,
- # there isn't a second time, so this simply cleans up the trash cycle.
- # We expect a, b, a.__dict__ and b.__dict__ (4 objects) to get reclaimed
- # this way.
- expect(gc.collect(), 4, "boom2")
- expect(len(gc.garbage), garbagelen, "boom2")
-
-# boom__new and boom2_new are exactly like boom and boom2, except use
-# new-style classes.
-
-class Boom_New(object):
- def __getattr__(self, someattribute):
- del self.attr
- raise AttributeError
-
-def test_boom_new():
- a = Boom_New()
- b = Boom_New()
- a.attr = b
- b.attr = a
-
- gc.collect()
- garbagelen = len(gc.garbage)
- del a, b
- expect(gc.collect(), 4, "boom_new")
- expect(len(gc.garbage), garbagelen, "boom_new")
-
-class Boom2_New(object):
- def __init__(self):
- self.x = 0
-
- def __getattr__(self, someattribute):
- self.x += 1
- if self.x > 1:
- del self.attr
- raise AttributeError
-
-def test_boom2_new():
- a = Boom2_New()
- b = Boom2_New()
- a.attr = b
- b.attr = a
-
- gc.collect()
- garbagelen = len(gc.garbage)
- del a, b
- expect(gc.collect(), 4, "boom2_new")
- expect(len(gc.garbage), garbagelen, "boom2_new")
-
-def test_get_referents():
- alist = [1, 3, 5]
- got = gc.get_referents(alist)
- got.sort()
- expect(got, alist, "get_referents")
-
- atuple = tuple(alist)
- got = gc.get_referents(atuple)
- got.sort()
- expect(got, alist, "get_referents")
-
- adict = {1: 3, 5: 7}
- expected = [1, 3, 5, 7]
- got = gc.get_referents(adict)
- got.sort()
- expect(got, expected, "get_referents")
-
- got = gc.get_referents([1, 2], {3: 4}, (0, 0, 0))
- got.sort()
- expect(got, [0, 0] + range(5), "get_referents")
-
- expect(gc.get_referents(1, 'a', 4j), [], "get_referents")
+### Support code
+###############################################################################
# Bug 1055820 has several tests of longstanding bugs involving weakrefs and
# cyclic gc.
@@ -410,217 +31,556 @@ class GC_Detector(object):
# gc collects it.
self.wr = weakref.ref(C1055820(666), it_happened)
-def test_bug1055820b():
- # Corresponds to temp2b.py in the bug report.
-
- ouch = []
- def callback(ignored):
- ouch[:] = [wr() for wr in WRs]
-
- Cs = [C1055820(i) for i in range(2)]
- WRs = [weakref.ref(c, callback) for c in Cs]
- c = None
-
- gc.collect()
- expect(len(ouch), 0, "bug1055820b")
- # Make the two instances trash, and collect again. The bug was that
- # the callback materialized a strong reference to an instance, but gc
- # cleared the instance's dict anyway.
- Cs = None
- gc.collect()
- expect(len(ouch), 2, "bug1055820b") # else the callbacks didn't run
- for x in ouch:
- # If the callback resurrected one of these guys, the instance
- # would be damaged, with an empty __dict__.
- expect(x, None, "bug1055820b")
-
-def test_bug1055820c():
- # Corresponds to temp2c.py in the bug report. This is pretty elaborate.
-
- c0 = C1055820(0)
- # Move c0 into generation 2.
- gc.collect()
-
- c1 = C1055820(1)
- c1.keep_c0_alive = c0
- del c0.loop # now only c1 keeps c0 alive
-
- c2 = C1055820(2)
- c2wr = weakref.ref(c2) # no callback!
-
- ouch = []
- def callback(ignored):
- ouch[:] = [c2wr()]
-
- # The callback gets associated with a wr on an object in generation 2.
- c0wr = weakref.ref(c0, callback)
-
- c0 = c1 = c2 = None
-
- # What we've set up: c0, c1, and c2 are all trash now. c0 is in
- # generation 2. The only thing keeping it alive is that c1 points to it.
- # c1 and c2 are in generation 0, and are in self-loops. There's a global
- # weakref to c2 (c2wr), but that weakref has no callback. There's also
- # a global weakref to c0 (c0wr), and that does have a callback, and that
- # callback references c2 via c2wr().
- #
- # c0 has a wr with callback, which references c2wr
- # ^
- # |
- # | Generation 2 above dots
- #. . . . . . . .|. . . . . . . . . . . . . . . . . . . . . . . .
- # | Generation 0 below dots
- # |
- # |
- # ^->c1 ^->c2 has a wr but no callback
- # | | | |
- # <--v <--v
- #
- # So this is the nightmare: when generation 0 gets collected, we see that
- # c2 has a callback-free weakref, and c1 doesn't even have a weakref.
- # Collecting generation 0 doesn't see c0 at all, and c0 is the only object
- # that has a weakref with a callback. gc clears c1 and c2. Clearing c1
- # has the side effect of dropping the refcount on c0 to 0, so c0 goes
- # away (despite that it's in an older generation) and c0's wr callback
- # triggers. That in turn materializes a reference to c2 via c2wr(), but
- # c2 gets cleared anyway by gc.
-
- # We want to let gc happen "naturally", to preserve the distinction
- # between generations.
- junk = []
- i = 0
- detector = GC_Detector()
- while not detector.gc_happened:
- i += 1
- if i > 10000:
- raise TestFailed("gc didn't happen after 10000 iterations")
- expect(len(ouch), 0, "bug1055820c")
- junk.append([]) # this will eventually trigger gc
-
- expect(len(ouch), 1, "bug1055820c") # else the callback wasn't invoked
- for x in ouch:
- # If the callback resurrected c2, the instance would be damaged,
- # with an empty __dict__.
- expect(x, None, "bug1055820c")
-
-def test_bug1055820d():
- # Corresponds to temp2d.py in the bug report. This is very much like
- # test_bug1055820c, but uses a __del__ method instead of a weakref
- # callback to sneak in a resurrection of cyclic trash.
-
- ouch = []
- class D(C1055820):
- def __del__(self):
- ouch[:] = [c2wr()]
- d0 = D(0)
- # Move all the above into generation 2.
- gc.collect()
-
- c1 = C1055820(1)
- c1.keep_d0_alive = d0
- del d0.loop # now only c1 keeps d0 alive
-
- c2 = C1055820(2)
- c2wr = weakref.ref(c2) # no callback!
-
- d0 = c1 = c2 = None
-
- # What we've set up: d0, c1, and c2 are all trash now. d0 is in
- # generation 2. The only thing keeping it alive is that c1 points to it.
- # c1 and c2 are in generation 0, and are in self-loops. There's a global
- # weakref to c2 (c2wr), but that weakref has no callback. There are no
- # other weakrefs.
- #
- # d0 has a __del__ method that references c2wr
- # ^
- # |
- # | Generation 2 above dots
- #. . . . . . . .|. . . . . . . . . . . . . . . . . . . . . . . .
- # | Generation 0 below dots
- # |
- # |
- # ^->c1 ^->c2 has a wr but no callback
- # | | | |
- # <--v <--v
- #
- # So this is the nightmare: when generation 0 gets collected, we see that
- # c2 has a callback-free weakref, and c1 doesn't even have a weakref.
- # Collecting generation 0 doesn't see d0 at all. gc clears c1 and c2.
- # Clearing c1 has the side effect of dropping the refcount on d0 to 0, so
- # d0 goes away (despite that it's in an older generation) and d0's __del__
- # triggers. That in turn materializes a reference to c2 via c2wr(), but
- # c2 gets cleared anyway by gc.
-
- # We want to let gc happen "naturally", to preserve the distinction
- # between generations.
- detector = GC_Detector()
- junk = []
- i = 0
- while not detector.gc_happened:
- i += 1
- if i > 10000:
- raise TestFailed("gc didn't happen after 10000 iterations")
- expect(len(ouch), 0, "bug1055820d")
- junk.append([]) # this will eventually trigger gc
-
- expect(len(ouch), 1, "bug1055820d") # else __del__ wasn't invoked
- for x in ouch:
- # If __del__ resurrected c2, the instance would be damaged, with an
- # empty __dict__.
- expect(x, None, "bug1055820d")
-
-
-def test_all():
- gc.collect() # Delete 2nd generation garbage
- run_test("lists", test_list)
- run_test("dicts", test_dict)
- run_test("tuples", test_tuple)
- run_test("classes", test_class)
- run_test("new style classes", test_newstyleclass)
- run_test("instances", test_instance)
- run_test("new instances", test_newinstance)
- run_test("methods", test_method)
- run_test("functions", test_function)
- run_test("frames", test_frame)
- run_test("finalizers", test_finalizer)
- run_test("finalizers (new class)", test_finalizer_newclass)
- run_test("__del__", test_del)
- run_test("__del__ (new class)", test_del_newclass)
- run_test("get_count()", test_get_count)
- run_test("collect(n)", test_collect_generations)
- run_test("saveall", test_saveall)
- run_test("trashcan", test_trashcan)
- run_test("boom", test_boom)
- run_test("boom2", test_boom2)
- run_test("boom_new", test_boom_new)
- run_test("boom2_new", test_boom2_new)
- run_test("get_referents", test_get_referents)
- run_test("bug1055820b", test_bug1055820b)
-
- gc.enable()
- try:
- run_test("bug1055820c", test_bug1055820c)
- finally:
+### Tests
+###############################################################################
+
+class GCTests(unittest.TestCase):
+ def test_list(self):
+ l = []
+ l.append(l)
+ gc.collect()
+ del l
+ self.assertEqual(gc.collect(), 1)
+
+ def test_dict(self):
+ d = {}
+ d[1] = d
+ gc.collect()
+ del d
+ self.assertEqual(gc.collect(), 1)
+
+ def test_tuple(self):
+ # since tuples are immutable we close the loop with a list
+ l = []
+ t = (l,)
+ l.append(t)
+ gc.collect()
+ del t
+ del l
+ self.assertEqual(gc.collect(), 2)
+
+ def test_class(self):
+ class A:
+ pass
+ A.a = A
+ gc.collect()
+ del A
+ self.assertNotEqual(gc.collect(), 0)
+
+ def test_newstyleclass(self):
+ class A(object):
+ pass
+ gc.collect()
+ del A
+ self.assertNotEqual(gc.collect(), 0)
+
+ def test_instance(self):
+ class A:
+ pass
+ a = A()
+ a.a = a
+ gc.collect()
+ del a
+ self.assertNotEqual(gc.collect(), 0)
+
+ def test_newinstance(self):
+ class A(object):
+ pass
+ a = A()
+ a.a = a
+ gc.collect()
+ del a
+ self.assertNotEqual(gc.collect(), 0)
+ class B(list):
+ pass
+ class C(B, A):
+ pass
+ a = C()
+ a.a = a
+ gc.collect()
+ del a
+ self.assertNotEqual(gc.collect(), 0)
+ del B, C
+ self.assertNotEqual(gc.collect(), 0)
+ A.a = A()
+ del A
+ self.assertNotEqual(gc.collect(), 0)
+ self.assertEqual(gc.collect(), 0)
+
+ def test_method(self):
+ # Tricky: self.__init__ is a bound method, it references the instance.
+ class A:
+ def __init__(self):
+ self.init = self.__init__
+ a = A()
+ gc.collect()
+ del a
+ self.assertNotEqual(gc.collect(), 0)
+
+ def test_finalizer(self):
+ # A() is uncollectable if it is part of a cycle, make sure it shows up
+ # in gc.garbage.
+ class A:
+ def __del__(self): pass
+ class B:
+ pass
+ a = A()
+ a.a = a
+ id_a = id(a)
+ b = B()
+ b.b = b
+ gc.collect()
+ del a
+ del b
+ self.assertNotEqual(gc.collect(), 0)
+ for obj in gc.garbage:
+ if id(obj) == id_a:
+ del obj.a
+ break
+ else:
+ self.fail("didn't find obj in garbage (finalizer)")
+ gc.garbage.remove(obj)
+
+ def test_finalizer_newclass(self):
+ # A() is uncollectable if it is part of a cycle, make sure it shows up
+ # in gc.garbage.
+ class A(object):
+ def __del__(self): pass
+ class B(object):
+ pass
+ a = A()
+ a.a = a
+ id_a = id(a)
+ b = B()
+ b.b = b
+ gc.collect()
+ del a
+ del b
+ self.assertNotEqual(gc.collect(), 0)
+ for obj in gc.garbage:
+ if id(obj) == id_a:
+ del obj.a
+ break
+ else:
+ self.fail("didn't find obj in garbage (finalizer)")
+ gc.garbage.remove(obj)
+
+ def test_function(self):
+ # Tricky: f -> d -> f, code should call d.clear() after the exec to
+ # break the cycle.
+ d = {}
+ exec("def f(): pass\n", d)
+ gc.collect()
+ del d
+ self.assertEqual(gc.collect(), 2)
+
+ def test_frame(self):
+ def f():
+ frame = sys._getframe()
+ gc.collect()
+ f()
+ self.assertEqual(gc.collect(), 1)
+
+ def test_saveall(self):
+ # Verify that cyclic garbage like lists show up in gc.garbage if the
+ # SAVEALL option is enabled.
+
+ # First make sure we don't save away other stuff that just happens to
+ # be waiting for collection.
+ gc.collect()
+ # if this fails, someone else created immortal trash
+ self.assertEqual(gc.garbage, [])
+
+ L = []
+ L.append(L)
+ id_L = id(L)
+
+ debug = gc.get_debug()
+ gc.set_debug(debug | gc.DEBUG_SAVEALL)
+ del L
+ gc.collect()
+ gc.set_debug(debug)
+
+ self.assertEqual(len(gc.garbage), 1)
+ obj = gc.garbage.pop()
+ self.assertEqual(id(obj), id_L)
+
+ def test_del(self):
+ # __del__ methods can trigger collection, make this to happen
+ thresholds = gc.get_threshold()
+ gc.enable()
+ gc.set_threshold(1)
+
+ class A:
+ def __del__(self):
+ dir(self)
+ a = A()
+ del a
+
gc.disable()
+ gc.set_threshold(*thresholds)
- gc.enable()
- try:
- run_test("bug1055820d", test_bug1055820d)
- finally:
+ def test_del_newclass(self):
+ # __del__ methods can trigger collection, make this to happen
+ thresholds = gc.get_threshold()
+ gc.enable()
+ gc.set_threshold(1)
+
+ class A(object):
+ def __del__(self):
+ dir(self)
+ a = A()
+ del a
+
+ gc.disable()
+ gc.set_threshold(*thresholds)
+
+ def test_get_count(self):
+ gc.collect()
+ self.assertEqual(gc.get_count(), (0, 0, 0))
+ a = dict()
+ self.assertEqual(gc.get_count(), (1, 0, 0))
+
+ def test_collect_generations(self):
+ gc.collect()
+ a = dict()
+ gc.collect(0)
+ self.assertEqual(gc.get_count(), (0, 1, 0))
+ gc.collect(1)
+ self.assertEqual(gc.get_count(), (0, 0, 1))
+ gc.collect(2)
+ self.assertEqual(gc.get_count(), (0, 0, 0))
+
+ def test_trashcan(self):
+ class Ouch:
+ n = 0
+ def __del__(self):
+ Ouch.n = Ouch.n + 1
+ if Ouch.n % 17 == 0:
+ gc.collect()
+
+ # "trashcan" is a hack to prevent stack overflow when deallocating
+ # very deeply nested tuples etc. It works in part by abusing the
+ # type pointer and refcount fields, and that can yield horrible
+ # problems when gc tries to traverse the structures.
+ # If this test fails (as it does in 2.0, 2.1 and 2.2), it will
+ # most likely die via segfault.
+
+ # Note: In 2.3 the possibility for compiling without cyclic gc was
+ # removed, and that in turn allows the trashcan mechanism to work
+ # via much simpler means (e.g., it never abuses the type pointer or
+ # refcount fields anymore). Since it's much less likely to cause a
+ # problem now, the various constants in this expensive (we force a lot
+ # of full collections) test are cut back from the 2.2 version.
+ gc.enable()
+ N = 150
+ for count in range(2):
+ t = []
+ for i in range(N):
+ t = [t, Ouch()]
+ u = []
+ for i in range(N):
+ u = [u, Ouch()]
+ v = {}
+ for i in range(N):
+ v = {1: v, 2: Ouch()}
gc.disable()
-def test():
- if verbose:
- print("disabling automatic collection")
+ def test_boom(self):
+ class Boom:
+ def __getattr__(self, someattribute):
+ del self.attr
+ raise AttributeError
+
+ a = Boom()
+ b = Boom()
+ a.attr = b
+ b.attr = a
+
+ gc.collect()
+ garbagelen = len(gc.garbage)
+ del a, b
+ # a<->b are in a trash cycle now. Collection will invoke
+ # Boom.__getattr__ (to see whether a and b have __del__ methods), and
+ # __getattr__ deletes the internal "attr" attributes as a side effect.
+ # That causes the trash cycle to get reclaimed via refcounts falling to
+ # 0, thus mutating the trash graph as a side effect of merely asking
+ # whether __del__ exists. This used to (before 2.3b1) crash Python.
+ # Now __getattr__ isn't called.
+ self.assertEqual(gc.collect(), 4)
+ self.assertEqual(len(gc.garbage), garbagelen)
+
+ def test_boom2(self):
+ class Boom2:
+ def __init__(self):
+ self.x = 0
+
+ def __getattr__(self, someattribute):
+ self.x += 1
+ if self.x > 1:
+ del self.attr
+ raise AttributeError
+
+ a = Boom2()
+ b = Boom2()
+ a.attr = b
+ b.attr = a
+
+ gc.collect()
+ garbagelen = len(gc.garbage)
+ del a, b
+ # Much like test_boom(), except that __getattr__ doesn't break the
+ # cycle until the second time gc checks for __del__. As of 2.3b1,
+ # there isn't a second time, so this simply cleans up the trash cycle.
+ # We expect a, b, a.__dict__ and b.__dict__ (4 objects) to get
+ # reclaimed this way.
+ self.assertEqual(gc.collect(), 4)
+ self.assertEqual(len(gc.garbage), garbagelen)
+
+ def test_boom_new(self):
+ # boom__new and boom2_new are exactly like boom and boom2, except use
+ # new-style classes.
+
+ class Boom_New(object):
+ def __getattr__(self, someattribute):
+ del self.attr
+ raise AttributeError
+
+ a = Boom_New()
+ b = Boom_New()
+ a.attr = b
+ b.attr = a
+
+ gc.collect()
+ garbagelen = len(gc.garbage)
+ del a, b
+ self.assertEqual(gc.collect(), 4)
+ self.assertEqual(len(gc.garbage), garbagelen)
+
+ def test_boom2_new(self):
+ class Boom2_New(object):
+ def __init__(self):
+ self.x = 0
+
+ def __getattr__(self, someattribute):
+ self.x += 1
+ if self.x > 1:
+ del self.attr
+ raise AttributeError
+
+ a = Boom2_New()
+ b = Boom2_New()
+ a.attr = b
+ b.attr = a
+
+ gc.collect()
+ garbagelen = len(gc.garbage)
+ del a, b
+ self.assertEqual(gc.collect(), 4)
+ self.assertEqual(len(gc.garbage), garbagelen)
+
+ def test_get_referents(self):
+ alist = [1, 3, 5]
+ got = gc.get_referents(alist)
+ got.sort()
+ self.assertEqual(got, alist)
+
+ atuple = tuple(alist)
+ got = gc.get_referents(atuple)
+ got.sort()
+ self.assertEqual(got, alist)
+
+ adict = {1: 3, 5: 7}
+ expected = [1, 3, 5, 7]
+ got = gc.get_referents(adict)
+ got.sort()
+ self.assertEqual(got, expected)
+
+ got = gc.get_referents([1, 2], {3: 4}, (0, 0, 0))
+ got.sort()
+ self.assertEqual(got, [0, 0] + range(5))
+
+ self.assertEqual(gc.get_referents(1, 'a', 4j), [])
+
+ def test_bug1055820b(self):
+ # Corresponds to temp2b.py in the bug report.
+
+ ouch = []
+ def callback(ignored):
+ ouch[:] = [wr() for wr in WRs]
+
+ Cs = [C1055820(i) for i in range(2)]
+ WRs = [weakref.ref(c, callback) for c in Cs]
+ c = None
+
+ gc.collect()
+ self.assertEqual(len(ouch), 0)
+ # Make the two instances trash, and collect again. The bug was that
+ # the callback materialized a strong reference to an instance, but gc
+ # cleared the instance's dict anyway.
+ Cs = None
+ gc.collect()
+ self.assertEqual(len(ouch), 2) # else the callbacks didn't run
+ for x in ouch:
+ # If the callback resurrected one of these guys, the instance
+ # would be damaged, with an empty __dict__.
+ self.assertEqual(x, None)
+
+class GCTogglingTests(unittest.TestCase):
+ def setUp(self):
+ gc.enable()
+
+ def tearDown(self):
+ gc.disable()
+
+ def test_bug1055820c(self):
+ # Corresponds to temp2c.py in the bug report. This is pretty
+ # elaborate.
+
+ c0 = C1055820(0)
+ # Move c0 into generation 2.
+ gc.collect()
+
+ c1 = C1055820(1)
+ c1.keep_c0_alive = c0
+ del c0.loop # now only c1 keeps c0 alive
+
+ c2 = C1055820(2)
+ c2wr = weakref.ref(c2) # no callback!
+
+ ouch = []
+ def callback(ignored):
+ ouch[:] = [c2wr()]
+
+ # The callback gets associated with a wr on an object in generation 2.
+ c0wr = weakref.ref(c0, callback)
+
+ c0 = c1 = c2 = None
+
+ # What we've set up: c0, c1, and c2 are all trash now. c0 is in
+ # generation 2. The only thing keeping it alive is that c1 points to
+ # it. c1 and c2 are in generation 0, and are in self-loops. There's a
+ # global weakref to c2 (c2wr), but that weakref has no callback.
+ # There's also a global weakref to c0 (c0wr), and that does have a
+ # callback, and that callback references c2 via c2wr().
+ #
+ # c0 has a wr with callback, which references c2wr
+ # ^
+ # |
+ # | Generation 2 above dots
+ #. . . . . . . .|. . . . . . . . . . . . . . . . . . . . . . . .
+ # | Generation 0 below dots
+ # |
+ # |
+ # ^->c1 ^->c2 has a wr but no callback
+ # | | | |
+ # <--v <--v
+ #
+ # So this is the nightmare: when generation 0 gets collected, we see
+ # that c2 has a callback-free weakref, and c1 doesn't even have a
+ # weakref. Collecting generation 0 doesn't see c0 at all, and c0 is
+ # the only object that has a weakref with a callback. gc clears c1
+ # and c2. Clearing c1 has the side effect of dropping the refcount on
+ # c0 to 0, so c0 goes away (despite that it's in an older generation)
+ # and c0's wr callback triggers. That in turn materializes a reference
+ # to c2 via c2wr(), but c2 gets cleared anyway by gc.
+
+ # We want to let gc happen "naturally", to preserve the distinction
+ # between generations.
+ junk = []
+ i = 0
+ detector = GC_Detector()
+ while not detector.gc_happened:
+ i += 1
+ if i > 10000:
+ self.fail("gc didn't happen after 10000 iterations")
+ self.assertEqual(len(ouch), 0)
+ junk.append([]) # this will eventually trigger gc
+
+ self.assertEqual(len(ouch), 1) # else the callback wasn't invoked
+ for x in ouch:
+ # If the callback resurrected c2, the instance would be damaged,
+ # with an empty __dict__.
+ self.assertEqual(x, None)
+
+ def test_bug1055820d(self):
+ # Corresponds to temp2d.py in the bug report. This is very much like
+ # test_bug1055820c, but uses a __del__ method instead of a weakref
+ # callback to sneak in a resurrection of cyclic trash.
+
+ ouch = []
+ class D(C1055820):
+ def __del__(self):
+ ouch[:] = [c2wr()]
+
+ d0 = D(0)
+ # Move all the above into generation 2.
+ gc.collect()
+
+ c1 = C1055820(1)
+ c1.keep_d0_alive = d0
+ del d0.loop # now only c1 keeps d0 alive
+
+ c2 = C1055820(2)
+ c2wr = weakref.ref(c2) # no callback!
+
+ d0 = c1 = c2 = None
+
+ # What we've set up: d0, c1, and c2 are all trash now. d0 is in
+ # generation 2. The only thing keeping it alive is that c1 points to
+ # it. c1 and c2 are in generation 0, and are in self-loops. There's
+ # a global weakref to c2 (c2wr), but that weakref has no callback.
+ # There are no other weakrefs.
+ #
+ # d0 has a __del__ method that references c2wr
+ # ^
+ # |
+ # | Generation 2 above dots
+ #. . . . . . . .|. . . . . . . . . . . . . . . . . . . . . . . .
+ # | Generation 0 below dots
+ # |
+ # |
+ # ^->c1 ^->c2 has a wr but no callback
+ # | | | |
+ # <--v <--v
+ #
+ # So this is the nightmare: when generation 0 gets collected, we see
+ # that c2 has a callback-free weakref, and c1 doesn't even have a
+ # weakref. Collecting generation 0 doesn't see d0 at all. gc clears
+ # c1 and c2. Clearing c1 has the side effect of dropping the refcount
+ # on d0 to 0, so d0 goes away (despite that it's in an older
+ # generation) and d0's __del__ triggers. That in turn materializes
+ # a reference to c2 via c2wr(), but c2 gets cleared anyway by gc.
+
+ # We want to let gc happen "naturally", to preserve the distinction
+ # between generations.
+ detector = GC_Detector()
+ junk = []
+ i = 0
+ while not detector.gc_happened:
+ i += 1
+ if i > 10000:
+ self.fail("gc didn't happen after 10000 iterations")
+ self.assertEqual(len(ouch), 0)
+ junk.append([]) # this will eventually trigger gc
+
+ self.assertEqual(len(ouch), 1) # else __del__ wasn't invoked
+ for x in ouch:
+ # If __del__ resurrected c2, the instance would be damaged, with an
+ # empty __dict__.
+ self.assertEqual(x, None)
+
+def test_main():
enabled = gc.isenabled()
gc.disable()
- verify(not gc.isenabled())
+ assert not gc.isenabled()
debug = gc.get_debug()
gc.set_debug(debug & ~gc.DEBUG_LEAK) # this test is supposed to leak
try:
- test_all()
+ gc.collect() # Delete 2nd generation garbage
+ run_unittest(GCTests, GCTogglingTests)
finally:
gc.set_debug(debug)
# test gc.enable() even if GC is disabled by default
@@ -628,9 +588,9 @@ def test():
print("restoring automatic collection")
# make sure to always test gc.enable()
gc.enable()
- verify(gc.isenabled())
+ assert gc.isenabled()
if not enabled:
gc.disable()
-
-test()
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_getopt.py b/Lib/test/test_getopt.py
index 10bd256..a3a9940 100644
--- a/Lib/test/test_getopt.py
+++ b/Lib/test/test_getopt.py
@@ -1,180 +1,179 @@
# test_getopt.py
# David Goodger <dgoodger@bigfoot.com> 2000-08-19
+from test.test_support import verbose, run_doctest, run_unittest
+import unittest
+
import getopt
-from getopt import GetoptError
-from test.test_support import verify, verbose, run_doctest
import os
-def expectException(teststr, expected, failure=AssertionError):
- """Executes a statement passed in teststr, and raises an exception
- (failure) if the expected exception is *not* raised."""
- try:
- exec(teststr)
- except expected:
- pass
- else:
- raise failure
-
-old_posixly_correct = os.environ.get("POSIXLY_CORRECT")
-if old_posixly_correct is not None:
- del os.environ["POSIXLY_CORRECT"]
-
-if verbose:
- print('Running tests on getopt.short_has_arg')
-verify(getopt.short_has_arg('a', 'a:'))
-verify(not getopt.short_has_arg('a', 'a'))
-expectException("tmp = getopt.short_has_arg('a', 'b')", GetoptError)
-expectException("tmp = getopt.short_has_arg('a', '')", GetoptError)
-
-if verbose:
- print('Running tests on getopt.long_has_args')
-has_arg, option = getopt.long_has_args('abc', ['abc='])
-verify(has_arg)
-verify(option == 'abc')
-has_arg, option = getopt.long_has_args('abc', ['abc'])
-verify(not has_arg)
-verify(option == 'abc')
-has_arg, option = getopt.long_has_args('abc', ['abcd'])
-verify(not has_arg)
-verify(option == 'abcd')
-expectException("has_arg, option = getopt.long_has_args('abc', ['def'])",
- GetoptError)
-expectException("has_arg, option = getopt.long_has_args('abc', [])",
- GetoptError)
-expectException("has_arg, option = " + \
- "getopt.long_has_args('abc', ['abcd','abcde'])",
- GetoptError)
-
-if verbose:
- print('Running tests on getopt.do_shorts')
-opts, args = getopt.do_shorts([], 'a', 'a', [])
-verify(opts == [('-a', '')])
-verify(args == [])
-opts, args = getopt.do_shorts([], 'a1', 'a:', [])
-verify(opts == [('-a', '1')])
-verify(args == [])
-#opts, args = getopt.do_shorts([], 'a=1', 'a:', [])
-#verify(opts == [('-a', '1')])
-#verify(args == [])
-opts, args = getopt.do_shorts([], 'a', 'a:', ['1'])
-verify(opts == [('-a', '1')])
-verify(args == [])
-opts, args = getopt.do_shorts([], 'a', 'a:', ['1', '2'])
-verify(opts == [('-a', '1')])
-verify(args == ['2'])
-expectException("opts, args = getopt.do_shorts([], 'a1', 'a', [])",
- GetoptError)
-expectException("opts, args = getopt.do_shorts([], 'a', 'a:', [])",
- GetoptError)
-
-if verbose:
- print('Running tests on getopt.do_longs')
-opts, args = getopt.do_longs([], 'abc', ['abc'], [])
-verify(opts == [('--abc', '')])
-verify(args == [])
-opts, args = getopt.do_longs([], 'abc=1', ['abc='], [])
-verify(opts == [('--abc', '1')])
-verify(args == [])
-opts, args = getopt.do_longs([], 'abc=1', ['abcd='], [])
-verify(opts == [('--abcd', '1')])
-verify(args == [])
-opts, args = getopt.do_longs([], 'abc', ['ab', 'abc', 'abcd'], [])
-verify(opts == [('--abc', '')])
-verify(args == [])
-# Much like the preceding, except with a non-alpha character ("-") in
-# option name that precedes "="; failed in
-# http://sourceforge.net/bugs/?func=detailbug&bug_id=126863&group_id=5470
-opts, args = getopt.do_longs([], 'foo=42', ['foo-bar', 'foo=',], [])
-verify(opts == [('--foo', '42')])
-verify(args == [])
-expectException("opts, args = getopt.do_longs([], 'abc=1', ['abc'], [])",
- GetoptError)
-expectException("opts, args = getopt.do_longs([], 'abc', ['abc='], [])",
- GetoptError)
-
-# note: the empty string between '-a' and '--beta' is significant:
-# it simulates an empty string option argument ('-a ""') on the command line.
-cmdline = ['-a', '1', '-b', '--alpha=2', '--beta', '-a', '3', '-a', '',
- '--beta', 'arg1', 'arg2']
-
-if verbose:
- print('Running tests on getopt.getopt')
-opts, args = getopt.getopt(cmdline, 'a:b', ['alpha=', 'beta'])
-verify(opts == [('-a', '1'), ('-b', ''), ('--alpha', '2'), ('--beta', ''),
- ('-a', '3'), ('-a', ''), ('--beta', '')] )
-# Note ambiguity of ('-b', '') and ('-a', '') above. This must be
-# accounted for in the code that calls getopt().
-verify(args == ['arg1', 'arg2'])
-
-expectException(
- "opts, args = getopt.getopt(cmdline, 'a:b', ['alpha', 'beta'])",
- GetoptError)
-
-# Test handling of GNU style scanning mode.
-if verbose:
- print('Running tests on getopt.gnu_getopt')
-cmdline = ['-a', 'arg1', '-b', '1', '--alpha', '--beta=2']
-# GNU style
-opts, args = getopt.gnu_getopt(cmdline, 'ab:', ['alpha', 'beta='])
-verify(opts == [('-a', ''), ('-b', '1'), ('--alpha', ''), ('--beta', '2')])
-verify(args == ['arg1'])
-# Posix style via +
-opts, args = getopt.gnu_getopt(cmdline, '+ab:', ['alpha', 'beta='])
-verify(opts == [('-a', '')])
-verify(args == ['arg1', '-b', '1', '--alpha', '--beta=2'])
-# Posix style via POSIXLY_CORRECT
-os.environ["POSIXLY_CORRECT"] = "1"
-opts, args = getopt.gnu_getopt(cmdline, 'ab:', ['alpha', 'beta='])
-verify(opts == [('-a', '')])
-verify(args == ['arg1', '-b', '1', '--alpha', '--beta=2'])
-
-
-if old_posixly_correct is None:
- del os.environ["POSIXLY_CORRECT"]
-else:
- os.environ["POSIXLY_CORRECT"] = old_posixly_correct
-
-#------------------------------------------------------------------------------
-
-libreftest = """
-Examples from the Library Reference: Doc/lib/libgetopt.tex
-
-An example using only Unix style options:
-
-
->>> import getopt
->>> args = '-a -b -cfoo -d bar a1 a2'.split()
->>> args
-['-a', '-b', '-cfoo', '-d', 'bar', 'a1', 'a2']
->>> optlist, args = getopt.getopt(args, 'abc:d:')
->>> optlist
-[('-a', ''), ('-b', ''), ('-c', 'foo'), ('-d', 'bar')]
->>> args
-['a1', 'a2']
-
-Using long option names is equally easy:
-
-
->>> s = '--condition=foo --testing --output-file abc.def -x a1 a2'
->>> args = s.split()
->>> args
-['--condition=foo', '--testing', '--output-file', 'abc.def', '-x', 'a1', 'a2']
->>> optlist, args = getopt.getopt(args, 'x', [
-... 'condition=', 'output-file=', 'testing'])
->>> optlist
-[('--condition', 'foo'), ('--testing', ''), ('--output-file', 'abc.def'), ('-x', '')]
->>> args
-['a1', 'a2']
-
-"""
-
-__test__ = {'libreftest' : libreftest}
-
-import sys
-run_doctest(sys.modules[__name__], verbose)
-
-#------------------------------------------------------------------------------
-
-if verbose:
- print("Module getopt: tests completed successfully.")
+sentinel = object()
+
+class GetoptTests(unittest.TestCase):
+ def setUp(self):
+ self.old_posixly_correct = os.environ.get("POSIXLY_CORRECT", sentinel)
+ if self.old_posixly_correct is not sentinel:
+ del os.environ["POSIXLY_CORRECT"]
+
+ def tearDown(self):
+ if self.old_posixly_correct is sentinel:
+ os.environ.pop("POSIXLY_CORRECT", None)
+ else:
+ os.environ["POSIXLY_CORRECT"] = self.old_posixly_correct
+
+ def assertError(self, *args, **kwargs):
+ self.assertRaises(getopt.GetoptError, *args, **kwargs)
+
+ def test_short_has_arg(self):
+ self.failUnless(getopt.short_has_arg('a', 'a:'))
+ self.failIf(getopt.short_has_arg('a', 'a'))
+ self.assertError(getopt.short_has_arg, 'a', 'b')
+
+ def test_long_has_args(self):
+ has_arg, option = getopt.long_has_args('abc', ['abc='])
+ self.failUnless(has_arg)
+ self.assertEqual(option, 'abc')
+
+ has_arg, option = getopt.long_has_args('abc', ['abc'])
+ self.failIf(has_arg)
+ self.assertEqual(option, 'abc')
+
+ has_arg, option = getopt.long_has_args('abc', ['abcd'])
+ self.failIf(has_arg)
+ self.assertEqual(option, 'abcd')
+
+ self.assertError(getopt.long_has_args, 'abc', ['def'])
+ self.assertError(getopt.long_has_args, 'abc', [])
+ self.assertError(getopt.long_has_args, 'abc', ['abcd','abcde'])
+
+ def test_do_shorts(self):
+ opts, args = getopt.do_shorts([], 'a', 'a', [])
+ self.assertEqual(opts, [('-a', '')])
+ self.assertEqual(args, [])
+
+ opts, args = getopt.do_shorts([], 'a1', 'a:', [])
+ self.assertEqual(opts, [('-a', '1')])
+ self.assertEqual(args, [])
+
+ #opts, args = getopt.do_shorts([], 'a=1', 'a:', [])
+ #self.assertEqual(opts, [('-a', '1')])
+ #self.assertEqual(args, [])
+
+ opts, args = getopt.do_shorts([], 'a', 'a:', ['1'])
+ self.assertEqual(opts, [('-a', '1')])
+ self.assertEqual(args, [])
+
+ opts, args = getopt.do_shorts([], 'a', 'a:', ['1', '2'])
+ self.assertEqual(opts, [('-a', '1')])
+ self.assertEqual(args, ['2'])
+
+ self.assertError(getopt.do_shorts, [], 'a1', 'a', [])
+ self.assertError(getopt.do_shorts, [], 'a', 'a:', [])
+
+ def test_do_longs(self):
+ opts, args = getopt.do_longs([], 'abc', ['abc'], [])
+ self.assertEqual(opts, [('--abc', '')])
+ self.assertEqual(args, [])
+
+ opts, args = getopt.do_longs([], 'abc=1', ['abc='], [])
+ self.assertEqual(opts, [('--abc', '1')])
+ self.assertEqual(args, [])
+
+ opts, args = getopt.do_longs([], 'abc=1', ['abcd='], [])
+ self.assertEqual(opts, [('--abcd', '1')])
+ self.assertEqual(args, [])
+
+ opts, args = getopt.do_longs([], 'abc', ['ab', 'abc', 'abcd'], [])
+ self.assertEqual(opts, [('--abc', '')])
+ self.assertEqual(args, [])
+
+ # Much like the preceding, except with a non-alpha character ("-") in
+ # option name that precedes "="; failed in
+ # http://python.org/sf/126863
+ opts, args = getopt.do_longs([], 'foo=42', ['foo-bar', 'foo=',], [])
+ self.assertEqual(opts, [('--foo', '42')])
+ self.assertEqual(args, [])
+
+ self.assertError(getopt.do_longs, [], 'abc=1', ['abc'], [])
+ self.assertError(getopt.do_longs, [], 'abc', ['abc='], [])
+
+ def test_getopt(self):
+ # note: the empty string between '-a' and '--beta' is significant:
+ # it simulates an empty string option argument ('-a ""') on the
+ # command line.
+ cmdline = ['-a', '1', '-b', '--alpha=2', '--beta', '-a', '3', '-a',
+ '', '--beta', 'arg1', 'arg2']
+
+ opts, args = getopt.getopt(cmdline, 'a:b', ['alpha=', 'beta'])
+ self.assertEqual(opts, [('-a', '1'), ('-b', ''),
+ ('--alpha', '2'), ('--beta', ''),
+ ('-a', '3'), ('-a', ''), ('--beta', '')])
+ # Note ambiguity of ('-b', '') and ('-a', '') above. This must be
+ # accounted for in the code that calls getopt().
+ self.assertEqual(args, ['arg1', 'arg2'])
+
+ self.assertError(getopt.getopt, cmdline, 'a:b', ['alpha', 'beta'])
+
+ def test_gnu_getopt(self):
+ # Test handling of GNU style scanning mode.
+ cmdline = ['-a', 'arg1', '-b', '1', '--alpha', '--beta=2']
+
+ # GNU style
+ opts, args = getopt.gnu_getopt(cmdline, 'ab:', ['alpha', 'beta='])
+ self.assertEqual(args, ['arg1'])
+ self.assertEqual(opts, [('-a', ''), ('-b', '1'),
+ ('--alpha', ''), ('--beta', '2')])
+
+ # Posix style via +
+ opts, args = getopt.gnu_getopt(cmdline, '+ab:', ['alpha', 'beta='])
+ self.assertEqual(opts, [('-a', '')])
+ self.assertEqual(args, ['arg1', '-b', '1', '--alpha', '--beta=2'])
+
+ # Posix style via POSIXLY_CORRECT
+ os.environ["POSIXLY_CORRECT"] = "1"
+ opts, args = getopt.gnu_getopt(cmdline, 'ab:', ['alpha', 'beta='])
+ self.assertEqual(opts, [('-a', '')])
+ self.assertEqual(args, ['arg1', '-b', '1', '--alpha', '--beta=2'])
+
+ def test_libref_examples(self):
+ s = """
+ Examples from the Library Reference: Doc/lib/libgetopt.tex
+
+ An example using only Unix style options:
+
+
+ >>> import getopt
+ >>> args = '-a -b -cfoo -d bar a1 a2'.split()
+ >>> args
+ ['-a', '-b', '-cfoo', '-d', 'bar', 'a1', 'a2']
+ >>> optlist, args = getopt.getopt(args, 'abc:d:')
+ >>> optlist
+ [('-a', ''), ('-b', ''), ('-c', 'foo'), ('-d', 'bar')]
+ >>> args
+ ['a1', 'a2']
+
+ Using long option names is equally easy:
+
+
+ >>> s = '--condition=foo --testing --output-file abc.def -x a1 a2'
+ >>> args = s.split()
+ >>> args
+ ['--condition=foo', '--testing', '--output-file', 'abc.def', '-x', 'a1', 'a2']
+ >>> optlist, args = getopt.getopt(args, 'x', [
+ ... 'condition=', 'output-file=', 'testing'])
+ >>> optlist
+ [('--condition', 'foo'), ('--testing', ''), ('--output-file', 'abc.def'), ('-x', '')]
+ >>> args
+ ['a1', 'a2']
+ """
+
+ import new
+ m = new.module("libreftest", s)
+ run_doctest(m, verbose)
+
+
+def test_main():
+ run_unittest(GetoptTests)
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_gettext.py b/Lib/test/test_gettext.py
index 76253de..ab6bc9a 100644
--- a/Lib/test/test_gettext.py
+++ b/Lib/test/test_gettext.py
@@ -4,7 +4,7 @@ import shutil
import gettext
import unittest
-from test.test_support import run_suite
+from test import test_support
# TODO:
@@ -336,19 +336,8 @@ class WeirdMetadataTest(GettextBaseTest):
'John Doe <jdoe@example.com>\nJane Foobar <jfoobar@example.com>')
-def suite():
- suite = unittest.TestSuite()
- suite.addTest(unittest.makeSuite(GettextTestCase1))
- suite.addTest(unittest.makeSuite(GettextTestCase2))
- suite.addTest(unittest.makeSuite(PluralFormsTestCase))
- suite.addTest(unittest.makeSuite(UnicodeTranslationsTest))
- suite.addTest(unittest.makeSuite(WeirdMetadataTest))
- return suite
-
-
def test_main():
- run_suite(suite())
-
+ test_support.run_unittest(__name__)
if __name__ == '__main__':
test_main()
diff --git a/Lib/test/test_glob.py b/Lib/test/test_glob.py
index 5ce09f9..f1993ab 100644
--- a/Lib/test/test_glob.py
+++ b/Lib/test/test_glob.py
@@ -52,6 +52,16 @@ class GlobTests(unittest.TestCase):
eq(self.glob('aab'), [self.norm('aab')])
eq(self.glob('zymurgy'), [])
+ # test return types are unicode, but only if os.listdir
+ # returns unicode filenames
+ uniset = set([unicode])
+ tmp = os.listdir(u'.')
+ if set(type(x) for x in tmp) == uniset:
+ u1 = glob.glob(u'*')
+ u2 = glob.glob(u'./*')
+ self.assertEquals(set(type(r) for r in u1), uniset)
+ self.assertEquals(set(type(r) for r in u2), uniset)
+
def test_glob_one_directory(self):
eq = self.assertSequencesEqual_noorder
eq(self.glob('a*'), map(self.norm, ['a', 'aab', 'aaa']))
diff --git a/Lib/test/test_grammar.py b/Lib/test/test_grammar.py
index 85be05b..96cf824 100644
--- a/Lib/test/test_grammar.py
+++ b/Lib/test/test_grammar.py
@@ -364,7 +364,7 @@ class GrammarTests(unittest.TestCase):
x = 1; pass; del x;
foo()
- ### small_stmt: expr_stmt | pass_stmt | del_stmt | flow_stmt | import_stmt | global_stmt | access_stmt
+ ### small_stmt: expr_stmt | pass_stmt | del_stmt | flow_stmt | import_stmt | global_stmt | access_stmt
# Tested below
def testExprStmt(self):
diff --git a/Lib/test/test_htmlparser.py b/Lib/test/test_htmlparser.py
index 54b90cd..229bbed 100755
--- a/Lib/test/test_htmlparser.py
+++ b/Lib/test/test_htmlparser.py
@@ -309,6 +309,11 @@ DOCTYPE html [
("endtag", "script"),
])
+ def test_entityrefs_in_attributes(self):
+ self._run_check("<html foo='&euro;&amp;&#97;&#x61;&unsupported;'>", [
+ ("starttag", "html", [("foo", u"\u20AC&aa&unsupported;")])
+ ])
+
def test_main():
test_support.run_unittest(HTMLParserTestCase)
diff --git a/Lib/test/test_httplib.py b/Lib/test/test_httplib.py
index 90a4e55..035f0b9 100644
--- a/Lib/test/test_httplib.py
+++ b/Lib/test/test_httplib.py
@@ -1,6 +1,7 @@
import httplib
import StringIO
import sys
+import socket
from unittest import TestCase
@@ -149,8 +150,52 @@ class OfflineTest(TestCase):
def test_responses(self):
self.assertEquals(httplib.responses[httplib.NOT_FOUND], "Not Found")
+PORT = 50003
+HOST = "localhost"
+
+class TimeoutTest(TestCase):
+
+ def setUp(self):
+ self.serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ global PORT
+ PORT = test_support.bind_port(self.serv, HOST, PORT)
+ self.serv.listen(5)
+
+ def tearDown(self):
+ self.serv.close()
+ self.serv = None
+
+ def testTimeoutAttribute(self):
+ '''This will prove that the timeout gets through
+ HTTPConnection and into the socket.
+ '''
+ # default
+ httpConn = httplib.HTTPConnection(HOST, PORT)
+ httpConn.connect()
+ self.assertTrue(httpConn.sock.gettimeout() is None)
+ httpConn.close()
+
+ # a value
+ httpConn = httplib.HTTPConnection(HOST, PORT, timeout=30)
+ httpConn.connect()
+ self.assertEqual(httpConn.sock.gettimeout(), 30)
+ httpConn.close()
+
+ # None, having other default
+ previous = socket.getdefaulttimeout()
+ socket.setdefaulttimeout(30)
+ try:
+ httpConn = httplib.HTTPConnection(HOST, PORT, timeout=None)
+ httpConn.connect()
+ finally:
+ socket.setdefaulttimeout(previous)
+ self.assertEqual(httpConn.sock.gettimeout(), 30)
+ httpConn.close()
+
+
def test_main(verbose=None):
- test_support.run_unittest(HeaderTests, OfflineTest, BasicTest)
+ test_support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest)
if __name__ == '__main__':
test_main()
diff --git a/Lib/test/test_import.py b/Lib/test/test_import.py
index a8f912f..87907c8 100644
--- a/Lib/test/test_import.py
+++ b/Lib/test/test_import.py
@@ -193,6 +193,16 @@ class ImportTest(unittest.TestCase):
if TESTFN in sys.modules:
del sys.modules[TESTFN]
+ def test_infinite_reload(self):
+ # Bug #742342 reports that Python segfaults (infinite recursion in C)
+ # when faced with self-recursive reload()ing.
+
+ sys.path.insert(0, os.path.dirname(__file__))
+ try:
+ import infinite_reload
+ finally:
+ sys.path.pop(0)
+
def test_import_name_binding(self):
# import x.y.z binds x in the current namespace
import test as x
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
index d74b9d5..98c79c7 100644
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -215,20 +215,20 @@ class TestBasicOps(unittest.TestCase):
self.assertEqual(list(izip_longest(*args, **{})), target)
target = [tuple((e is None and 'X' or e) for e in t) for t in target] # Replace None fills with 'X'
self.assertEqual(list(izip_longest(*args, **dict(fillvalue='X'))), target)
-
+
self.assertEqual(take(3,izip_longest('abcdef', count())), list(zip('abcdef', range(3)))) # take 3 from infinite input
self.assertEqual(list(izip_longest()), list(zip()))
self.assertEqual(list(izip_longest([])), list(zip([])))
self.assertEqual(list(izip_longest('abcdef')), list(zip('abcdef')))
-
+
self.assertEqual(list(izip_longest('abc', 'defg', **{})), map(None, 'abc', 'defg')) # empty keyword dict
self.assertRaises(TypeError, izip_longest, 3)
self.assertRaises(TypeError, izip_longest, range(3), 3)
for stmt in [
"izip_longest('abc', fv=1)",
- "izip_longest('abc', fillvalue=1, bogus_keyword=None)",
+ "izip_longest('abc', fillvalue=1, bogus_keyword=None)",
]:
try:
eval(stmt, globals(), locals())
@@ -236,7 +236,7 @@ class TestBasicOps(unittest.TestCase):
pass
else:
self.fail('Did not raise Type in: ' + stmt)
-
+
# Check tuple re-use (implementation detail)
self.assertEqual([tuple(list(pair)) for pair in izip_longest('abc', 'def')],
list(zip('abc', 'def')))
@@ -818,7 +818,7 @@ libreftest = """ Doctest for examples in the library reference: libitertools.tex
>>> amounts = [120.15, 764.05, 823.14]
>>> for checknum, amount in izip(count(1200), amounts):
... print('Check %d is for $%.2f' % (checknum, amount))
-...
+...
Check 1200 is for $120.15
Check 1201 is for $764.05
Check 1202 is for $823.14
@@ -826,7 +826,7 @@ Check 1202 is for $823.14
>>> import operator
>>> for cube in imap(operator.pow, xrange(1,4), repeat(3)):
... print(cube)
-...
+...
1
8
27
@@ -834,7 +834,7 @@ Check 1202 is for $823.14
>>> reportlines = ['EuroPython', 'Roster', '', 'alex', '', 'laura', '', 'martin', '', 'walter', '', 'samuele']
>>> for name in islice(reportlines, 3, None, 2):
... print(name.title())
-...
+...
Alex
Laura
Martin
@@ -846,7 +846,7 @@ Samuele
>>> di = sorted(sorted(d.items()), key=itemgetter(1))
>>> for k, g in groupby(di, itemgetter(1)):
... print(k, map(itemgetter(0), g))
-...
+...
1 ['a', 'c', 'e']
2 ['b', 'd', 'f']
3 ['g']
@@ -857,7 +857,7 @@ Samuele
>>> data = [ 1, 4,5,6, 10, 15,16,17,18, 22, 25,26,27,28]
>>> for k, g in groupby(enumerate(data), lambda (i,x):i-x):
... print(map(operator.itemgetter(1), g))
-...
+...
[1]
[4, 5, 6]
[10]
diff --git a/Lib/test/test_keywordonlyarg.py b/Lib/test/test_keywordonlyarg.py
index 2e1f8bd..fc67c98 100644
--- a/Lib/test/test_keywordonlyarg.py
+++ b/Lib/test/test_keywordonlyarg.py
@@ -71,7 +71,7 @@ class KeywordOnlyArgTestCase(unittest.TestCase):
fundef3 += "i%d, "%i
fundef3 += "lastarg):\n pass\n"
compile(fundef3, "<test>", "single")
-
+
def testSyntaxErrorForFunctionCall(self):
self.assertRaisesSyntaxError("f(p, k=1, p2)")
self.assertRaisesSyntaxError("f(p, *(1,2), k1=100)")
diff --git a/Lib/test/test_locale.py b/Lib/test/test_locale.py
index ba7d653..eba2cfd 100644
--- a/Lib/test/test_locale.py
+++ b/Lib/test/test_locale.py
@@ -7,7 +7,7 @@ if sys.platform == 'darwin':
oldlocale = locale.setlocale(locale.LC_NUMERIC)
if sys.platform.startswith("win"):
- tlocs = ("en",)
+ tlocs = ("En", "English")
else:
tlocs = ("en_US.UTF-8", "en_US.US-ASCII", "en_US")
diff --git a/Lib/test/test_logging.py b/Lib/test/test_logging.py
index 843440a..e8e4a8d 100644
--- a/Lib/test/test_logging.py
+++ b/Lib/test/test_logging.py
@@ -554,6 +554,8 @@ def test5():
except KeyError:
logging.exception("just testing")
os.remove(fn)
+ hdlr = logging.getLogger().handlers[0]
+ logging.getLogger().handlers.remove(hdlr)
finally:
logging._acquireLock()
try:
diff --git a/Lib/test/test_long_future.py b/Lib/test/test_long_future.py
index fc01001..36840b4 100644
--- a/Lib/test/test_long_future.py
+++ b/Lib/test/test_long_future.py
@@ -3,53 +3,53 @@ from __future__ import division
# test_long.py instead. In the meantime, it's too obscure to try to
# trick just part of test_long into using future division.
-from test.test_support import TestFailed, verify, verbose
-
-def test_true_division():
- if verbose:
- print("long true division")
- huge = 1 << 40000
- mhuge = -huge
- verify(huge / huge == 1.0)
- verify(mhuge / mhuge == 1.0)
- verify(huge / mhuge == -1.0)
- verify(mhuge / huge == -1.0)
- verify(1 / huge == 0.0)
- verify(1 / huge == 0.0)
- verify(1 / mhuge == 0.0)
- verify(1 / mhuge == 0.0)
- verify((666 * huge + (huge >> 1)) / huge == 666.5)
- verify((666 * mhuge + (mhuge >> 1)) / mhuge == 666.5)
- verify((666 * huge + (huge >> 1)) / mhuge == -666.5)
- verify((666 * mhuge + (mhuge >> 1)) / huge == -666.5)
- verify(huge / (huge << 1) == 0.5)
- verify((1000000 * huge) / huge == 1000000)
-
- namespace = {'huge': huge, 'mhuge': mhuge}
-
- for overflow in ["float(huge)", "float(mhuge)",
- "huge / 1", "huge / 2", "huge / -1", "huge / -2",
- "mhuge / 100", "mhuge / 100"]:
- try:
- eval(overflow, namespace)
- except OverflowError:
- pass
- else:
- raise TestFailed("expected OverflowError from %r" % overflow)
-
- for underflow in ["1 / huge", "2 / huge", "-1 / huge", "-2 / huge",
- "100 / mhuge", "100 / mhuge"]:
- result = eval(underflow, namespace)
- if result != 0.0:
- raise TestFailed("expected underflow to 0 from %r" % underflow)
-
- for zero in ["huge / 0", "huge / 0",
- "mhuge / 0", "mhuge / 0"]:
- try:
- eval(zero, namespace)
- except ZeroDivisionError:
- pass
- else:
- raise TestFailed("expected ZeroDivisionError from %r" % zero)
-
-test_true_division()
+import unittest
+from test.test_support import run_unittest
+
+class TrueDivisionTests(unittest.TestCase):
+ def test(self):
+ huge = 1 << 40000
+ mhuge = -huge
+ self.assertEqual(huge / huge, 1.0)
+ self.assertEqual(mhuge / mhuge, 1.0)
+ self.assertEqual(huge / mhuge, -1.0)
+ self.assertEqual(mhuge / huge, -1.0)
+ self.assertEqual(1 / huge, 0.0)
+ self.assertEqual(1 / huge, 0.0)
+ self.assertEqual(1 / mhuge, 0.0)
+ self.assertEqual(1 / mhuge, 0.0)
+ self.assertEqual((666 * huge + (huge >> 1)) / huge, 666.5)
+ self.assertEqual((666 * mhuge + (mhuge >> 1)) / mhuge, 666.5)
+ self.assertEqual((666 * huge + (huge >> 1)) / mhuge, -666.5)
+ self.assertEqual((666 * mhuge + (mhuge >> 1)) / huge, -666.5)
+ self.assertEqual(huge / (huge << 1), 0.5)
+ self.assertEqual((1000000 * huge) / huge, 1000000)
+
+ namespace = {'huge': huge, 'mhuge': mhuge}
+
+ for overflow in ["float(huge)", "float(mhuge)",
+ "huge / 1", "huge / 2", "huge / -1", "huge / -2",
+ "mhuge / 100", "mhuge / 200"]:
+ # XXX(cwinter) this test doesn't pass when converted to
+ # use assertRaises.
+ try:
+ eval(overflow, namespace)
+ self.fail("expected OverflowError from %r" % overflow)
+ except OverflowError:
+ pass
+
+ for underflow in ["1 / huge", "2 / huge", "-1 / huge", "-2 / huge",
+ "100 / mhuge", "200 / mhuge"]:
+ result = eval(underflow, namespace)
+ self.assertEqual(result, 0.0,
+ "expected underflow to 0 from %r" % underflow)
+
+ for zero in ["huge / 0", "mhuge / 0"]:
+ self.assertRaises(ZeroDivisionError, eval, zero, namespace)
+
+
+def test_main():
+ run_unittest(TrueDivisionTests)
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_macpath.py b/Lib/test/test_macpath.py
index 3a3cf04..2449b0a 100644
--- a/Lib/test/test_macpath.py
+++ b/Lib/test/test_macpath.py
@@ -48,7 +48,7 @@ class MacPathTestCase(unittest.TestCase):
splitext = macpath.splitext
self.assertEquals(splitext(":foo.ext"), (':foo', '.ext'))
self.assertEquals(splitext("foo:foo.ext"), ('foo:foo', '.ext'))
- self.assertEquals(splitext(".ext"), ('', '.ext'))
+ self.assertEquals(splitext(".ext"), ('.ext', ''))
self.assertEquals(splitext("foo.ext:foo"), ('foo.ext:foo', ''))
self.assertEquals(splitext(":foo.ext:"), (':foo.ext:', ''))
self.assertEquals(splitext(""), ('', ''))
diff --git a/Lib/test/test_mailbox.py b/Lib/test/test_mailbox.py
index 8650cef..1972ca3 100644
--- a/Lib/test/test_mailbox.py
+++ b/Lib/test/test_mailbox.py
@@ -54,6 +54,7 @@ class TestMailbox(TestBase):
def setUp(self):
self._path = test_support.TESTFN
+ self._delete_recursively(self._path)
self._box = self._factory(self._path)
def tearDown(self):
@@ -686,7 +687,7 @@ class _TestMboxMMDF(TestMailbox):
self._box.close()
self._delete_recursively(self._path)
for lock_remnant in glob.glob(self._path + '.*'):
- os.remove(lock_remnant)
+ test_support.unlink(lock_remnant)
def test_add_from_string(self):
# Add a string starting with 'From ' to the mailbox
@@ -909,7 +910,7 @@ class TestBabyl(TestMailbox):
self._box.close()
self._delete_recursively(self._path)
for lock_remnant in glob.glob(self._path + '.*'):
- os.remove(lock_remnant)
+ test_support.unlink(lock_remnant)
def test_labels(self):
# Get labels from the mailbox
diff --git a/Lib/test/test_metaclass.py b/Lib/test/test_metaclass.py
index df81079..9126cf6 100644
--- a/Lib/test/test_metaclass.py
+++ b/Lib/test/test_metaclass.py
@@ -63,6 +63,8 @@ Use a metaclass with a __prepare__ static method.
... def __new__(cls, name, bases, namespace, **kwds):
... print("New called:", kwds)
... return type.__new__(cls, name, bases, namespace)
+ ... def __init__(cls, *args, **kwds):
+ ... pass
...
>>> class C(metaclass=M):
... def meth(self): print("Hello")
diff --git a/Lib/test/test_minidom.py b/Lib/test/test_minidom.py
index 6c4dd94..5f95365 100644
--- a/Lib/test/test_minidom.py
+++ b/Lib/test/test_minidom.py
@@ -5,7 +5,8 @@ import sys
import pickle
import traceback
from StringIO import StringIO
-from test.test_support import verbose
+from test.test_support import verbose, run_unittest, TestSkipped
+import unittest
import xml.dom
import xml.dom.minidom
@@ -22,680 +23,9 @@ else:
tstfile = os.path.join(os.path.dirname(base), "test"+os.extsep+"xml")
del base
-def confirm(test, testname = "Test"):
- if not test:
- print("Failed " + testname)
- raise Exception
-
-def testParseFromFile():
- dom = parse(StringIO(open(tstfile).read()))
- dom.unlink()
- confirm(isinstance(dom,Document))
-
-def testGetElementsByTagName():
- dom = parse(tstfile)
- confirm(dom.getElementsByTagName("LI") == \
- dom.documentElement.getElementsByTagName("LI"))
- dom.unlink()
-
-def testInsertBefore():
- dom = parseString("<doc><foo/></doc>")
- root = dom.documentElement
- elem = root.childNodes[0]
- nelem = dom.createElement("element")
- root.insertBefore(nelem, elem)
- confirm(len(root.childNodes) == 2
- and root.childNodes.length == 2
- and root.childNodes[0] is nelem
- and root.childNodes.item(0) is nelem
- and root.childNodes[1] is elem
- and root.childNodes.item(1) is elem
- and root.firstChild is nelem
- and root.lastChild is elem
- and root.toxml() == "<doc><element/><foo/></doc>"
- , "testInsertBefore -- node properly placed in tree")
- nelem = dom.createElement("element")
- root.insertBefore(nelem, None)
- confirm(len(root.childNodes) == 3
- and root.childNodes.length == 3
- and root.childNodes[1] is elem
- and root.childNodes.item(1) is elem
- and root.childNodes[2] is nelem
- and root.childNodes.item(2) is nelem
- and root.lastChild is nelem
- and nelem.previousSibling is elem
- and root.toxml() == "<doc><element/><foo/><element/></doc>"
- , "testInsertBefore -- node properly placed in tree")
- nelem2 = dom.createElement("bar")
- root.insertBefore(nelem2, nelem)
- confirm(len(root.childNodes) == 4
- and root.childNodes.length == 4
- and root.childNodes[2] is nelem2
- and root.childNodes.item(2) is nelem2
- and root.childNodes[3] is nelem
- and root.childNodes.item(3) is nelem
- and nelem2.nextSibling is nelem
- and nelem.previousSibling is nelem2
- and root.toxml() == "<doc><element/><foo/><bar/><element/></doc>"
- , "testInsertBefore -- node properly placed in tree")
- dom.unlink()
-
-def _create_fragment_test_nodes():
- dom = parseString("<doc/>")
- orig = dom.createTextNode("original")
- c1 = dom.createTextNode("foo")
- c2 = dom.createTextNode("bar")
- c3 = dom.createTextNode("bat")
- dom.documentElement.appendChild(orig)
- frag = dom.createDocumentFragment()
- frag.appendChild(c1)
- frag.appendChild(c2)
- frag.appendChild(c3)
- return dom, orig, c1, c2, c3, frag
-
-def testInsertBeforeFragment():
- dom, orig, c1, c2, c3, frag = _create_fragment_test_nodes()
- dom.documentElement.insertBefore(frag, None)
- confirm(tuple(dom.documentElement.childNodes) == (orig, c1, c2, c3),
- "insertBefore(<fragment>, None)")
- frag.unlink()
- dom.unlink()
- #
- dom, orig, c1, c2, c3, frag = _create_fragment_test_nodes()
- dom.documentElement.insertBefore(frag, orig)
- confirm(tuple(dom.documentElement.childNodes) == (c1, c2, c3, orig),
- "insertBefore(<fragment>, orig)")
- frag.unlink()
- dom.unlink()
-
-def testAppendChild():
- dom = parse(tstfile)
- dom.documentElement.appendChild(dom.createComment(u"Hello"))
- confirm(dom.documentElement.childNodes[-1].nodeName == "#comment")
- confirm(dom.documentElement.childNodes[-1].data == "Hello")
- dom.unlink()
-
-def testAppendChildFragment():
- dom, orig, c1, c2, c3, frag = _create_fragment_test_nodes()
- dom.documentElement.appendChild(frag)
- confirm(tuple(dom.documentElement.childNodes) == (orig, c1, c2, c3),
- "appendChild(<fragment>)")
- frag.unlink()
- dom.unlink()
-
-def testReplaceChildFragment():
- dom, orig, c1, c2, c3, frag = _create_fragment_test_nodes()
- dom.documentElement.replaceChild(frag, orig)
- orig.unlink()
- confirm(tuple(dom.documentElement.childNodes) == (c1, c2, c3),
- "replaceChild(<fragment>)")
- frag.unlink()
- dom.unlink()
-
-def testLegalChildren():
- dom = Document()
- elem = dom.createElement('element')
- text = dom.createTextNode('text')
-
- try: dom.appendChild(text)
- except xml.dom.HierarchyRequestErr: pass
- else:
- print("dom.appendChild didn't raise HierarchyRequestErr")
-
- dom.appendChild(elem)
- try: dom.insertBefore(text, elem)
- except xml.dom.HierarchyRequestErr: pass
- else:
- print("dom.appendChild didn't raise HierarchyRequestErr")
-
- try: dom.replaceChild(text, elem)
- except xml.dom.HierarchyRequestErr: pass
- else:
- print("dom.appendChild didn't raise HierarchyRequestErr")
-
- nodemap = elem.attributes
- try: nodemap.setNamedItem(text)
- except xml.dom.HierarchyRequestErr: pass
- else:
- print("NamedNodeMap.setNamedItem didn't raise HierarchyRequestErr")
-
- try: nodemap.setNamedItemNS(text)
- except xml.dom.HierarchyRequestErr: pass
- else:
- print("NamedNodeMap.setNamedItemNS didn't raise HierarchyRequestErr")
-
- elem.appendChild(text)
- dom.unlink()
-
-def testNamedNodeMapSetItem():
- dom = Document()
- elem = dom.createElement('element')
- attrs = elem.attributes
- attrs["foo"] = "bar"
- a = attrs.item(0)
- confirm(a.ownerDocument is dom,
- "NamedNodeMap.__setitem__() sets ownerDocument")
- confirm(a.ownerElement is elem,
- "NamedNodeMap.__setitem__() sets ownerElement")
- confirm(a.value == "bar",
- "NamedNodeMap.__setitem__() sets value")
- confirm(a.nodeValue == "bar",
- "NamedNodeMap.__setitem__() sets nodeValue")
- elem.unlink()
- dom.unlink()
-
-def testNonZero():
- dom = parse(tstfile)
- confirm(dom)# should not be zero
- dom.appendChild(dom.createComment("foo"))
- confirm(not dom.childNodes[-1].childNodes)
- dom.unlink()
-
-def testUnlink():
- dom = parse(tstfile)
- dom.unlink()
-
-def testElement():
- dom = Document()
- dom.appendChild(dom.createElement("abc"))
- confirm(dom.documentElement)
- dom.unlink()
-
-def testAAA():
- dom = parseString("<abc/>")
- el = dom.documentElement
- el.setAttribute("spam", "jam2")
- confirm(el.toxml() == '<abc spam="jam2"/>', "testAAA")
- a = el.getAttributeNode("spam")
- confirm(a.ownerDocument is dom,
- "setAttribute() sets ownerDocument")
- confirm(a.ownerElement is dom.documentElement,
- "setAttribute() sets ownerElement")
- dom.unlink()
-
-def testAAB():
- dom = parseString("<abc/>")
- el = dom.documentElement
- el.setAttribute("spam", "jam")
- el.setAttribute("spam", "jam2")
- confirm(el.toxml() == '<abc spam="jam2"/>', "testAAB")
- dom.unlink()
-
-def testAddAttr():
- dom = Document()
- child = dom.appendChild(dom.createElement("abc"))
-
- child.setAttribute("def", "ghi")
- confirm(child.getAttribute("def") == "ghi")
- confirm(child.attributes["def"].value == "ghi")
-
- child.setAttribute("jkl", "mno")
- confirm(child.getAttribute("jkl") == "mno")
- confirm(child.attributes["jkl"].value == "mno")
-
- confirm(len(child.attributes) == 2)
-
- child.setAttribute("def", "newval")
- confirm(child.getAttribute("def") == "newval")
- confirm(child.attributes["def"].value == "newval")
-
- confirm(len(child.attributes) == 2)
- dom.unlink()
-
-def testDeleteAttr():
- dom = Document()
- child = dom.appendChild(dom.createElement("abc"))
-
- confirm(len(child.attributes) == 0)
- child.setAttribute("def", "ghi")
- confirm(len(child.attributes) == 1)
- del child.attributes["def"]
- confirm(len(child.attributes) == 0)
- dom.unlink()
-
-def testRemoveAttr():
- dom = Document()
- child = dom.appendChild(dom.createElement("abc"))
-
- child.setAttribute("def", "ghi")
- confirm(len(child.attributes) == 1)
- child.removeAttribute("def")
- confirm(len(child.attributes) == 0)
-
- dom.unlink()
-
-def testRemoveAttrNS():
- dom = Document()
- child = dom.appendChild(
- dom.createElementNS("http://www.python.org", "python:abc"))
- child.setAttributeNS("http://www.w3.org", "xmlns:python",
- "http://www.python.org")
- child.setAttributeNS("http://www.python.org", "python:abcattr", "foo")
- confirm(len(child.attributes) == 2)
- child.removeAttributeNS("http://www.python.org", "abcattr")
- confirm(len(child.attributes) == 1)
-
- dom.unlink()
-
-def testRemoveAttributeNode():
- dom = Document()
- child = dom.appendChild(dom.createElement("foo"))
- child.setAttribute("spam", "jam")
- confirm(len(child.attributes) == 1)
- node = child.getAttributeNode("spam")
- child.removeAttributeNode(node)
- confirm(len(child.attributes) == 0
- and child.getAttributeNode("spam") is None)
-
- dom.unlink()
-
-def testChangeAttr():
- dom = parseString("<abc/>")
- el = dom.documentElement
- el.setAttribute("spam", "jam")
- confirm(len(el.attributes) == 1)
- el.setAttribute("spam", "bam")
- # Set this attribute to be an ID and make sure that doesn't change
- # when changing the value:
- el.setIdAttribute("spam")
- confirm(len(el.attributes) == 1
- and el.attributes["spam"].value == "bam"
- and el.attributes["spam"].nodeValue == "bam"
- and el.getAttribute("spam") == "bam"
- and el.getAttributeNode("spam").isId)
- el.attributes["spam"] = "ham"
- confirm(len(el.attributes) == 1
- and el.attributes["spam"].value == "ham"
- and el.attributes["spam"].nodeValue == "ham"
- and el.getAttribute("spam") == "ham"
- and el.attributes["spam"].isId)
- el.setAttribute("spam2", "bam")
- confirm(len(el.attributes) == 2
- and el.attributes["spam"].value == "ham"
- and el.attributes["spam"].nodeValue == "ham"
- and el.getAttribute("spam") == "ham"
- and el.attributes["spam2"].value == "bam"
- and el.attributes["spam2"].nodeValue == "bam"
- and el.getAttribute("spam2") == "bam")
- el.attributes["spam2"] = "bam2"
- confirm(len(el.attributes) == 2
- and el.attributes["spam"].value == "ham"
- and el.attributes["spam"].nodeValue == "ham"
- and el.getAttribute("spam") == "ham"
- and el.attributes["spam2"].value == "bam2"
- and el.attributes["spam2"].nodeValue == "bam2"
- and el.getAttribute("spam2") == "bam2")
- dom.unlink()
-
-def testGetAttrList():
- pass
-
-def testGetAttrValues(): pass
-
-def testGetAttrLength(): pass
-
-def testGetAttribute(): pass
-
-def testGetAttributeNS(): pass
-
-def testGetAttributeNode(): pass
-
-def testGetElementsByTagNameNS():
- d="""<foo xmlns:minidom='http://pyxml.sf.net/minidom'>
- <minidom:myelem/>
- </foo>"""
- dom = parseString(d)
- elems = dom.getElementsByTagNameNS("http://pyxml.sf.net/minidom", "myelem")
- confirm(len(elems) == 1
- and elems[0].namespaceURI == "http://pyxml.sf.net/minidom"
- and elems[0].localName == "myelem"
- and elems[0].prefix == "minidom"
- and elems[0].tagName == "minidom:myelem"
- and elems[0].nodeName == "minidom:myelem")
- dom.unlink()
-
-def get_empty_nodelist_from_elements_by_tagName_ns_helper(doc, nsuri, lname):
- nodelist = doc.getElementsByTagNameNS(nsuri, lname)
- confirm(len(nodelist) == 0)
-
-def testGetEmptyNodeListFromElementsByTagNameNS():
- doc = parseString('<doc/>')
- get_empty_nodelist_from_elements_by_tagName_ns_helper(
- doc, 'http://xml.python.org/namespaces/a', 'localname')
- get_empty_nodelist_from_elements_by_tagName_ns_helper(
- doc, '*', 'splat')
- get_empty_nodelist_from_elements_by_tagName_ns_helper(
- doc, 'http://xml.python.org/namespaces/a', '*')
-
- doc = parseString('<doc xmlns="http://xml.python.org/splat"><e/></doc>')
- get_empty_nodelist_from_elements_by_tagName_ns_helper(
- doc, "http://xml.python.org/splat", "not-there")
- get_empty_nodelist_from_elements_by_tagName_ns_helper(
- doc, "*", "not-there")
- get_empty_nodelist_from_elements_by_tagName_ns_helper(
- doc, "http://somewhere.else.net/not-there", "e")
-
-def testElementReprAndStr():
- dom = Document()
- el = dom.appendChild(dom.createElement("abc"))
- string1 = repr(el)
- string2 = str(el)
- confirm(string1 == string2)
- dom.unlink()
-
-# commented out until Fredrick's fix is checked in
-def _testElementReprAndStrUnicode():
- dom = Document()
- el = dom.appendChild(dom.createElement(u"abc"))
- string1 = repr(el)
- string2 = str(el)
- confirm(string1 == string2)
- dom.unlink()
-
-# commented out until Fredrick's fix is checked in
-def _testElementReprAndStrUnicodeNS():
- dom = Document()
- el = dom.appendChild(
- dom.createElementNS(u"http://www.slashdot.org", u"slash:abc"))
- string1 = repr(el)
- string2 = str(el)
- confirm(string1 == string2)
- confirm(string1.find("slash:abc") != -1)
- dom.unlink()
-
-def testAttributeRepr():
- dom = Document()
- el = dom.appendChild(dom.createElement(u"abc"))
- node = el.setAttribute("abc", "def")
- confirm(str(node) == repr(node))
- dom.unlink()
-
-def testTextNodeRepr(): pass
-
-def testWriteXML():
- str = '<?xml version="1.0" ?><a b="c"/>'
- dom = parseString(str)
- domstr = dom.toxml()
- dom.unlink()
- confirm(str == domstr)
-
-def testAltNewline():
- str = '<?xml version="1.0" ?>\n<a b="c"/>\n'
- dom = parseString(str)
- domstr = dom.toprettyxml(newl="\r\n")
- dom.unlink()
- confirm(domstr == str.replace("\n", "\r\n"))
-
-def testProcessingInstruction():
- dom = parseString('<e><?mypi \t\n data \t\n ?></e>')
- pi = dom.documentElement.firstChild
- confirm(pi.target == "mypi"
- and pi.data == "data \t\n "
- and pi.nodeName == "mypi"
- and pi.nodeType == Node.PROCESSING_INSTRUCTION_NODE
- and pi.attributes is None
- and not pi.hasChildNodes()
- and len(pi.childNodes) == 0
- and pi.firstChild is None
- and pi.lastChild is None
- and pi.localName is None
- and pi.namespaceURI == xml.dom.EMPTY_NAMESPACE)
-
-def testProcessingInstructionRepr(): pass
-
-def testTextRepr(): pass
-
-def testWriteText(): pass
-
-def testDocumentElement(): pass
-
-def testTooManyDocumentElements():
- doc = parseString("<doc/>")
- elem = doc.createElement("extra")
- try:
- doc.appendChild(elem)
- except xml.dom.HierarchyRequestErr:
- pass
- else:
- print("Failed to catch expected exception when" \
- " adding extra document element.")
- elem.unlink()
- doc.unlink()
-
-def testCreateElementNS(): pass
-
-def testCreateAttributeNS(): pass
-
-def testParse(): pass
-
-def testParseString(): pass
-
-def testComment(): pass
-
-def testAttrListItem(): pass
-
-def testAttrListItems(): pass
-
-def testAttrListItemNS(): pass
-
-def testAttrListKeys(): pass
-
-def testAttrListKeysNS(): pass
-
-def testRemoveNamedItem():
- doc = parseString("<doc a=''/>")
- e = doc.documentElement
- attrs = e.attributes
- a1 = e.getAttributeNode("a")
- a2 = attrs.removeNamedItem("a")
- confirm(a1.isSameNode(a2))
- try:
- attrs.removeNamedItem("a")
- except xml.dom.NotFoundErr:
- pass
-
-def testRemoveNamedItemNS():
- doc = parseString("<doc xmlns:a='http://xml.python.org/' a:b=''/>")
- e = doc.documentElement
- attrs = e.attributes
- a1 = e.getAttributeNodeNS("http://xml.python.org/", "b")
- a2 = attrs.removeNamedItemNS("http://xml.python.org/", "b")
- confirm(a1.isSameNode(a2))
- try:
- attrs.removeNamedItemNS("http://xml.python.org/", "b")
- except xml.dom.NotFoundErr:
- pass
-
-def testAttrListValues(): pass
-
-def testAttrListLength(): pass
-
-def testAttrList__getitem__(): pass
-
-def testAttrList__setitem__(): pass
-
-def testSetAttrValueandNodeValue(): pass
-
-def testParseElement(): pass
-
-def testParseAttributes(): pass
-
-def testParseElementNamespaces(): pass
-
-def testParseAttributeNamespaces(): pass
-
-def testParseProcessingInstructions(): pass
-
-def testChildNodes(): pass
-
-def testFirstChild(): pass
-
-def testHasChildNodes(): pass
-
-def testCloneElementShallow():
- dom, clone = _setupCloneElement(0)
- confirm(len(clone.childNodes) == 0
- and clone.childNodes.length == 0
- and clone.parentNode is None
- and clone.toxml() == '<doc attr="value"/>'
- , "testCloneElementShallow")
- dom.unlink()
-
-def testCloneElementDeep():
- dom, clone = _setupCloneElement(1)
- confirm(len(clone.childNodes) == 1
- and clone.childNodes.length == 1
- and clone.parentNode is None
- and clone.toxml() == '<doc attr="value"><foo/></doc>'
- , "testCloneElementDeep")
- dom.unlink()
-
-def _setupCloneElement(deep):
- dom = parseString("<doc attr='value'><foo/></doc>")
- root = dom.documentElement
- clone = root.cloneNode(deep)
- _testCloneElementCopiesAttributes(
- root, clone, "testCloneElement" + (deep and "Deep" or "Shallow"))
- # mutilate the original so shared data is detected
- root.tagName = root.nodeName = "MODIFIED"
- root.setAttribute("attr", "NEW VALUE")
- root.setAttribute("added", "VALUE")
- return dom, clone
-
-def _testCloneElementCopiesAttributes(e1, e2, test):
- attrs1 = e1.attributes
- attrs2 = e2.attributes
- keys1 = sorted(attrs1.keys())
- keys2 = sorted(attrs2.keys())
- confirm(keys1 == keys2, "clone of element has same attribute keys")
- for i in range(len(keys1)):
- a1 = attrs1.item(i)
- a2 = attrs2.item(i)
- confirm(a1 is not a2
- and a1.value == a2.value
- and a1.nodeValue == a2.nodeValue
- and a1.namespaceURI == a2.namespaceURI
- and a1.localName == a2.localName
- , "clone of attribute node has proper attribute values")
- confirm(a2.ownerElement is e2,
- "clone of attribute node correctly owned")
-
-def testCloneDocumentShallow():
- doc = parseString("<?xml version='1.0'?>\n"
- "<!-- comment -->"
- "<!DOCTYPE doc [\n"
- "<!NOTATION notation SYSTEM 'http://xml.python.org/'>\n"
- "]>\n"
- "<doc attr='value'/>")
- doc2 = doc.cloneNode(0)
- confirm(doc2 is None,
- "testCloneDocumentShallow:"
- " shallow cloning of documents makes no sense!")
-
-def testCloneDocumentDeep():
- doc = parseString("<?xml version='1.0'?>\n"
- "<!-- comment -->"
- "<!DOCTYPE doc [\n"
- "<!NOTATION notation SYSTEM 'http://xml.python.org/'>\n"
- "]>\n"
- "<doc attr='value'/>")
- doc2 = doc.cloneNode(1)
- confirm(not (doc.isSameNode(doc2) or doc2.isSameNode(doc)),
- "testCloneDocumentDeep: document objects not distinct")
- confirm(len(doc.childNodes) == len(doc2.childNodes),
- "testCloneDocumentDeep: wrong number of Document children")
- confirm(doc2.documentElement.nodeType == Node.ELEMENT_NODE,
- "testCloneDocumentDeep: documentElement not an ELEMENT_NODE")
- confirm(doc2.documentElement.ownerDocument.isSameNode(doc2),
- "testCloneDocumentDeep: documentElement owner is not new document")
- confirm(not doc.documentElement.isSameNode(doc2.documentElement),
- "testCloneDocumentDeep: documentElement should not be shared")
- if doc.doctype is not None:
- # check the doctype iff the original DOM maintained it
- confirm(doc2.doctype.nodeType == Node.DOCUMENT_TYPE_NODE,
- "testCloneDocumentDeep: doctype not a DOCUMENT_TYPE_NODE")
- confirm(doc2.doctype.ownerDocument.isSameNode(doc2))
- confirm(not doc.doctype.isSameNode(doc2.doctype))
-
-def testCloneDocumentTypeDeepOk():
- doctype = create_nonempty_doctype()
- clone = doctype.cloneNode(1)
- confirm(clone is not None
- and clone.nodeName == doctype.nodeName
- and clone.name == doctype.name
- and clone.publicId == doctype.publicId
- and clone.systemId == doctype.systemId
- and len(clone.entities) == len(doctype.entities)
- and clone.entities.item(len(clone.entities)) is None
- and len(clone.notations) == len(doctype.notations)
- and clone.notations.item(len(clone.notations)) is None
- and len(clone.childNodes) == 0)
- for i in range(len(doctype.entities)):
- se = doctype.entities.item(i)
- ce = clone.entities.item(i)
- confirm((not se.isSameNode(ce))
- and (not ce.isSameNode(se))
- and ce.nodeName == se.nodeName
- and ce.notationName == se.notationName
- and ce.publicId == se.publicId
- and ce.systemId == se.systemId
- and ce.encoding == se.encoding
- and ce.actualEncoding == se.actualEncoding
- and ce.version == se.version)
- for i in range(len(doctype.notations)):
- sn = doctype.notations.item(i)
- cn = clone.notations.item(i)
- confirm((not sn.isSameNode(cn))
- and (not cn.isSameNode(sn))
- and cn.nodeName == sn.nodeName
- and cn.publicId == sn.publicId
- and cn.systemId == sn.systemId)
-
-def testCloneDocumentTypeDeepNotOk():
- doc = create_doc_with_doctype()
- clone = doc.doctype.cloneNode(1)
- confirm(clone is None, "testCloneDocumentTypeDeepNotOk")
-
-def testCloneDocumentTypeShallowOk():
- doctype = create_nonempty_doctype()
- clone = doctype.cloneNode(0)
- confirm(clone is not None
- and clone.nodeName == doctype.nodeName
- and clone.name == doctype.name
- and clone.publicId == doctype.publicId
- and clone.systemId == doctype.systemId
- and len(clone.entities) == 0
- and clone.entities.item(0) is None
- and len(clone.notations) == 0
- and clone.notations.item(0) is None
- and len(clone.childNodes) == 0)
-
-def testCloneDocumentTypeShallowNotOk():
- doc = create_doc_with_doctype()
- clone = doc.doctype.cloneNode(0)
- confirm(clone is None, "testCloneDocumentTypeShallowNotOk")
-
-def check_import_document(deep, testName):
- doc1 = parseString("<doc/>")
- doc2 = parseString("<doc/>")
- try:
- doc1.importNode(doc2, deep)
- except xml.dom.NotSupportedErr:
- pass
- else:
- raise Exception(testName +
- ": expected NotSupportedErr when importing a document")
-
-def testImportDocumentShallow():
- check_import_document(0, "testImportDocumentShallow")
-
-def testImportDocumentDeep():
- check_import_document(1, "testImportDocumentDeep")
-
# The tests of DocumentType importing use these helpers to construct
# the documents to work with, since not all DOM builders actually
# create the DocumentType nodes.
-
def create_doc_without_doctype(doctype=None):
return getDOMImplementation().createDocument(None, "doc", doctype)
@@ -722,673 +52,1263 @@ def create_doc_with_doctype():
doctype.notations.item(0).ownerDocument = doc
return doc
-def testImportDocumentTypeShallow():
- src = create_doc_with_doctype()
- target = create_doc_without_doctype()
- try:
- imported = target.importNode(src.doctype, 0)
- except xml.dom.NotSupportedErr:
- pass
- else:
- raise Exception(
- "testImportDocumentTypeShallow: expected NotSupportedErr")
-
-def testImportDocumentTypeDeep():
- src = create_doc_with_doctype()
- target = create_doc_without_doctype()
- try:
- imported = target.importNode(src.doctype, 1)
- except xml.dom.NotSupportedErr:
- pass
- else:
- raise Exception(
- "testImportDocumentTypeDeep: expected NotSupportedErr")
-
-# Testing attribute clones uses a helper, and should always be deep,
-# even if the argument to cloneNode is false.
-def check_clone_attribute(deep, testName):
- doc = parseString("<doc attr='value'/>")
- attr = doc.documentElement.getAttributeNode("attr")
- assert attr is not None
- clone = attr.cloneNode(deep)
- confirm(not clone.isSameNode(attr))
- confirm(not attr.isSameNode(clone))
- confirm(clone.ownerElement is None,
- testName + ": ownerElement should be None")
- confirm(clone.ownerDocument.isSameNode(attr.ownerDocument),
- testName + ": ownerDocument does not match")
- confirm(clone.specified,
- testName + ": cloned attribute must have specified == True")
-
-def testCloneAttributeShallow():
- check_clone_attribute(0, "testCloneAttributeShallow")
-
-def testCloneAttributeDeep():
- check_clone_attribute(1, "testCloneAttributeDeep")
-
-def check_clone_pi(deep, testName):
- doc = parseString("<?target data?><doc/>")
- pi = doc.firstChild
- assert pi.nodeType == Node.PROCESSING_INSTRUCTION_NODE
- clone = pi.cloneNode(deep)
- confirm(clone.target == pi.target
- and clone.data == pi.data)
-
-def testClonePIShallow():
- check_clone_pi(0, "testClonePIShallow")
-
-def testClonePIDeep():
- check_clone_pi(1, "testClonePIDeep")
-
-def testNormalize():
- doc = parseString("<doc/>")
- root = doc.documentElement
- root.appendChild(doc.createTextNode("first"))
- root.appendChild(doc.createTextNode("second"))
- confirm(len(root.childNodes) == 2
- and root.childNodes.length == 2, "testNormalize -- preparation")
- doc.normalize()
- confirm(len(root.childNodes) == 1
- and root.childNodes.length == 1
- and root.firstChild is root.lastChild
- and root.firstChild.data == "firstsecond"
- , "testNormalize -- result")
- doc.unlink()
-
- doc = parseString("<doc/>")
- root = doc.documentElement
- root.appendChild(doc.createTextNode(""))
- doc.normalize()
- confirm(len(root.childNodes) == 0
- and root.childNodes.length == 0,
- "testNormalize -- single empty node removed")
- doc.unlink()
-
-def testSiblings():
- doc = parseString("<doc><?pi?>text?<elm/></doc>")
- root = doc.documentElement
- (pi, text, elm) = root.childNodes
-
- confirm(pi.nextSibling is text and
- pi.previousSibling is None and
- text.nextSibling is elm and
- text.previousSibling is pi and
- elm.nextSibling is None and
- elm.previousSibling is text, "testSiblings")
-
- doc.unlink()
-
-def testParents():
- doc = parseString("<doc><elm1><elm2/><elm2><elm3/></elm2></elm1></doc>")
- root = doc.documentElement
- elm1 = root.childNodes[0]
- (elm2a, elm2b) = elm1.childNodes
- elm3 = elm2b.childNodes[0]
-
- confirm(root.parentNode is doc and
- elm1.parentNode is root and
- elm2a.parentNode is elm1 and
- elm2b.parentNode is elm1 and
- elm3.parentNode is elm2b, "testParents")
-
- doc.unlink()
-
-def testNodeListItem():
- doc = parseString("<doc><e/><e/></doc>")
- children = doc.childNodes
- docelem = children[0]
- confirm(children[0] is children.item(0)
- and children.item(1) is None
- and docelem.childNodes.item(0) is docelem.childNodes[0]
- and docelem.childNodes.item(1) is docelem.childNodes[1]
- and docelem.childNodes.item(0).childNodes.item(0) is None,
- "test NodeList.item()")
- doc.unlink()
-
-def testSAX2DOM():
- from xml.dom import pulldom
-
- sax2dom = pulldom.SAX2DOM()
- sax2dom.startDocument()
- sax2dom.startElement("doc", {})
- sax2dom.characters("text")
- sax2dom.startElement("subelm", {})
- sax2dom.characters("text")
- sax2dom.endElement("subelm")
- sax2dom.characters("text")
- sax2dom.endElement("doc")
- sax2dom.endDocument()
-
- doc = sax2dom.document
- root = doc.documentElement
- (text1, elm1, text2) = root.childNodes
- text3 = elm1.childNodes[0]
-
- confirm(text1.previousSibling is None and
- text1.nextSibling is elm1 and
- elm1.previousSibling is text1 and
- elm1.nextSibling is text2 and
- text2.previousSibling is elm1 and
- text2.nextSibling is None and
- text3.previousSibling is None and
- text3.nextSibling is None, "testSAX2DOM - siblings")
-
- confirm(root.parentNode is doc and
- text1.parentNode is root and
- elm1.parentNode is root and
- text2.parentNode is root and
- text3.parentNode is elm1, "testSAX2DOM - parents")
-
- doc.unlink()
-
-def testEncodings():
- doc = parseString('<foo>&#x20ac;</foo>')
- confirm(doc.toxml() == u'<?xml version="1.0" ?><foo>\u20ac</foo>'
- and doc.toxml('utf-8') == '<?xml version="1.0" encoding="utf-8"?><foo>\xe2\x82\xac</foo>'
- and doc.toxml('iso-8859-15') == '<?xml version="1.0" encoding="iso-8859-15"?><foo>\xa4</foo>',
- "testEncodings - encoding EURO SIGN")
-
- # Verify that character decoding errors throw exceptions instead of crashing
- try:
- doc = parseString('<fran\xe7ais>Comment \xe7a va ? Tr\xe8s bien ?</fran\xe7ais>')
- except UnicodeDecodeError:
- pass
- else:
- print('parsing with bad encoding should raise a UnicodeDecodeError')
-
- doc.unlink()
-
-class UserDataHandler:
- called = 0
- def handle(self, operation, key, data, src, dst):
- dst.setUserData(key, data + 1, self)
- src.setUserData(key, None, None)
- self.called = 1
-
-def testUserData():
- dom = Document()
- n = dom.createElement('e')
- confirm(n.getUserData("foo") is None)
- n.setUserData("foo", None, None)
- confirm(n.getUserData("foo") is None)
- n.setUserData("foo", 12, 12)
- n.setUserData("bar", 13, 13)
- confirm(n.getUserData("foo") == 12)
- confirm(n.getUserData("bar") == 13)
- n.setUserData("foo", None, None)
- confirm(n.getUserData("foo") is None)
- confirm(n.getUserData("bar") == 13)
-
- handler = UserDataHandler()
- n.setUserData("bar", 12, handler)
- c = n.cloneNode(1)
- confirm(handler.called
- and n.getUserData("bar") is None
- and c.getUserData("bar") == 13)
- n.unlink()
- c.unlink()
- dom.unlink()
-
-def testRenameAttribute():
- doc = parseString("<doc a='v'/>")
- elem = doc.documentElement
- attrmap = elem.attributes
- attr = elem.attributes['a']
-
- # Simple renaming
- attr = doc.renameNode(attr, xml.dom.EMPTY_NAMESPACE, "b")
- confirm(attr.name == "b"
- and attr.nodeName == "b"
- and attr.localName is None
- and attr.namespaceURI == xml.dom.EMPTY_NAMESPACE
- and attr.prefix is None
- and attr.value == "v"
- and elem.getAttributeNode("a") is None
- and elem.getAttributeNode("b").isSameNode(attr)
- and attrmap["b"].isSameNode(attr)
- and attr.ownerDocument.isSameNode(doc)
- and attr.ownerElement.isSameNode(elem))
-
- # Rename to have a namespace, no prefix
- attr = doc.renameNode(attr, "http://xml.python.org/ns", "c")
- confirm(attr.name == "c"
- and attr.nodeName == "c"
- and attr.localName == "c"
- and attr.namespaceURI == "http://xml.python.org/ns"
- and attr.prefix is None
- and attr.value == "v"
- and elem.getAttributeNode("a") is None
- and elem.getAttributeNode("b") is None
- and elem.getAttributeNode("c").isSameNode(attr)
- and elem.getAttributeNodeNS(
- "http://xml.python.org/ns", "c").isSameNode(attr)
- and attrmap["c"].isSameNode(attr)
- and attrmap[("http://xml.python.org/ns", "c")].isSameNode(attr))
-
- # Rename to have a namespace, with prefix
- attr = doc.renameNode(attr, "http://xml.python.org/ns2", "p:d")
- confirm(attr.name == "p:d"
- and attr.nodeName == "p:d"
- and attr.localName == "d"
- and attr.namespaceURI == "http://xml.python.org/ns2"
- and attr.prefix == "p"
- and attr.value == "v"
- and elem.getAttributeNode("a") is None
- and elem.getAttributeNode("b") is None
- and elem.getAttributeNode("c") is None
- and elem.getAttributeNodeNS(
- "http://xml.python.org/ns", "c") is None
- and elem.getAttributeNode("p:d").isSameNode(attr)
- and elem.getAttributeNodeNS(
- "http://xml.python.org/ns2", "d").isSameNode(attr)
- and attrmap["p:d"].isSameNode(attr)
- and attrmap[("http://xml.python.org/ns2", "d")].isSameNode(attr))
-
- # Rename back to a simple non-NS node
- attr = doc.renameNode(attr, xml.dom.EMPTY_NAMESPACE, "e")
- confirm(attr.name == "e"
- and attr.nodeName == "e"
- and attr.localName is None
- and attr.namespaceURI == xml.dom.EMPTY_NAMESPACE
- and attr.prefix is None
- and attr.value == "v"
- and elem.getAttributeNode("a") is None
- and elem.getAttributeNode("b") is None
- and elem.getAttributeNode("c") is None
- and elem.getAttributeNode("p:d") is None
- and elem.getAttributeNodeNS(
- "http://xml.python.org/ns", "c") is None
- and elem.getAttributeNode("e").isSameNode(attr)
- and attrmap["e"].isSameNode(attr))
-
- try:
- doc.renameNode(attr, "http://xml.python.org/ns", "xmlns")
- except xml.dom.NamespaceErr:
- pass
- else:
- print("expected NamespaceErr")
-
- checkRenameNodeSharedConstraints(doc, attr)
- doc.unlink()
-
-def testRenameElement():
- doc = parseString("<doc/>")
- elem = doc.documentElement
-
- # Simple renaming
- elem = doc.renameNode(elem, xml.dom.EMPTY_NAMESPACE, "a")
- confirm(elem.tagName == "a"
- and elem.nodeName == "a"
- and elem.localName is None
- and elem.namespaceURI == xml.dom.EMPTY_NAMESPACE
- and elem.prefix is None
- and elem.ownerDocument.isSameNode(doc))
-
- # Rename to have a namespace, no prefix
- elem = doc.renameNode(elem, "http://xml.python.org/ns", "b")
- confirm(elem.tagName == "b"
- and elem.nodeName == "b"
- and elem.localName == "b"
- and elem.namespaceURI == "http://xml.python.org/ns"
- and elem.prefix is None
- and elem.ownerDocument.isSameNode(doc))
-
- # Rename to have a namespace, with prefix
- elem = doc.renameNode(elem, "http://xml.python.org/ns2", "p:c")
- confirm(elem.tagName == "p:c"
- and elem.nodeName == "p:c"
- and elem.localName == "c"
- and elem.namespaceURI == "http://xml.python.org/ns2"
- and elem.prefix == "p"
- and elem.ownerDocument.isSameNode(doc))
-
- # Rename back to a simple non-NS node
- elem = doc.renameNode(elem, xml.dom.EMPTY_NAMESPACE, "d")
- confirm(elem.tagName == "d"
- and elem.nodeName == "d"
- and elem.localName is None
- and elem.namespaceURI == xml.dom.EMPTY_NAMESPACE
- and elem.prefix is None
- and elem.ownerDocument.isSameNode(doc))
-
- checkRenameNodeSharedConstraints(doc, elem)
- doc.unlink()
-
-def checkRenameNodeSharedConstraints(doc, node):
- # Make sure illegal NS usage is detected:
- try:
- doc.renameNode(node, "http://xml.python.org/ns", "xmlns:foo")
- except xml.dom.NamespaceErr:
- pass
- else:
- print("expected NamespaceErr")
+class MinidomTest(unittest.TestCase):
+ def tearDown(self):
+ try:
+ Node.allnodes
+ except AttributeError:
+ # We don't actually have the minidom from the standard library,
+ # but are picking up the PyXML version from site-packages.
+ pass
+ else:
+ self.confirm(len(Node.allnodes) == 0,
+ "assertion: len(Node.allnodes) == 0")
+ if len(Node.allnodes):
+ print("Garbage left over:")
+ if verbose:
+ print(list(Node.allnodes.items())[0:10])
+ else:
+ # Don't print specific nodes if repeatable results
+ # are needed
+ print(len(Node.allnodes))
+ Node.allnodes = {}
- doc2 = parseString("<doc/>")
- try:
- doc2.renameNode(node, xml.dom.EMPTY_NAMESPACE, "foo")
- except xml.dom.WrongDocumentErr:
- pass
- else:
- print("expected WrongDocumentErr")
-
-def testRenameOther():
- # We have to create a comment node explicitly since not all DOM
- # builders used with minidom add comments to the DOM.
- doc = xml.dom.minidom.getDOMImplementation().createDocument(
- xml.dom.EMPTY_NAMESPACE, "e", None)
- node = doc.createComment("comment")
- try:
- doc.renameNode(node, xml.dom.EMPTY_NAMESPACE, "foo")
- except xml.dom.NotSupportedErr:
+ def confirm(self, test, testname = "Test"):
+ self.assertTrue(test, testname)
+
+ def checkWholeText(self, node, s):
+ t = node.wholeText
+ self.confirm(t == s, "looking for %s, found %s" % (repr(s), repr(t)))
+
+ def testParseFromFile(self):
+ dom = parse(StringIO(open(tstfile).read()))
+ dom.unlink()
+ self.confirm(isinstance(dom, Document))
+
+ def testGetElementsByTagName(self):
+ dom = parse(tstfile)
+ self.confirm(dom.getElementsByTagName("LI") == \
+ dom.documentElement.getElementsByTagName("LI"))
+ dom.unlink()
+
+ def testInsertBefore(self):
+ dom = parseString("<doc><foo/></doc>")
+ root = dom.documentElement
+ elem = root.childNodes[0]
+ nelem = dom.createElement("element")
+ root.insertBefore(nelem, elem)
+ self.confirm(len(root.childNodes) == 2
+ and root.childNodes.length == 2
+ and root.childNodes[0] is nelem
+ and root.childNodes.item(0) is nelem
+ and root.childNodes[1] is elem
+ and root.childNodes.item(1) is elem
+ and root.firstChild is nelem
+ and root.lastChild is elem
+ and root.toxml() == "<doc><element/><foo/></doc>"
+ , "testInsertBefore -- node properly placed in tree")
+ nelem = dom.createElement("element")
+ root.insertBefore(nelem, None)
+ self.confirm(len(root.childNodes) == 3
+ and root.childNodes.length == 3
+ and root.childNodes[1] is elem
+ and root.childNodes.item(1) is elem
+ and root.childNodes[2] is nelem
+ and root.childNodes.item(2) is nelem
+ and root.lastChild is nelem
+ and nelem.previousSibling is elem
+ and root.toxml() == "<doc><element/><foo/><element/></doc>"
+ , "testInsertBefore -- node properly placed in tree")
+ nelem2 = dom.createElement("bar")
+ root.insertBefore(nelem2, nelem)
+ self.confirm(len(root.childNodes) == 4
+ and root.childNodes.length == 4
+ and root.childNodes[2] is nelem2
+ and root.childNodes.item(2) is nelem2
+ and root.childNodes[3] is nelem
+ and root.childNodes.item(3) is nelem
+ and nelem2.nextSibling is nelem
+ and nelem.previousSibling is nelem2
+ and root.toxml() ==
+ "<doc><element/><foo/><bar/><element/></doc>"
+ , "testInsertBefore -- node properly placed in tree")
+ dom.unlink()
+
+ def _create_fragment_test_nodes(self):
+ dom = parseString("<doc/>")
+ orig = dom.createTextNode("original")
+ c1 = dom.createTextNode("foo")
+ c2 = dom.createTextNode("bar")
+ c3 = dom.createTextNode("bat")
+ dom.documentElement.appendChild(orig)
+ frag = dom.createDocumentFragment()
+ frag.appendChild(c1)
+ frag.appendChild(c2)
+ frag.appendChild(c3)
+ return dom, orig, c1, c2, c3, frag
+
+ def testInsertBeforeFragment(self):
+ dom, orig, c1, c2, c3, frag = self._create_fragment_test_nodes()
+ dom.documentElement.insertBefore(frag, None)
+ self.confirm(tuple(dom.documentElement.childNodes) ==
+ (orig, c1, c2, c3),
+ "insertBefore(<fragment>, None)")
+ frag.unlink()
+ dom.unlink()
+
+ dom, orig, c1, c2, c3, frag = self._create_fragment_test_nodes()
+ dom.documentElement.insertBefore(frag, orig)
+ self.confirm(tuple(dom.documentElement.childNodes) ==
+ (c1, c2, c3, orig),
+ "insertBefore(<fragment>, orig)")
+ frag.unlink()
+ dom.unlink()
+
+ def testAppendChild(self):
+ dom = parse(tstfile)
+ dom.documentElement.appendChild(dom.createComment(u"Hello"))
+ self.confirm(dom.documentElement.childNodes[-1].nodeName == "#comment")
+ self.confirm(dom.documentElement.childNodes[-1].data == "Hello")
+ dom.unlink()
+
+ def testAppendChildFragment(self):
+ dom, orig, c1, c2, c3, frag = self._create_fragment_test_nodes()
+ dom.documentElement.appendChild(frag)
+ self.confirm(tuple(dom.documentElement.childNodes) ==
+ (orig, c1, c2, c3),
+ "appendChild(<fragment>)")
+ frag.unlink()
+ dom.unlink()
+
+ def testReplaceChildFragment(self):
+ dom, orig, c1, c2, c3, frag = self._create_fragment_test_nodes()
+ dom.documentElement.replaceChild(frag, orig)
+ orig.unlink()
+ self.confirm(tuple(dom.documentElement.childNodes) == (c1, c2, c3),
+ "replaceChild(<fragment>)")
+ frag.unlink()
+ dom.unlink()
+
+ def testLegalChildren(self):
+ dom = Document()
+ elem = dom.createElement('element')
+ text = dom.createTextNode('text')
+ self.assertRaises(xml.dom.HierarchyRequestErr, dom.appendChild, text)
+
+ dom.appendChild(elem)
+ self.assertRaises(xml.dom.HierarchyRequestErr, dom.insertBefore, text,
+ elem)
+ self.assertRaises(xml.dom.HierarchyRequestErr, dom.replaceChild, text,
+ elem)
+
+ nodemap = elem.attributes
+ self.assertRaises(xml.dom.HierarchyRequestErr, nodemap.setNamedItem,
+ text)
+ self.assertRaises(xml.dom.HierarchyRequestErr, nodemap.setNamedItemNS,
+ text)
+
+ elem.appendChild(text)
+ dom.unlink()
+
+ def testNamedNodeMapSetItem(self):
+ dom = Document()
+ elem = dom.createElement('element')
+ attrs = elem.attributes
+ attrs["foo"] = "bar"
+ a = attrs.item(0)
+ self.confirm(a.ownerDocument is dom,
+ "NamedNodeMap.__setitem__() sets ownerDocument")
+ self.confirm(a.ownerElement is elem,
+ "NamedNodeMap.__setitem__() sets ownerElement")
+ self.confirm(a.value == "bar",
+ "NamedNodeMap.__setitem__() sets value")
+ self.confirm(a.nodeValue == "bar",
+ "NamedNodeMap.__setitem__() sets nodeValue")
+ elem.unlink()
+ dom.unlink()
+
+ def testNonZero(self):
+ dom = parse(tstfile)
+ self.confirm(dom)# should not be zero
+ dom.appendChild(dom.createComment("foo"))
+ self.confirm(not dom.childNodes[-1].childNodes)
+ dom.unlink()
+
+ def testUnlink(self):
+ dom = parse(tstfile)
+ dom.unlink()
+
+ def testElement(self):
+ dom = Document()
+ dom.appendChild(dom.createElement("abc"))
+ self.confirm(dom.documentElement)
+ dom.unlink()
+
+ def testAAA(self):
+ dom = parseString("<abc/>")
+ el = dom.documentElement
+ el.setAttribute("spam", "jam2")
+ self.confirm(el.toxml() == '<abc spam="jam2"/>', "testAAA")
+ a = el.getAttributeNode("spam")
+ self.confirm(a.ownerDocument is dom,
+ "setAttribute() sets ownerDocument")
+ self.confirm(a.ownerElement is dom.documentElement,
+ "setAttribute() sets ownerElement")
+ dom.unlink()
+
+ def testAAB(self):
+ dom = parseString("<abc/>")
+ el = dom.documentElement
+ el.setAttribute("spam", "jam")
+ el.setAttribute("spam", "jam2")
+ self.confirm(el.toxml() == '<abc spam="jam2"/>', "testAAB")
+ dom.unlink()
+
+ def testAddAttr(self):
+ dom = Document()
+ child = dom.appendChild(dom.createElement("abc"))
+
+ child.setAttribute("def", "ghi")
+ self.confirm(child.getAttribute("def") == "ghi")
+ self.confirm(child.attributes["def"].value == "ghi")
+
+ child.setAttribute("jkl", "mno")
+ self.confirm(child.getAttribute("jkl") == "mno")
+ self.confirm(child.attributes["jkl"].value == "mno")
+
+ self.confirm(len(child.attributes) == 2)
+
+ child.setAttribute("def", "newval")
+ self.confirm(child.getAttribute("def") == "newval")
+ self.confirm(child.attributes["def"].value == "newval")
+
+ self.confirm(len(child.attributes) == 2)
+ dom.unlink()
+
+ def testDeleteAttr(self):
+ dom = Document()
+ child = dom.appendChild(dom.createElement("abc"))
+
+ self.confirm(len(child.attributes) == 0)
+ child.setAttribute("def", "ghi")
+ self.confirm(len(child.attributes) == 1)
+ del child.attributes["def"]
+ self.confirm(len(child.attributes) == 0)
+ dom.unlink()
+
+ def testRemoveAttr(self):
+ dom = Document()
+ child = dom.appendChild(dom.createElement("abc"))
+
+ child.setAttribute("def", "ghi")
+ self.confirm(len(child.attributes) == 1)
+ child.removeAttribute("def")
+ self.confirm(len(child.attributes) == 0)
+ dom.unlink()
+
+ def testRemoveAttrNS(self):
+ dom = Document()
+ child = dom.appendChild(
+ dom.createElementNS("http://www.python.org", "python:abc"))
+ child.setAttributeNS("http://www.w3.org", "xmlns:python",
+ "http://www.python.org")
+ child.setAttributeNS("http://www.python.org", "python:abcattr", "foo")
+ self.confirm(len(child.attributes) == 2)
+ child.removeAttributeNS("http://www.python.org", "abcattr")
+ self.confirm(len(child.attributes) == 1)
+ dom.unlink()
+
+ def testRemoveAttributeNode(self):
+ dom = Document()
+ child = dom.appendChild(dom.createElement("foo"))
+ child.setAttribute("spam", "jam")
+ self.confirm(len(child.attributes) == 1)
+ node = child.getAttributeNode("spam")
+ child.removeAttributeNode(node)
+ self.confirm(len(child.attributes) == 0
+ and child.getAttributeNode("spam") is None)
+ dom.unlink()
+
+ def testChangeAttr(self):
+ dom = parseString("<abc/>")
+ el = dom.documentElement
+ el.setAttribute("spam", "jam")
+ self.confirm(len(el.attributes) == 1)
+ el.setAttribute("spam", "bam")
+ # Set this attribute to be an ID and make sure that doesn't change
+ # when changing the value:
+ el.setIdAttribute("spam")
+ self.confirm(len(el.attributes) == 1
+ and el.attributes["spam"].value == "bam"
+ and el.attributes["spam"].nodeValue == "bam"
+ and el.getAttribute("spam") == "bam"
+ and el.getAttributeNode("spam").isId)
+ el.attributes["spam"] = "ham"
+ self.confirm(len(el.attributes) == 1
+ and el.attributes["spam"].value == "ham"
+ and el.attributes["spam"].nodeValue == "ham"
+ and el.getAttribute("spam") == "ham"
+ and el.attributes["spam"].isId)
+ el.setAttribute("spam2", "bam")
+ self.confirm(len(el.attributes) == 2
+ and el.attributes["spam"].value == "ham"
+ and el.attributes["spam"].nodeValue == "ham"
+ and el.getAttribute("spam") == "ham"
+ and el.attributes["spam2"].value == "bam"
+ and el.attributes["spam2"].nodeValue == "bam"
+ and el.getAttribute("spam2") == "bam")
+ el.attributes["spam2"] = "bam2"
+ self.confirm(len(el.attributes) == 2
+ and el.attributes["spam"].value == "ham"
+ and el.attributes["spam"].nodeValue == "ham"
+ and el.getAttribute("spam") == "ham"
+ and el.attributes["spam2"].value == "bam2"
+ and el.attributes["spam2"].nodeValue == "bam2"
+ and el.getAttribute("spam2") == "bam2")
+ dom.unlink()
+
+ def testGetAttrList(self):
pass
- else:
- print("expected NotSupportedErr when renaming comment node")
- doc.unlink()
-
-def checkWholeText(node, s):
- t = node.wholeText
- confirm(t == s, "looking for %s, found %s" % (repr(s), repr(t)))
-
-def testWholeText():
- doc = parseString("<doc>a</doc>")
- elem = doc.documentElement
- text = elem.childNodes[0]
- assert text.nodeType == Node.TEXT_NODE
-
- checkWholeText(text, "a")
- elem.appendChild(doc.createTextNode("b"))
- checkWholeText(text, "ab")
- elem.insertBefore(doc.createCDATASection("c"), text)
- checkWholeText(text, "cab")
-
- # make sure we don't cross other nodes
- splitter = doc.createComment("comment")
- elem.appendChild(splitter)
- text2 = doc.createTextNode("d")
- elem.appendChild(text2)
- checkWholeText(text, "cab")
- checkWholeText(text2, "d")
-
- x = doc.createElement("x")
- elem.replaceChild(x, splitter)
- splitter = x
- checkWholeText(text, "cab")
- checkWholeText(text2, "d")
-
- x = doc.createProcessingInstruction("y", "z")
- elem.replaceChild(x, splitter)
- splitter = x
- checkWholeText(text, "cab")
- checkWholeText(text2, "d")
-
- elem.removeChild(splitter)
- checkWholeText(text, "cabd")
- checkWholeText(text2, "cabd")
-
-def testPatch1094164 ():
- doc = parseString("<doc><e/></doc>")
- elem = doc.documentElement
- e = elem.firstChild
- confirm(e.parentNode is elem, "Before replaceChild()")
- # Check that replacing a child with itself leaves the tree unchanged
- elem.replaceChild(e, e)
- confirm(e.parentNode is elem, "After replaceChild()")
-
-
-
-def testReplaceWholeText():
- def setup():
- doc = parseString("<doc>a<e/>d</doc>")
+
+ def testGetAttrValues(self): pass
+
+ def testGetAttrLength(self): pass
+
+ def testGetAttribute(self): pass
+
+ def testGetAttributeNS(self): pass
+
+ def testGetAttributeNode(self): pass
+
+ def testGetElementsByTagNameNS(self):
+ d="""<foo xmlns:minidom='http://pyxml.sf.net/minidom'>
+ <minidom:myelem/>
+ </foo>"""
+ dom = parseString(d)
+ elems = dom.getElementsByTagNameNS("http://pyxml.sf.net/minidom",
+ "myelem")
+ self.confirm(len(elems) == 1
+ and elems[0].namespaceURI == "http://pyxml.sf.net/minidom"
+ and elems[0].localName == "myelem"
+ and elems[0].prefix == "minidom"
+ and elems[0].tagName == "minidom:myelem"
+ and elems[0].nodeName == "minidom:myelem")
+ dom.unlink()
+
+ def get_empty_nodelist_from_elements_by_tagName_ns_helper(self, doc, nsuri,
+ lname):
+ nodelist = doc.getElementsByTagNameNS(nsuri, lname)
+ self.confirm(len(nodelist) == 0)
+
+ def testGetEmptyNodeListFromElementsByTagNameNS(self):
+ doc = parseString('<doc/>')
+ self.get_empty_nodelist_from_elements_by_tagName_ns_helper(
+ doc, 'http://xml.python.org/namespaces/a', 'localname')
+ self.get_empty_nodelist_from_elements_by_tagName_ns_helper(
+ doc, '*', 'splat')
+ self.get_empty_nodelist_from_elements_by_tagName_ns_helper(
+ doc, 'http://xml.python.org/namespaces/a', '*')
+
+ doc = parseString('<doc xmlns="http://xml.python.org/splat"><e/></doc>')
+ self.get_empty_nodelist_from_elements_by_tagName_ns_helper(
+ doc, "http://xml.python.org/splat", "not-there")
+ self.get_empty_nodelist_from_elements_by_tagName_ns_helper(
+ doc, "*", "not-there")
+ self.get_empty_nodelist_from_elements_by_tagName_ns_helper(
+ doc, "http://somewhere.else.net/not-there", "e")
+
+ def testElementReprAndStr(self):
+ dom = Document()
+ el = dom.appendChild(dom.createElement("abc"))
+ string1 = repr(el)
+ string2 = str(el)
+ self.confirm(string1 == string2)
+ dom.unlink()
+
+ def testElementReprAndStrUnicode(self):
+ dom = Document()
+ el = dom.appendChild(dom.createElement(u"abc"))
+ string1 = repr(el)
+ string2 = str(el)
+ self.confirm(string1 == string2)
+ dom.unlink()
+
+ def testElementReprAndStrUnicodeNS(self):
+ dom = Document()
+ el = dom.appendChild(
+ dom.createElementNS(u"http://www.slashdot.org", u"slash:abc"))
+ string1 = repr(el)
+ string2 = str(el)
+ self.confirm(string1 == string2)
+ self.confirm(string1.find("slash:abc") != -1)
+ dom.unlink()
+
+ def testAttributeRepr(self):
+ dom = Document()
+ el = dom.appendChild(dom.createElement(u"abc"))
+ node = el.setAttribute("abc", "def")
+ self.confirm(str(node) == repr(node))
+ dom.unlink()
+
+ def testTextNodeRepr(self): pass
+
+ def testWriteXML(self):
+ str = '<?xml version="1.0" ?><a b="c"/>'
+ dom = parseString(str)
+ domstr = dom.toxml()
+ dom.unlink()
+ self.confirm(str == domstr)
+
+ def testAltNewline(self):
+ str = '<?xml version="1.0" ?>\n<a b="c"/>\n'
+ dom = parseString(str)
+ domstr = dom.toprettyxml(newl="\r\n")
+ dom.unlink()
+ self.confirm(domstr == str.replace("\n", "\r\n"))
+
+ def testProcessingInstruction(self):
+ dom = parseString('<e><?mypi \t\n data \t\n ?></e>')
+ pi = dom.documentElement.firstChild
+ self.confirm(pi.target == "mypi"
+ and pi.data == "data \t\n "
+ and pi.nodeName == "mypi"
+ and pi.nodeType == Node.PROCESSING_INSTRUCTION_NODE
+ and pi.attributes is None
+ and not pi.hasChildNodes()
+ and len(pi.childNodes) == 0
+ and pi.firstChild is None
+ and pi.lastChild is None
+ and pi.localName is None
+ and pi.namespaceURI == xml.dom.EMPTY_NAMESPACE)
+
+ def testProcessingInstructionRepr(self): pass
+
+ def testTextRepr(self): pass
+
+ def testWriteText(self): pass
+
+ def testDocumentElement(self): pass
+
+ def testTooManyDocumentElements(self):
+ doc = parseString("<doc/>")
+ elem = doc.createElement("extra")
+ # Should raise an exception when adding an extra document element.
+ self.assertRaises(xml.dom.HierarchyRequestErr, doc.appendChild, elem)
+ elem.unlink()
+ doc.unlink()
+
+ def testCreateElementNS(self): pass
+
+ def testCreateAttributeNS(self): pass
+
+ def testParse(self): pass
+
+ def testParseString(self): pass
+
+ def testComment(self): pass
+
+ def testAttrListItem(self): pass
+
+ def testAttrListItems(self): pass
+
+ def testAttrListItemNS(self): pass
+
+ def testAttrListKeys(self): pass
+
+ def testAttrListKeysNS(self): pass
+
+ def testRemoveNamedItem(self):
+ doc = parseString("<doc a=''/>")
+ e = doc.documentElement
+ attrs = e.attributes
+ a1 = e.getAttributeNode("a")
+ a2 = attrs.removeNamedItem("a")
+ self.confirm(a1.isSameNode(a2))
+ self.assertRaises(xml.dom.NotFoundErr, attrs.removeNamedItem, "a")
+
+ def testRemoveNamedItemNS(self):
+ doc = parseString("<doc xmlns:a='http://xml.python.org/' a:b=''/>")
+ e = doc.documentElement
+ attrs = e.attributes
+ a1 = e.getAttributeNodeNS("http://xml.python.org/", "b")
+ a2 = attrs.removeNamedItemNS("http://xml.python.org/", "b")
+ self.confirm(a1.isSameNode(a2))
+ self.assertRaises(xml.dom.NotFoundErr, attrs.removeNamedItemNS,
+ "http://xml.python.org/", "b")
+
+ def testAttrListValues(self): pass
+
+ def testAttrListLength(self): pass
+
+ def testAttrList__getitem__(self): pass
+
+ def testAttrList__setitem__(self): pass
+
+ def testSetAttrValueandNodeValue(self): pass
+
+ def testParseElement(self): pass
+
+ def testParseAttributes(self): pass
+
+ def testParseElementNamespaces(self): pass
+
+ def testParseAttributeNamespaces(self): pass
+
+ def testParseProcessingInstructions(self): pass
+
+ def testChildNodes(self): pass
+
+ def testFirstChild(self): pass
+
+ def testHasChildNodes(self): pass
+
+ def _testCloneElementCopiesAttributes(self, e1, e2, test):
+ attrs1 = e1.attributes
+ attrs2 = e2.attributes
+ keys1 = list(attrs1.keys())
+ keys2 = list(attrs2.keys())
+ keys1.sort()
+ keys2.sort()
+ self.confirm(keys1 == keys2, "clone of element has same attribute keys")
+ for i in range(len(keys1)):
+ a1 = attrs1.item(i)
+ a2 = attrs2.item(i)
+ self.confirm(a1 is not a2
+ and a1.value == a2.value
+ and a1.nodeValue == a2.nodeValue
+ and a1.namespaceURI == a2.namespaceURI
+ and a1.localName == a2.localName
+ , "clone of attribute node has proper attribute values")
+ self.confirm(a2.ownerElement is e2,
+ "clone of attribute node correctly owned")
+
+ def _setupCloneElement(self, deep):
+ dom = parseString("<doc attr='value'><foo/></doc>")
+ root = dom.documentElement
+ clone = root.cloneNode(deep)
+ self._testCloneElementCopiesAttributes(
+ root, clone, "testCloneElement" + (deep and "Deep" or "Shallow"))
+ # mutilate the original so shared data is detected
+ root.tagName = root.nodeName = "MODIFIED"
+ root.setAttribute("attr", "NEW VALUE")
+ root.setAttribute("added", "VALUE")
+ return dom, clone
+
+ def testCloneElementShallow(self):
+ dom, clone = self._setupCloneElement(0)
+ self.confirm(len(clone.childNodes) == 0
+ and clone.childNodes.length == 0
+ and clone.parentNode is None
+ and clone.toxml() == '<doc attr="value"/>'
+ , "testCloneElementShallow")
+ dom.unlink()
+
+ def testCloneElementDeep(self):
+ dom, clone = self._setupCloneElement(1)
+ self.confirm(len(clone.childNodes) == 1
+ and clone.childNodes.length == 1
+ and clone.parentNode is None
+ and clone.toxml() == '<doc attr="value"><foo/></doc>'
+ , "testCloneElementDeep")
+ dom.unlink()
+
+ def testCloneDocumentShallow(self):
+ doc = parseString("<?xml version='1.0'?>\n"
+ "<!-- comment -->"
+ "<!DOCTYPE doc [\n"
+ "<!NOTATION notation SYSTEM 'http://xml.python.org/'>\n"
+ "]>\n"
+ "<doc attr='value'/>")
+ doc2 = doc.cloneNode(0)
+ self.confirm(doc2 is None,
+ "testCloneDocumentShallow:"
+ " shallow cloning of documents makes no sense!")
+
+ def testCloneDocumentDeep(self):
+ doc = parseString("<?xml version='1.0'?>\n"
+ "<!-- comment -->"
+ "<!DOCTYPE doc [\n"
+ "<!NOTATION notation SYSTEM 'http://xml.python.org/'>\n"
+ "]>\n"
+ "<doc attr='value'/>")
+ doc2 = doc.cloneNode(1)
+ self.confirm(not (doc.isSameNode(doc2) or doc2.isSameNode(doc)),
+ "testCloneDocumentDeep: document objects not distinct")
+ self.confirm(len(doc.childNodes) == len(doc2.childNodes),
+ "testCloneDocumentDeep: wrong number of Document children")
+ self.confirm(doc2.documentElement.nodeType == Node.ELEMENT_NODE,
+ "testCloneDocumentDeep: documentElement not an ELEMENT_NODE")
+ self.confirm(doc2.documentElement.ownerDocument.isSameNode(doc2),
+ "testCloneDocumentDeep: documentElement owner is not new document")
+ self.confirm(not doc.documentElement.isSameNode(doc2.documentElement),
+ "testCloneDocumentDeep: documentElement should not be shared")
+ if doc.doctype is not None:
+ # check the doctype iff the original DOM maintained it
+ self.confirm(doc2.doctype.nodeType == Node.DOCUMENT_TYPE_NODE,
+ "testCloneDocumentDeep: doctype not a DOCUMENT_TYPE_NODE")
+ self.confirm(doc2.doctype.ownerDocument.isSameNode(doc2))
+ self.confirm(not doc.doctype.isSameNode(doc2.doctype))
+
+ def testCloneDocumentTypeDeepOk(self):
+ doctype = create_nonempty_doctype()
+ clone = doctype.cloneNode(1)
+ self.confirm(clone is not None
+ and clone.nodeName == doctype.nodeName
+ and clone.name == doctype.name
+ and clone.publicId == doctype.publicId
+ and clone.systemId == doctype.systemId
+ and len(clone.entities) == len(doctype.entities)
+ and clone.entities.item(len(clone.entities)) is None
+ and len(clone.notations) == len(doctype.notations)
+ and clone.notations.item(len(clone.notations)) is None
+ and len(clone.childNodes) == 0)
+ for i in range(len(doctype.entities)):
+ se = doctype.entities.item(i)
+ ce = clone.entities.item(i)
+ self.confirm((not se.isSameNode(ce))
+ and (not ce.isSameNode(se))
+ and ce.nodeName == se.nodeName
+ and ce.notationName == se.notationName
+ and ce.publicId == se.publicId
+ and ce.systemId == se.systemId
+ and ce.encoding == se.encoding
+ and ce.actualEncoding == se.actualEncoding
+ and ce.version == se.version)
+ for i in range(len(doctype.notations)):
+ sn = doctype.notations.item(i)
+ cn = clone.notations.item(i)
+ self.confirm((not sn.isSameNode(cn))
+ and (not cn.isSameNode(sn))
+ and cn.nodeName == sn.nodeName
+ and cn.publicId == sn.publicId
+ and cn.systemId == sn.systemId)
+
+ def testCloneDocumentTypeDeepNotOk(self):
+ doc = create_doc_with_doctype()
+ clone = doc.doctype.cloneNode(1)
+ self.confirm(clone is None, "testCloneDocumentTypeDeepNotOk")
+
+ def testCloneDocumentTypeShallowOk(self):
+ doctype = create_nonempty_doctype()
+ clone = doctype.cloneNode(0)
+ self.confirm(clone is not None
+ and clone.nodeName == doctype.nodeName
+ and clone.name == doctype.name
+ and clone.publicId == doctype.publicId
+ and clone.systemId == doctype.systemId
+ and len(clone.entities) == 0
+ and clone.entities.item(0) is None
+ and len(clone.notations) == 0
+ and clone.notations.item(0) is None
+ and len(clone.childNodes) == 0)
+
+ def testCloneDocumentTypeShallowNotOk(self):
+ doc = create_doc_with_doctype()
+ clone = doc.doctype.cloneNode(0)
+ self.confirm(clone is None, "testCloneDocumentTypeShallowNotOk")
+
+ def check_import_document(self, deep, testName):
+ doc1 = parseString("<doc/>")
+ doc2 = parseString("<doc/>")
+ self.assertRaises(xml.dom.NotSupportedErr, doc1.importNode, doc2, deep)
+
+ def testImportDocumentShallow(self):
+ self.check_import_document(0, "testImportDocumentShallow")
+
+ def testImportDocumentDeep(self):
+ self.check_import_document(1, "testImportDocumentDeep")
+
+ def testImportDocumentTypeShallow(self):
+ src = create_doc_with_doctype()
+ target = create_doc_without_doctype()
+ self.assertRaises(xml.dom.NotSupportedErr, target.importNode,
+ src.doctype, 0)
+
+ def testImportDocumentTypeDeep(self):
+ src = create_doc_with_doctype()
+ target = create_doc_without_doctype()
+ self.assertRaises(xml.dom.NotSupportedErr, target.importNode,
+ src.doctype, 1)
+
+ # Testing attribute clones uses a helper, and should always be deep,
+ # even if the argument to cloneNode is false.
+ def check_clone_attribute(self, deep, testName):
+ doc = parseString("<doc attr='value'/>")
+ attr = doc.documentElement.getAttributeNode("attr")
+ self.failIfEqual(attr, None)
+ clone = attr.cloneNode(deep)
+ self.confirm(not clone.isSameNode(attr))
+ self.confirm(not attr.isSameNode(clone))
+ self.confirm(clone.ownerElement is None,
+ testName + ": ownerElement should be None")
+ self.confirm(clone.ownerDocument.isSameNode(attr.ownerDocument),
+ testName + ": ownerDocument does not match")
+ self.confirm(clone.specified,
+ testName + ": cloned attribute must have specified == True")
+
+ def testCloneAttributeShallow(self):
+ self.check_clone_attribute(0, "testCloneAttributeShallow")
+
+ def testCloneAttributeDeep(self):
+ self.check_clone_attribute(1, "testCloneAttributeDeep")
+
+ def check_clone_pi(self, deep, testName):
+ doc = parseString("<?target data?><doc/>")
+ pi = doc.firstChild
+ self.assertEquals(pi.nodeType, Node.PROCESSING_INSTRUCTION_NODE)
+ clone = pi.cloneNode(deep)
+ self.confirm(clone.target == pi.target
+ and clone.data == pi.data)
+
+ def testClonePIShallow(self):
+ self.check_clone_pi(0, "testClonePIShallow")
+
+ def testClonePIDeep(self):
+ self.check_clone_pi(1, "testClonePIDeep")
+
+ def testNormalize(self):
+ doc = parseString("<doc/>")
+ root = doc.documentElement
+ root.appendChild(doc.createTextNode("first"))
+ root.appendChild(doc.createTextNode("second"))
+ self.confirm(len(root.childNodes) == 2
+ and root.childNodes.length == 2,
+ "testNormalize -- preparation")
+ doc.normalize()
+ self.confirm(len(root.childNodes) == 1
+ and root.childNodes.length == 1
+ and root.firstChild is root.lastChild
+ and root.firstChild.data == "firstsecond"
+ , "testNormalize -- result")
+ doc.unlink()
+
+ doc = parseString("<doc/>")
+ root = doc.documentElement
+ root.appendChild(doc.createTextNode(""))
+ doc.normalize()
+ self.confirm(len(root.childNodes) == 0
+ and root.childNodes.length == 0,
+ "testNormalize -- single empty node removed")
+ doc.unlink()
+
+ def testSiblings(self):
+ doc = parseString("<doc><?pi?>text?<elm/></doc>")
+ root = doc.documentElement
+ (pi, text, elm) = root.childNodes
+
+ self.confirm(pi.nextSibling is text and
+ pi.previousSibling is None and
+ text.nextSibling is elm and
+ text.previousSibling is pi and
+ elm.nextSibling is None and
+ elm.previousSibling is text, "testSiblings")
+
+ doc.unlink()
+
+ def testParents(self):
+ doc = parseString(
+ "<doc><elm1><elm2/><elm2><elm3/></elm2></elm1></doc>")
+ root = doc.documentElement
+ elm1 = root.childNodes[0]
+ (elm2a, elm2b) = elm1.childNodes
+ elm3 = elm2b.childNodes[0]
+
+ self.confirm(root.parentNode is doc and
+ elm1.parentNode is root and
+ elm2a.parentNode is elm1 and
+ elm2b.parentNode is elm1 and
+ elm3.parentNode is elm2b, "testParents")
+ doc.unlink()
+
+ def testNodeListItem(self):
+ doc = parseString("<doc><e/><e/></doc>")
+ children = doc.childNodes
+ docelem = children[0]
+ self.confirm(children[0] is children.item(0)
+ and children.item(1) is None
+ and docelem.childNodes.item(0) is docelem.childNodes[0]
+ and docelem.childNodes.item(1) is docelem.childNodes[1]
+ and docelem.childNodes.item(0).childNodes.item(0) is None,
+ "test NodeList.item()")
+ doc.unlink()
+
+ def testSAX2DOM(self):
+ from xml.dom import pulldom
+
+ sax2dom = pulldom.SAX2DOM()
+ sax2dom.startDocument()
+ sax2dom.startElement("doc", {})
+ sax2dom.characters("text")
+ sax2dom.startElement("subelm", {})
+ sax2dom.characters("text")
+ sax2dom.endElement("subelm")
+ sax2dom.characters("text")
+ sax2dom.endElement("doc")
+ sax2dom.endDocument()
+
+ doc = sax2dom.document
+ root = doc.documentElement
+ (text1, elm1, text2) = root.childNodes
+ text3 = elm1.childNodes[0]
+
+ self.confirm(text1.previousSibling is None and
+ text1.nextSibling is elm1 and
+ elm1.previousSibling is text1 and
+ elm1.nextSibling is text2 and
+ text2.previousSibling is elm1 and
+ text2.nextSibling is None and
+ text3.previousSibling is None and
+ text3.nextSibling is None, "testSAX2DOM - siblings")
+
+ self.confirm(root.parentNode is doc and
+ text1.parentNode is root and
+ elm1.parentNode is root and
+ text2.parentNode is root and
+ text3.parentNode is elm1, "testSAX2DOM - parents")
+ doc.unlink()
+
+ def testEncodings(self):
+ doc = parseString('<foo>&#x20ac;</foo>')
+ self.confirm(doc.toxml() == u'<?xml version="1.0" ?><foo>\u20ac</foo>'
+ and doc.toxml('utf-8') ==
+ '<?xml version="1.0" encoding="utf-8"?><foo>\xe2\x82\xac</foo>'
+ and doc.toxml('iso-8859-15') ==
+ '<?xml version="1.0" encoding="iso-8859-15"?><foo>\xa4</foo>',
+ "testEncodings - encoding EURO SIGN")
+
+ # Verify that character decoding errors throw exceptions instead
+ # of crashing
+ self.assertRaises(UnicodeDecodeError, parseString,
+ '<fran\xe7ais>Comment \xe7a va ? Tr\xe8s bien ?</fran\xe7ais>')
+
+ doc.unlink()
+
+ class UserDataHandler:
+ called = 0
+ def handle(self, operation, key, data, src, dst):
+ dst.setUserData(key, data + 1, self)
+ src.setUserData(key, None, None)
+ self.called = 1
+
+ def testUserData(self):
+ dom = Document()
+ n = dom.createElement('e')
+ self.confirm(n.getUserData("foo") is None)
+ n.setUserData("foo", None, None)
+ self.confirm(n.getUserData("foo") is None)
+ n.setUserData("foo", 12, 12)
+ n.setUserData("bar", 13, 13)
+ self.confirm(n.getUserData("foo") == 12)
+ self.confirm(n.getUserData("bar") == 13)
+ n.setUserData("foo", None, None)
+ self.confirm(n.getUserData("foo") is None)
+ self.confirm(n.getUserData("bar") == 13)
+
+ handler = self.UserDataHandler()
+ n.setUserData("bar", 12, handler)
+ c = n.cloneNode(1)
+ self.confirm(handler.called
+ and n.getUserData("bar") is None
+ and c.getUserData("bar") == 13)
+ n.unlink()
+ c.unlink()
+ dom.unlink()
+
+ def checkRenameNodeSharedConstraints(self, doc, node):
+ # Make sure illegal NS usage is detected:
+ self.assertRaises(xml.dom.NamespaceErr, doc.renameNode, node,
+ "http://xml.python.org/ns", "xmlns:foo")
+ doc2 = parseString("<doc/>")
+ self.assertRaises(xml.dom.WrongDocumentErr, doc2.renameNode, node,
+ xml.dom.EMPTY_NAMESPACE, "foo")
+
+ def testRenameAttribute(self):
+ doc = parseString("<doc a='v'/>")
+ elem = doc.documentElement
+ attrmap = elem.attributes
+ attr = elem.attributes['a']
+
+ # Simple renaming
+ attr = doc.renameNode(attr, xml.dom.EMPTY_NAMESPACE, "b")
+ self.confirm(attr.name == "b"
+ and attr.nodeName == "b"
+ and attr.localName is None
+ and attr.namespaceURI == xml.dom.EMPTY_NAMESPACE
+ and attr.prefix is None
+ and attr.value == "v"
+ and elem.getAttributeNode("a") is None
+ and elem.getAttributeNode("b").isSameNode(attr)
+ and attrmap["b"].isSameNode(attr)
+ and attr.ownerDocument.isSameNode(doc)
+ and attr.ownerElement.isSameNode(elem))
+
+ # Rename to have a namespace, no prefix
+ attr = doc.renameNode(attr, "http://xml.python.org/ns", "c")
+ self.confirm(attr.name == "c"
+ and attr.nodeName == "c"
+ and attr.localName == "c"
+ and attr.namespaceURI == "http://xml.python.org/ns"
+ and attr.prefix is None
+ and attr.value == "v"
+ and elem.getAttributeNode("a") is None
+ and elem.getAttributeNode("b") is None
+ and elem.getAttributeNode("c").isSameNode(attr)
+ and elem.getAttributeNodeNS(
+ "http://xml.python.org/ns", "c").isSameNode(attr)
+ and attrmap["c"].isSameNode(attr)
+ and attrmap[("http://xml.python.org/ns", "c")].isSameNode(attr))
+
+ # Rename to have a namespace, with prefix
+ attr = doc.renameNode(attr, "http://xml.python.org/ns2", "p:d")
+ self.confirm(attr.name == "p:d"
+ and attr.nodeName == "p:d"
+ and attr.localName == "d"
+ and attr.namespaceURI == "http://xml.python.org/ns2"
+ and attr.prefix == "p"
+ and attr.value == "v"
+ and elem.getAttributeNode("a") is None
+ and elem.getAttributeNode("b") is None
+ and elem.getAttributeNode("c") is None
+ and elem.getAttributeNodeNS(
+ "http://xml.python.org/ns", "c") is None
+ and elem.getAttributeNode("p:d").isSameNode(attr)
+ and elem.getAttributeNodeNS(
+ "http://xml.python.org/ns2", "d").isSameNode(attr)
+ and attrmap["p:d"].isSameNode(attr)
+ and attrmap[("http://xml.python.org/ns2", "d")].isSameNode(attr))
+
+ # Rename back to a simple non-NS node
+ attr = doc.renameNode(attr, xml.dom.EMPTY_NAMESPACE, "e")
+ self.confirm(attr.name == "e"
+ and attr.nodeName == "e"
+ and attr.localName is None
+ and attr.namespaceURI == xml.dom.EMPTY_NAMESPACE
+ and attr.prefix is None
+ and attr.value == "v"
+ and elem.getAttributeNode("a") is None
+ and elem.getAttributeNode("b") is None
+ and elem.getAttributeNode("c") is None
+ and elem.getAttributeNode("p:d") is None
+ and elem.getAttributeNodeNS(
+ "http://xml.python.org/ns", "c") is None
+ and elem.getAttributeNode("e").isSameNode(attr)
+ and attrmap["e"].isSameNode(attr))
+
+ self.assertRaises(xml.dom.NamespaceErr, doc.renameNode, attr,
+ "http://xml.python.org/ns", "xmlns")
+ self.checkRenameNodeSharedConstraints(doc, attr)
+ doc.unlink()
+
+ def testRenameElement(self):
+ doc = parseString("<doc/>")
elem = doc.documentElement
- text1 = elem.firstChild
- text2 = elem.lastChild
- splitter = text1.nextSibling
- elem.insertBefore(doc.createTextNode("b"), splitter)
- elem.insertBefore(doc.createCDATASection("c"), text1)
- return doc, elem, text1, splitter, text2
-
- doc, elem, text1, splitter, text2 = setup()
- text = text1.replaceWholeText("new content")
- checkWholeText(text, "new content")
- checkWholeText(text2, "d")
- confirm(len(elem.childNodes) == 3)
-
- doc, elem, text1, splitter, text2 = setup()
- text = text2.replaceWholeText("new content")
- checkWholeText(text, "new content")
- checkWholeText(text1, "cab")
- confirm(len(elem.childNodes) == 5)
-
- doc, elem, text1, splitter, text2 = setup()
- text = text1.replaceWholeText("")
- checkWholeText(text2, "d")
- confirm(text is None
- and len(elem.childNodes) == 2)
-
-def testSchemaType():
- doc = parseString(
- "<!DOCTYPE doc [\n"
- " <!ENTITY e1 SYSTEM 'http://xml.python.org/e1'>\n"
- " <!ENTITY e2 SYSTEM 'http://xml.python.org/e2'>\n"
- " <!ATTLIST doc id ID #IMPLIED \n"
- " ref IDREF #IMPLIED \n"
- " refs IDREFS #IMPLIED \n"
- " enum (a|b) #IMPLIED \n"
- " ent ENTITY #IMPLIED \n"
- " ents ENTITIES #IMPLIED \n"
- " nm NMTOKEN #IMPLIED \n"
- " nms NMTOKENS #IMPLIED \n"
- " text CDATA #IMPLIED \n"
- " >\n"
- "]><doc id='name' notid='name' text='splat!' enum='b'"
- " ref='name' refs='name name' ent='e1' ents='e1 e2'"
- " nm='123' nms='123 abc' />")
- elem = doc.documentElement
- # We don't want to rely on any specific loader at this point, so
- # just make sure we can get to all the names, and that the
- # DTD-based namespace is right. The names can vary by loader
- # since each supports a different level of DTD information.
- t = elem.schemaType
- confirm(t.name is None
- and t.namespace == xml.dom.EMPTY_NAMESPACE)
- names = "id notid text enum ref refs ent ents nm nms".split()
- for name in names:
- a = elem.getAttributeNode(name)
- t = a.schemaType
- confirm(hasattr(t, "name")
- and t.namespace == xml.dom.EMPTY_NAMESPACE)
-def testSetIdAttribute():
- doc = parseString("<doc a1='v' a2='w'/>")
- e = doc.documentElement
- a1 = e.getAttributeNode("a1")
- a2 = e.getAttributeNode("a2")
- confirm(doc.getElementById("v") is None
- and not a1.isId
- and not a2.isId)
- e.setIdAttribute("a1")
- confirm(e.isSameNode(doc.getElementById("v"))
- and a1.isId
- and not a2.isId)
- e.setIdAttribute("a2")
- confirm(e.isSameNode(doc.getElementById("v"))
- and e.isSameNode(doc.getElementById("w"))
- and a1.isId
- and a2.isId)
- # replace the a1 node; the new node should *not* be an ID
- a3 = doc.createAttribute("a1")
- a3.value = "v"
- e.setAttributeNode(a3)
- confirm(doc.getElementById("v") is None
- and e.isSameNode(doc.getElementById("w"))
- and not a1.isId
- and a2.isId
- and not a3.isId)
- # renaming an attribute should not affect its ID-ness:
- doc.renameNode(a2, xml.dom.EMPTY_NAMESPACE, "an")
- confirm(e.isSameNode(doc.getElementById("w"))
- and a2.isId)
-
-def testSetIdAttributeNS():
- NS1 = "http://xml.python.org/ns1"
- NS2 = "http://xml.python.org/ns2"
- doc = parseString("<doc"
- " xmlns:ns1='" + NS1 + "'"
- " xmlns:ns2='" + NS2 + "'"
- " ns1:a1='v' ns2:a2='w'/>")
- e = doc.documentElement
- a1 = e.getAttributeNodeNS(NS1, "a1")
- a2 = e.getAttributeNodeNS(NS2, "a2")
- confirm(doc.getElementById("v") is None
- and not a1.isId
- and not a2.isId)
- e.setIdAttributeNS(NS1, "a1")
- confirm(e.isSameNode(doc.getElementById("v"))
- and a1.isId
- and not a2.isId)
- e.setIdAttributeNS(NS2, "a2")
- confirm(e.isSameNode(doc.getElementById("v"))
- and e.isSameNode(doc.getElementById("w"))
- and a1.isId
- and a2.isId)
- # replace the a1 node; the new node should *not* be an ID
- a3 = doc.createAttributeNS(NS1, "a1")
- a3.value = "v"
- e.setAttributeNode(a3)
- confirm(e.isSameNode(doc.getElementById("w")))
- confirm(not a1.isId)
- confirm(a2.isId)
- confirm(not a3.isId)
- confirm(doc.getElementById("v") is None)
- # renaming an attribute should not affect its ID-ness:
- doc.renameNode(a2, xml.dom.EMPTY_NAMESPACE, "an")
- confirm(e.isSameNode(doc.getElementById("w"))
- and a2.isId)
-
-def testSetIdAttributeNode():
- NS1 = "http://xml.python.org/ns1"
- NS2 = "http://xml.python.org/ns2"
- doc = parseString("<doc"
- " xmlns:ns1='" + NS1 + "'"
- " xmlns:ns2='" + NS2 + "'"
- " ns1:a1='v' ns2:a2='w'/>")
- e = doc.documentElement
- a1 = e.getAttributeNodeNS(NS1, "a1")
- a2 = e.getAttributeNodeNS(NS2, "a2")
- confirm(doc.getElementById("v") is None
- and not a1.isId
- and not a2.isId)
- e.setIdAttributeNode(a1)
- confirm(e.isSameNode(doc.getElementById("v"))
- and a1.isId
- and not a2.isId)
- e.setIdAttributeNode(a2)
- confirm(e.isSameNode(doc.getElementById("v"))
- and e.isSameNode(doc.getElementById("w"))
- and a1.isId
- and a2.isId)
- # replace the a1 node; the new node should *not* be an ID
- a3 = doc.createAttributeNS(NS1, "a1")
- a3.value = "v"
- e.setAttributeNode(a3)
- confirm(e.isSameNode(doc.getElementById("w")))
- confirm(not a1.isId)
- confirm(a2.isId)
- confirm(not a3.isId)
- confirm(doc.getElementById("v") is None)
- # renaming an attribute should not affect its ID-ness:
- doc.renameNode(a2, xml.dom.EMPTY_NAMESPACE, "an")
- confirm(e.isSameNode(doc.getElementById("w"))
- and a2.isId)
-
-def testPickledDocument():
- doc = parseString("<?xml version='1.0' encoding='us-ascii'?>\n"
- "<!DOCTYPE doc PUBLIC 'http://xml.python.org/public'"
- " 'http://xml.python.org/system' [\n"
- " <!ELEMENT e EMPTY>\n"
- " <!ENTITY ent SYSTEM 'http://xml.python.org/entity'>\n"
- "]><doc attr='value'> text\n"
- "<?pi sample?> <!-- comment --> <e/> </doc>")
- s = pickle.dumps(doc)
- doc2 = pickle.loads(s)
- stack = [(doc, doc2)]
- while stack:
- n1, n2 = stack.pop()
- confirm(n1.nodeType == n2.nodeType
- and len(n1.childNodes) == len(n2.childNodes)
- and n1.nodeName == n2.nodeName
- and not n1.isSameNode(n2)
- and not n2.isSameNode(n1))
- if n1.nodeType == Node.DOCUMENT_TYPE_NODE:
- len(n1.entities)
- len(n2.entities)
- len(n1.notations)
- len(n2.notations)
- confirm(len(n1.entities) == len(n2.entities)
- and len(n1.notations) == len(n2.notations))
- for i in range(len(n1.notations)):
- no1 = n1.notations.item(i)
- no2 = n1.notations.item(i)
- confirm(no1.name == no2.name
- and no1.publicId == no2.publicId
- and no1.systemId == no2.systemId)
- statck.append((no1, no2))
- for i in range(len(n1.entities)):
- e1 = n1.entities.item(i)
- e2 = n2.entities.item(i)
- confirm(e1.notationName == e2.notationName
- and e1.publicId == e2.publicId
- and e1.systemId == e2.systemId)
- stack.append((e1, e2))
- if n1.nodeType != Node.DOCUMENT_NODE:
- confirm(n1.ownerDocument.isSameNode(doc)
- and n2.ownerDocument.isSameNode(doc2))
- for i in range(len(n1.childNodes)):
- stack.append((n1.childNodes[i], n2.childNodes[i]))
-
-
-# --- MAIN PROGRAM
-
-names = sorted(globals().keys())
-
-failed = []
-
-try:
- Node.allnodes
-except AttributeError:
- # We don't actually have the minidom from the standard library,
- # but are picking up the PyXML version from site-packages.
- def check_allnodes():
- pass
-else:
- def check_allnodes():
- confirm(len(Node.allnodes) == 0,
- "assertion: len(Node.allnodes) == 0")
- if len(Node.allnodes):
- print("Garbage left over:")
- if verbose:
- print(Node.allnodes.items()[0:10])
- else:
- # Don't print specific nodes if repeatable results
- # are needed
- print(len(Node.allnodes))
- Node.allnodes = {}
-
-for name in names:
- if name.startswith("test"):
- func = globals()[name]
- try:
- func()
- check_allnodes()
- except:
- failed.append(name)
- print("Test Failed: ", name)
- sys.stdout.flush()
- traceback.print_exception(*sys.exc_info())
- print(repr(sys.exc_info()[1]))
- Node.allnodes = {}
+ # Simple renaming
+ elem = doc.renameNode(elem, xml.dom.EMPTY_NAMESPACE, "a")
+ self.confirm(elem.tagName == "a"
+ and elem.nodeName == "a"
+ and elem.localName is None
+ and elem.namespaceURI == xml.dom.EMPTY_NAMESPACE
+ and elem.prefix is None
+ and elem.ownerDocument.isSameNode(doc))
+
+ # Rename to have a namespace, no prefix
+ elem = doc.renameNode(elem, "http://xml.python.org/ns", "b")
+ self.confirm(elem.tagName == "b"
+ and elem.nodeName == "b"
+ and elem.localName == "b"
+ and elem.namespaceURI == "http://xml.python.org/ns"
+ and elem.prefix is None
+ and elem.ownerDocument.isSameNode(doc))
+
+ # Rename to have a namespace, with prefix
+ elem = doc.renameNode(elem, "http://xml.python.org/ns2", "p:c")
+ self.confirm(elem.tagName == "p:c"
+ and elem.nodeName == "p:c"
+ and elem.localName == "c"
+ and elem.namespaceURI == "http://xml.python.org/ns2"
+ and elem.prefix == "p"
+ and elem.ownerDocument.isSameNode(doc))
+
+ # Rename back to a simple non-NS node
+ elem = doc.renameNode(elem, xml.dom.EMPTY_NAMESPACE, "d")
+ self.confirm(elem.tagName == "d"
+ and elem.nodeName == "d"
+ and elem.localName is None
+ and elem.namespaceURI == xml.dom.EMPTY_NAMESPACE
+ and elem.prefix is None
+ and elem.ownerDocument.isSameNode(doc))
+
+ self.checkRenameNodeSharedConstraints(doc, elem)
+ doc.unlink()
+
+ def testRenameOther(self):
+ # We have to create a comment node explicitly since not all DOM
+ # builders used with minidom add comments to the DOM.
+ doc = xml.dom.minidom.getDOMImplementation().createDocument(
+ xml.dom.EMPTY_NAMESPACE, "e", None)
+ node = doc.createComment("comment")
+ self.assertRaises(xml.dom.NotSupportedErr, doc.renameNode, node,
+ xml.dom.EMPTY_NAMESPACE, "foo")
+ doc.unlink()
+
+ def testWholeText(self):
+ doc = parseString("<doc>a</doc>")
+ elem = doc.documentElement
+ text = elem.childNodes[0]
+ self.assertEquals(text.nodeType, Node.TEXT_NODE)
+
+ self.checkWholeText(text, "a")
+ elem.appendChild(doc.createTextNode("b"))
+ self.checkWholeText(text, "ab")
+ elem.insertBefore(doc.createCDATASection("c"), text)
+ self.checkWholeText(text, "cab")
+
+ # make sure we don't cross other nodes
+ splitter = doc.createComment("comment")
+ elem.appendChild(splitter)
+ text2 = doc.createTextNode("d")
+ elem.appendChild(text2)
+ self.checkWholeText(text, "cab")
+ self.checkWholeText(text2, "d")
+
+ x = doc.createElement("x")
+ elem.replaceChild(x, splitter)
+ splitter = x
+ self.checkWholeText(text, "cab")
+ self.checkWholeText(text2, "d")
+
+ x = doc.createProcessingInstruction("y", "z")
+ elem.replaceChild(x, splitter)
+ splitter = x
+ self.checkWholeText(text, "cab")
+ self.checkWholeText(text2, "d")
+
+ elem.removeChild(splitter)
+ self.checkWholeText(text, "cabd")
+ self.checkWholeText(text2, "cabd")
+
+ def testPatch1094164(self):
+ doc = parseString("<doc><e/></doc>")
+ elem = doc.documentElement
+ e = elem.firstChild
+ self.confirm(e.parentNode is elem, "Before replaceChild()")
+ # Check that replacing a child with itself leaves the tree unchanged
+ elem.replaceChild(e, e)
+ self.confirm(e.parentNode is elem, "After replaceChild()")
+
+ def testReplaceWholeText(self):
+ def setup():
+ doc = parseString("<doc>a<e/>d</doc>")
+ elem = doc.documentElement
+ text1 = elem.firstChild
+ text2 = elem.lastChild
+ splitter = text1.nextSibling
+ elem.insertBefore(doc.createTextNode("b"), splitter)
+ elem.insertBefore(doc.createCDATASection("c"), text1)
+ return doc, elem, text1, splitter, text2
+
+ doc, elem, text1, splitter, text2 = setup()
+ text = text1.replaceWholeText("new content")
+ self.checkWholeText(text, "new content")
+ self.checkWholeText(text2, "d")
+ self.confirm(len(elem.childNodes) == 3)
+
+ doc, elem, text1, splitter, text2 = setup()
+ text = text2.replaceWholeText("new content")
+ self.checkWholeText(text, "new content")
+ self.checkWholeText(text1, "cab")
+ self.confirm(len(elem.childNodes) == 5)
+
+ doc, elem, text1, splitter, text2 = setup()
+ text = text1.replaceWholeText("")
+ self.checkWholeText(text2, "d")
+ self.confirm(text is None
+ and len(elem.childNodes) == 2)
+
+ def testSchemaType(self):
+ doc = parseString(
+ "<!DOCTYPE doc [\n"
+ " <!ENTITY e1 SYSTEM 'http://xml.python.org/e1'>\n"
+ " <!ENTITY e2 SYSTEM 'http://xml.python.org/e2'>\n"
+ " <!ATTLIST doc id ID #IMPLIED \n"
+ " ref IDREF #IMPLIED \n"
+ " refs IDREFS #IMPLIED \n"
+ " enum (a|b) #IMPLIED \n"
+ " ent ENTITY #IMPLIED \n"
+ " ents ENTITIES #IMPLIED \n"
+ " nm NMTOKEN #IMPLIED \n"
+ " nms NMTOKENS #IMPLIED \n"
+ " text CDATA #IMPLIED \n"
+ " >\n"
+ "]><doc id='name' notid='name' text='splat!' enum='b'"
+ " ref='name' refs='name name' ent='e1' ents='e1 e2'"
+ " nm='123' nms='123 abc' />")
+ elem = doc.documentElement
+ # We don't want to rely on any specific loader at this point, so
+ # just make sure we can get to all the names, and that the
+ # DTD-based namespace is right. The names can vary by loader
+ # since each supports a different level of DTD information.
+ t = elem.schemaType
+ self.confirm(t.name is None
+ and t.namespace == xml.dom.EMPTY_NAMESPACE)
+ names = "id notid text enum ref refs ent ents nm nms".split()
+ for name in names:
+ a = elem.getAttributeNode(name)
+ t = a.schemaType
+ self.confirm(hasattr(t, "name")
+ and t.namespace == xml.dom.EMPTY_NAMESPACE)
+
+ def testSetIdAttribute(self):
+ doc = parseString("<doc a1='v' a2='w'/>")
+ e = doc.documentElement
+ a1 = e.getAttributeNode("a1")
+ a2 = e.getAttributeNode("a2")
+ self.confirm(doc.getElementById("v") is None
+ and not a1.isId
+ and not a2.isId)
+ e.setIdAttribute("a1")
+ self.confirm(e.isSameNode(doc.getElementById("v"))
+ and a1.isId
+ and not a2.isId)
+ e.setIdAttribute("a2")
+ self.confirm(e.isSameNode(doc.getElementById("v"))
+ and e.isSameNode(doc.getElementById("w"))
+ and a1.isId
+ and a2.isId)
+ # replace the a1 node; the new node should *not* be an ID
+ a3 = doc.createAttribute("a1")
+ a3.value = "v"
+ e.setAttributeNode(a3)
+ self.confirm(doc.getElementById("v") is None
+ and e.isSameNode(doc.getElementById("w"))
+ and not a1.isId
+ and a2.isId
+ and not a3.isId)
+ # renaming an attribute should not affect its ID-ness:
+ doc.renameNode(a2, xml.dom.EMPTY_NAMESPACE, "an")
+ self.confirm(e.isSameNode(doc.getElementById("w"))
+ and a2.isId)
+
+ def testSetIdAttributeNS(self):
+ NS1 = "http://xml.python.org/ns1"
+ NS2 = "http://xml.python.org/ns2"
+ doc = parseString("<doc"
+ " xmlns:ns1='" + NS1 + "'"
+ " xmlns:ns2='" + NS2 + "'"
+ " ns1:a1='v' ns2:a2='w'/>")
+ e = doc.documentElement
+ a1 = e.getAttributeNodeNS(NS1, "a1")
+ a2 = e.getAttributeNodeNS(NS2, "a2")
+ self.confirm(doc.getElementById("v") is None
+ and not a1.isId
+ and not a2.isId)
+ e.setIdAttributeNS(NS1, "a1")
+ self.confirm(e.isSameNode(doc.getElementById("v"))
+ and a1.isId
+ and not a2.isId)
+ e.setIdAttributeNS(NS2, "a2")
+ self.confirm(e.isSameNode(doc.getElementById("v"))
+ and e.isSameNode(doc.getElementById("w"))
+ and a1.isId
+ and a2.isId)
+ # replace the a1 node; the new node should *not* be an ID
+ a3 = doc.createAttributeNS(NS1, "a1")
+ a3.value = "v"
+ e.setAttributeNode(a3)
+ self.confirm(e.isSameNode(doc.getElementById("w")))
+ self.confirm(not a1.isId)
+ self.confirm(a2.isId)
+ self.confirm(not a3.isId)
+ self.confirm(doc.getElementById("v") is None)
+ # renaming an attribute should not affect its ID-ness:
+ doc.renameNode(a2, xml.dom.EMPTY_NAMESPACE, "an")
+ self.confirm(e.isSameNode(doc.getElementById("w"))
+ and a2.isId)
+
+ def testSetIdAttributeNode(self):
+ NS1 = "http://xml.python.org/ns1"
+ NS2 = "http://xml.python.org/ns2"
+ doc = parseString("<doc"
+ " xmlns:ns1='" + NS1 + "'"
+ " xmlns:ns2='" + NS2 + "'"
+ " ns1:a1='v' ns2:a2='w'/>")
+ e = doc.documentElement
+ a1 = e.getAttributeNodeNS(NS1, "a1")
+ a2 = e.getAttributeNodeNS(NS2, "a2")
+ self.confirm(doc.getElementById("v") is None
+ and not a1.isId
+ and not a2.isId)
+ e.setIdAttributeNode(a1)
+ self.confirm(e.isSameNode(doc.getElementById("v"))
+ and a1.isId
+ and not a2.isId)
+ e.setIdAttributeNode(a2)
+ self.confirm(e.isSameNode(doc.getElementById("v"))
+ and e.isSameNode(doc.getElementById("w"))
+ and a1.isId
+ and a2.isId)
+ # replace the a1 node; the new node should *not* be an ID
+ a3 = doc.createAttributeNS(NS1, "a1")
+ a3.value = "v"
+ e.setAttributeNode(a3)
+ self.confirm(e.isSameNode(doc.getElementById("w")))
+ self.confirm(not a1.isId)
+ self.confirm(a2.isId)
+ self.confirm(not a3.isId)
+ self.confirm(doc.getElementById("v") is None)
+ # renaming an attribute should not affect its ID-ness:
+ doc.renameNode(a2, xml.dom.EMPTY_NAMESPACE, "an")
+ self.confirm(e.isSameNode(doc.getElementById("w"))
+ and a2.isId)
+
+ def testPickledDocument(self):
+ doc = parseString("<?xml version='1.0' encoding='us-ascii'?>\n"
+ "<!DOCTYPE doc PUBLIC 'http://xml.python.org/public'"
+ " 'http://xml.python.org/system' [\n"
+ " <!ELEMENT e EMPTY>\n"
+ " <!ENTITY ent SYSTEM 'http://xml.python.org/entity'>\n"
+ "]><doc attr='value'> text\n"
+ "<?pi sample?> <!-- comment --> <e/> </doc>")
+ s = pickle.dumps(doc)
+ doc2 = pickle.loads(s)
+ stack = [(doc, doc2)]
+ while stack:
+ n1, n2 = stack.pop()
+ self.confirm(n1.nodeType == n2.nodeType
+ and len(n1.childNodes) == len(n2.childNodes)
+ and n1.nodeName == n2.nodeName
+ and not n1.isSameNode(n2)
+ and not n2.isSameNode(n1))
+ if n1.nodeType == Node.DOCUMENT_TYPE_NODE:
+ len(n1.entities)
+ len(n2.entities)
+ len(n1.notations)
+ len(n2.notations)
+ self.confirm(len(n1.entities) == len(n2.entities)
+ and len(n1.notations) == len(n2.notations))
+ for i in range(len(n1.notations)):
+ no1 = n1.notations.item(i)
+ no2 = n1.notations.item(i)
+ self.confirm(no1.name == no2.name
+ and no1.publicId == no2.publicId
+ and no1.systemId == no2.systemId)
+ statck.append((no1, no2))
+ for i in range(len(n1.entities)):
+ e1 = n1.entities.item(i)
+ e2 = n2.entities.item(i)
+ self.confirm(e1.notationName == e2.notationName
+ and e1.publicId == e2.publicId
+ and e1.systemId == e2.systemId)
+ stack.append((e1, e2))
+ if n1.nodeType != Node.DOCUMENT_NODE:
+ self.confirm(n1.ownerDocument.isSameNode(doc)
+ and n2.ownerDocument.isSameNode(doc2))
+ for i in range(len(n1.childNodes)):
+ stack.append((n1.childNodes[i], n2.childNodes[i]))
+
+def test_main():
+ run_unittest(MinidomTest)
-if failed:
- print("\n\n\n**** Check for failures in these tests:")
- for name in failed:
- print(" " + name)
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_module.py b/Lib/test/test_module.py
index 7911a0e..cc8b192 100644
--- a/Lib/test/test_module.py
+++ b/Lib/test/test_module.py
@@ -1,48 +1,61 @@
# Test the module type
-
-from test.test_support import verify, vereq, verbose, TestFailed
+import unittest
+from test.test_support import verbose, run_unittest
import sys
-module = type(sys)
-
-# An uninitialized module has no __dict__ or __name__, and __doc__ is None
-foo = module.__new__(module)
-verify(foo.__dict__ is None)
-try:
- s = foo.__name__
-except AttributeError:
- pass
-else:
- raise TestFailed, "__name__ = %s" % repr(s)
-vereq(foo.__doc__, module.__doc__)
-
-# Regularly initialized module, no docstring
-foo = module("foo")
-vereq(foo.__name__, "foo")
-vereq(foo.__doc__, None)
-vereq(foo.__dict__, {"__name__": "foo", "__doc__": None})
-
-# ASCII docstring
-foo = module("foo", "foodoc")
-vereq(foo.__name__, "foo")
-vereq(foo.__doc__, "foodoc")
-vereq(foo.__dict__, {"__name__": "foo", "__doc__": "foodoc"})
-
-# Unicode docstring
-foo = module("foo", u"foodoc\u1234")
-vereq(foo.__name__, "foo")
-vereq(foo.__doc__, u"foodoc\u1234")
-vereq(foo.__dict__, {"__name__": "foo", "__doc__": u"foodoc\u1234"})
-
-# Reinitialization should not replace the __dict__
-foo.bar = 42
-d = foo.__dict__
-foo.__init__("foo", "foodoc")
-vereq(foo.__name__, "foo")
-vereq(foo.__doc__, "foodoc")
-vereq(foo.bar, 42)
-vereq(foo.__dict__, {"__name__": "foo", "__doc__": "foodoc", "bar": 42})
-verify(foo.__dict__ is d)
-
-if verbose:
- print("All OK")
+ModuleType = type(sys)
+
+class ModuleTests(unittest.TestCase):
+ def test_uninitialized(self):
+ # An uninitialized module has no __dict__ or __name__,
+ # and __doc__ is None
+ foo = ModuleType.__new__(ModuleType)
+ self.failUnless(foo.__dict__ is None)
+ try:
+ s = foo.__name__
+ self.fail("__name__ = %s" % repr(s))
+ except AttributeError:
+ pass
+ self.assertEqual(foo.__doc__, ModuleType.__doc__)
+
+ def test_no_docstring(self):
+ # Regularly initialized module, no docstring
+ foo = ModuleType("foo")
+ self.assertEqual(foo.__name__, "foo")
+ self.assertEqual(foo.__doc__, None)
+ self.assertEqual(foo.__dict__, {"__name__": "foo", "__doc__": None})
+
+ def test_ascii_docstring(self):
+ # ASCII docstring
+ foo = ModuleType("foo", "foodoc")
+ self.assertEqual(foo.__name__, "foo")
+ self.assertEqual(foo.__doc__, "foodoc")
+ self.assertEqual(foo.__dict__,
+ {"__name__": "foo", "__doc__": "foodoc"})
+
+ def test_unicode_docstring(self):
+ # Unicode docstring
+ foo = ModuleType("foo", u"foodoc\u1234")
+ self.assertEqual(foo.__name__, "foo")
+ self.assertEqual(foo.__doc__, u"foodoc\u1234")
+ self.assertEqual(foo.__dict__,
+ {"__name__": "foo", "__doc__": u"foodoc\u1234"})
+
+ def test_reinit(self):
+ # Reinitialization should not replace the __dict__
+ foo = ModuleType("foo", u"foodoc\u1234")
+ foo.bar = 42
+ d = foo.__dict__
+ foo.__init__("foo", "foodoc")
+ self.assertEqual(foo.__name__, "foo")
+ self.assertEqual(foo.__doc__, "foodoc")
+ self.assertEqual(foo.bar, 42)
+ self.assertEqual(foo.__dict__,
+ {"__name__": "foo", "__doc__": "foodoc", "bar": 42})
+ self.failUnless(foo.__dict__ is d)
+
+def test_main():
+ run_unittest(ModuleTests)
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_multibytecodec.py b/Lib/test/test_multibytecodec.py
index 2ac7061..c5615a8 100644
--- a/Lib/test/test_multibytecodec.py
+++ b/Lib/test/test_multibytecodec.py
@@ -219,13 +219,7 @@ class Test_ISO2022(unittest.TestCase):
myunichr(x).encode('iso_2022_jp', 'ignore')
def test_main():
- suite = unittest.TestSuite()
- suite.addTest(unittest.makeSuite(Test_MultibyteCodec))
- suite.addTest(unittest.makeSuite(Test_IncrementalEncoder))
- suite.addTest(unittest.makeSuite(Test_IncrementalDecoder))
- suite.addTest(unittest.makeSuite(Test_StreamWriter))
- suite.addTest(unittest.makeSuite(Test_ISO2022))
- test_support.run_suite(suite)
+ test_support.run_unittest(__name__)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_normalization.py b/Lib/test/test_normalization.py
index d890067..b571bdc 100644
--- a/Lib/test/test_normalization.py
+++ b/Lib/test/test_normalization.py
@@ -1,5 +1,6 @@
-from test.test_support import (verbose, TestFailed, TestSkipped, verify,
- open_urlresource)
+from test.test_support import run_unittest, open_urlresource
+import unittest
+
import sys
import os
from unicodedata import normalize
@@ -29,60 +30,67 @@ def unistr(data):
raise RangeError
return u"".join([unichr(x) for x in data])
-def test_main():
- part1_data = {}
- for line in open_urlresource(TESTDATAURL):
- if '#' in line:
- line = line.split('#')[0]
- line = line.strip()
- if not line:
- continue
- if line.startswith("@Part"):
- part = line.split()[0]
- continue
- if part == "@Part3":
- # XXX we don't support PRI #29 yet, so skip these tests for now
- continue
- try:
- c1,c2,c3,c4,c5 = [unistr(x) for x in line.split(';')[:-1]]
- except RangeError:
- # Skip unsupported characters;
- # try atleast adding c1 if we are in part1
+class NormalizationTest(unittest.TestCase):
+ def test_main(self):
+ part1_data = {}
+ for line in open_urlresource(TESTDATAURL):
+ if '#' in line:
+ line = line.split('#')[0]
+ line = line.strip()
+ if not line:
+ continue
+ if line.startswith("@Part"):
+ part = line.split()[0]
+ continue
+ if part == "@Part3":
+ # XXX we don't support PRI #29 yet, so skip these tests for now
+ continue
+ try:
+ c1,c2,c3,c4,c5 = [unistr(x) for x in line.split(';')[:-1]]
+ except RangeError:
+ # Skip unsupported characters;
+ # try atleast adding c1 if we are in part1
+ if part == "@Part1":
+ try:
+ c1 = unistr(line.split(';')[0])
+ except RangeError:
+ pass
+ else:
+ part1_data[c1] = 1
+ continue
+
+ # Perform tests
+ self.failUnless(c2 == NFC(c1) == NFC(c2) == NFC(c3), line)
+ self.failUnless(c4 == NFC(c4) == NFC(c5), line)
+ self.failUnless(c3 == NFD(c1) == NFD(c2) == NFD(c3), line)
+ self.failUnless(c5 == NFD(c4) == NFD(c5), line)
+ self.failUnless(c4 == NFKC(c1) == NFKC(c2) == \
+ NFKC(c3) == NFKC(c4) == NFKC(c5),
+ line)
+ self.failUnless(c5 == NFKD(c1) == NFKD(c2) == \
+ NFKD(c3) == NFKD(c4) == NFKD(c5),
+ line)
+
+ # Record part 1 data
if part == "@Part1":
- try:
- c1=unistr(line.split(';')[0])
- except RangeError:
- pass
- else:
- part1_data[c1] = 1
- continue
-
- if verbose:
- print(line)
-
- # Perform tests
- verify(c2 == NFC(c1) == NFC(c2) == NFC(c3), line)
- verify(c4 == NFC(c4) == NFC(c5), line)
- verify(c3 == NFD(c1) == NFD(c2) == NFD(c3), line)
- verify(c5 == NFD(c4) == NFD(c5), line)
- verify(c4 == NFKC(c1) == NFKC(c2) == NFKC(c3) == NFKC(c4) == NFKC(c5),
- line)
- verify(c5 == NFKD(c1) == NFKD(c2) == NFKD(c3) == NFKD(c4) == NFKD(c5),
- line)
-
- # Record part 1 data
- if part == "@Part1":
- part1_data[c1] = 1
-
- # Perform tests for all other data
- for c in range(sys.maxunicode+1):
- X = unichr(c)
- if X in part1_data:
- continue
- assert X == NFC(X) == NFD(X) == NFKC(X) == NFKD(X), c
-
- # Check for bug 834676
- normalize('NFC',u'\ud55c\uae00')
+ part1_data[c1] = 1
+
+ # Perform tests for all other data
+ for c in range(sys.maxunicode+1):
+ X = unichr(c)
+ if X in part1_data:
+ continue
+ self.failUnless(X == NFC(X) == NFD(X) == NFKC(X) == NFKD(X), c)
+
+ def test_bug_834676(self):
+ # Check for bug 834676
+ normalize('NFC', u'\ud55c\uae00')
+
+
+def test_main():
+ # Hit the exception early
+ open_urlresource(TESTDATAURL)
+ run_unittest(NormalizationTest)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_ntpath.py b/Lib/test/test_ntpath.py
index 939886d..c6dbf2e 100644
--- a/Lib/test/test_ntpath.py
+++ b/Lib/test/test_ntpath.py
@@ -18,13 +18,14 @@ def tester(fn, wantResult):
tester('ntpath.splitext("foo.ext")', ('foo', '.ext'))
tester('ntpath.splitext("/foo/foo.ext")', ('/foo/foo', '.ext'))
-tester('ntpath.splitext(".ext")', ('', '.ext'))
+tester('ntpath.splitext(".ext")', ('.ext', ''))
tester('ntpath.splitext("\\foo.ext\\foo")', ('\\foo.ext\\foo', ''))
tester('ntpath.splitext("foo.ext\\")', ('foo.ext\\', ''))
tester('ntpath.splitext("")', ('', ''))
tester('ntpath.splitext("foo.bar.ext")', ('foo.bar', '.ext'))
tester('ntpath.splitext("xx/foo.bar.ext")', ('xx/foo.bar', '.ext'))
tester('ntpath.splitext("xx\\foo.bar.ext")', ('xx\\foo.bar', '.ext'))
+tester('ntpath.splitext("c:a/b\\c.d")', ('c:a/b\\c', '.d'))
tester('ntpath.splitdrive("c:\\foo\\bar")',
('c:', '\\foo\\bar'))
@@ -133,6 +134,13 @@ try:
tester('ntpath.expandvars("${{foo}}")', "baz1}")
tester('ntpath.expandvars("$foo$foo")', "barbar")
tester('ntpath.expandvars("$bar$bar")', "$bar$bar")
+ tester('ntpath.expandvars("%foo% bar")', "bar bar")
+ tester('ntpath.expandvars("%foo%bar")', "barbar")
+ tester('ntpath.expandvars("%foo%%foo%")', "barbar")
+ tester('ntpath.expandvars("%%foo%%foo%foo%")', "%foo%foobar")
+ tester('ntpath.expandvars("%?bar%")', "%?bar%")
+ tester('ntpath.expandvars("%foo%%bar")', "bar%bar")
+ tester('ntpath.expandvars("\'%foo%\'%bar")', "\'%foo%\'%bar")
finally:
os.environ.clear()
os.environ.update(oldenv)
@@ -149,6 +157,16 @@ except ImportError:
else:
tester('ntpath.abspath("C:\\")', "C:\\")
+currentdir = os.path.split(os.getcwd())[-1]
+tester('ntpath.relpath("a")', 'a')
+tester('ntpath.relpath(os.path.abspath("a"))', 'a')
+tester('ntpath.relpath("a/b")', 'a\\b')
+tester('ntpath.relpath("../a/b")', '..\\a\\b')
+tester('ntpath.relpath("a", "../b")', '..\\'+currentdir+'\\a')
+tester('ntpath.relpath("a/b", "../c")', '..\\'+currentdir+'\\a\\b')
+tester('ntpath.relpath("a", "b/c")', '..\\..\\a')
+tester('ntpath.relpath("//conky/mountpoint/a", "//conky/mountpoint/b/c")', '..\\..\\a')
+
if errors:
raise TestFailed(str(errors) + " errors.")
elif verbose:
diff --git a/Lib/test/test_operations.py b/Lib/test/test_operations.py
deleted file mode 100644
index e8b1ae8..0000000
--- a/Lib/test/test_operations.py
+++ /dev/null
@@ -1,77 +0,0 @@
-# Python test set -- part 3, built-in operations.
-
-
-print('3. Operations')
-print('XXX Mostly not yet implemented')
-
-
-print('3.1 Dictionary lookups fail if __cmp__() raises an exception')
-
-class BadDictKey:
-
- def __hash__(self):
- return hash(self.__class__)
-
- def __eq__(self, other):
- if isinstance(other, self.__class__):
- print("raising error")
- raise RuntimeError, "gotcha"
- return other
-
-d = {}
-x1 = BadDictKey()
-x2 = BadDictKey()
-d[x1] = 1
-for stmt in ['d[x2] = 2',
- 'z = d[x2]',
- 'x2 in d',
- 'd.get(x2)',
- 'd.setdefault(x2, 42)',
- 'd.pop(x2)',
- 'd.update({x2: 2})']:
- try:
- exec(stmt)
- except RuntimeError:
- print("%s: caught the RuntimeError outside" % (stmt,))
- else:
- print("%s: No exception passed through!" % (stmt,)) # old CPython behavior
-
-
-# Dict resizing bug, found by Jack Jansen in 2.2 CVS development.
-# This version got an assert failure in debug build, infinite loop in
-# release build. Unfortunately, provoking this kind of stuff requires
-# a mix of inserts and deletes hitting exactly the right hash codes in
-# exactly the right order, and I can't think of a randomized approach
-# that would be *likely* to hit a failing case in reasonable time.
-
-d = {}
-for i in range(5):
- d[i] = i
-for i in range(5):
- del d[i]
-for i in range(5, 9): # i==8 was the problem
- d[i] = i
-
-
-# Another dict resizing bug (SF bug #1456209).
-# This caused Segmentation faults or Illegal instructions.
-
-class X(object):
- def __hash__(self):
- return 5
- def __eq__(self, other):
- if resizing:
- d.clear()
- return False
-d = {}
-resizing = False
-d[X()] = 1
-d[X()] = 2
-d[X()] = 3
-d[X()] = 4
-d[X()] = 5
-# now trigger a resize
-resizing = True
-d[9] = 6
-
-print('resize bugs not triggered.')
diff --git a/Lib/test/test_operator.py b/Lib/test/test_operator.py
index f9519b2..8d70564 100644
--- a/Lib/test/test_operator.py
+++ b/Lib/test/test_operator.py
@@ -143,6 +143,8 @@ class OperatorTestCase(unittest.TestCase):
self.failUnlessRaises(TypeError, operator.delslice, a, None, None)
self.failUnless(operator.delslice(a, 2, 8) is None)
self.assert_(a == [0, 1, 8, 9])
+ operator.delslice(a, 0, test_support.MAX_Py_ssize_t)
+ self.assert_(a == [])
def test_floordiv(self):
self.failUnlessRaises(TypeError, operator.floordiv, 5)
@@ -165,6 +167,8 @@ class OperatorTestCase(unittest.TestCase):
self.failUnlessRaises(TypeError, operator.getslice)
self.failUnlessRaises(TypeError, operator.getslice, a, None, None)
self.failUnless(operator.getslice(a, 4, 6) == [4, 5])
+ b = operator.getslice(a, 0, test_support.MAX_Py_ssize_t)
+ self.assert_(b == a)
def test_indexOf(self):
self.failUnlessRaises(TypeError, operator.indexOf)
@@ -300,6 +304,8 @@ class OperatorTestCase(unittest.TestCase):
self.failUnlessRaises(TypeError, operator.setslice, a, None, None, None)
self.failUnless(operator.setslice(a, 1, 3, [2, 1]) is None)
self.assert_(a == [0, 2, 1, 3])
+ operator.setslice(a, 0, test_support.MAX_Py_ssize_t, [])
+ self.assert_(a == [])
def test_sub(self):
self.failUnlessRaises(TypeError, operator.sub)
diff --git a/Lib/test/test_optparse.py b/Lib/test/test_optparse.py
index 6ec2902..88e3a1f 100644
--- a/Lib/test/test_optparse.py
+++ b/Lib/test/test_optparse.py
@@ -1631,18 +1631,8 @@ class TestParseNumber(BaseTest):
"option -l: invalid integer value: '0x12x'")
-def _testclasses():
- mod = sys.modules[__name__]
- return [getattr(mod, name) for name in dir(mod) if name.startswith('Test')]
-
-def suite():
- suite = unittest.TestSuite()
- for testclass in _testclasses():
- suite.addTest(unittest.makeSuite(testclass))
- return suite
-
def test_main():
- test_support.run_suite(suite())
+ test_support.run_unittest(__name__)
if __name__ == '__main__':
- unittest.main()
+ test_main()
diff --git a/Lib/test/test_os.py b/Lib/test/test_os.py
index a7fc1da..ed044f6 100644
--- a/Lib/test/test_os.py
+++ b/Lib/test/test_os.py
@@ -240,6 +240,15 @@ class StatAttributeTests(unittest.TestCase):
os.utime(self.fname, (t1, t1))
self.assertEquals(os.stat(self.fname).st_mtime, t1)
+ def test_1686475(self):
+ # Verify that an open file can be stat'ed
+ try:
+ os.stat(r"c:\pagefile.sys")
+ except WindowsError as e:
+ if e == 2: # file does not exist; cannot run test
+ return
+ self.fail("Could not stat pagefile.sys")
+
from test import mapping_tests
class EnvironTests(mapping_tests.BasicTestMappingProtocol):
@@ -272,75 +281,104 @@ class WalkTests(unittest.TestCase):
from os.path import join
# Build:
- # TESTFN/ a file kid and two directory kids
+ # TESTFN/
+ # TEST1/ a file kid and two directory kids
# tmp1
# SUB1/ a file kid and a directory kid
- # tmp2
- # SUB11/ no kids
- # SUB2/ just a file kid
- # tmp3
- sub1_path = join(test_support.TESTFN, "SUB1")
+ # tmp2
+ # SUB11/ no kids
+ # SUB2/ a file kid and a dirsymlink kid
+ # tmp3
+ # link/ a symlink to TESTFN.2
+ # TEST2/
+ # tmp4 a lone file
+ walk_path = join(test_support.TESTFN, "TEST1")
+ sub1_path = join(walk_path, "SUB1")
sub11_path = join(sub1_path, "SUB11")
- sub2_path = join(test_support.TESTFN, "SUB2")
- tmp1_path = join(test_support.TESTFN, "tmp1")
+ sub2_path = join(walk_path, "SUB2")
+ tmp1_path = join(walk_path, "tmp1")
tmp2_path = join(sub1_path, "tmp2")
tmp3_path = join(sub2_path, "tmp3")
+ link_path = join(sub2_path, "link")
+ t2_path = join(test_support.TESTFN, "TEST2")
+ tmp4_path = join(test_support.TESTFN, "TEST2", "tmp4")
# Create stuff.
os.makedirs(sub11_path)
os.makedirs(sub2_path)
- for path in tmp1_path, tmp2_path, tmp3_path:
+ os.makedirs(t2_path)
+ for path in tmp1_path, tmp2_path, tmp3_path, tmp4_path:
f = open(path, "w")
f.write("I'm " + path + " and proud of it. Blame test_os.\n")
f.close()
+ if hasattr(os, "symlink"):
+ os.symlink(os.path.abspath(t2_path), link_path)
+ sub2_tree = (sub2_path, ["link"], ["tmp3"])
+ else:
+ sub2_tree = (sub2_path, [], ["tmp3"])
# Walk top-down.
- all = list(os.walk(test_support.TESTFN))
+ all = list(os.walk(walk_path))
self.assertEqual(len(all), 4)
# We can't know which order SUB1 and SUB2 will appear in.
# Not flipped: TESTFN, SUB1, SUB11, SUB2
# flipped: TESTFN, SUB2, SUB1, SUB11
flipped = all[0][1][0] != "SUB1"
all[0][1].sort()
- self.assertEqual(all[0], (test_support.TESTFN, ["SUB1", "SUB2"], ["tmp1"]))
+ self.assertEqual(all[0], (walk_path, ["SUB1", "SUB2"], ["tmp1"]))
self.assertEqual(all[1 + flipped], (sub1_path, ["SUB11"], ["tmp2"]))
self.assertEqual(all[2 + flipped], (sub11_path, [], []))
- self.assertEqual(all[3 - 2 * flipped], (sub2_path, [], ["tmp3"]))
+ self.assertEqual(all[3 - 2 * flipped], sub2_tree)
# Prune the search.
all = []
- for root, dirs, files in os.walk(test_support.TESTFN):
+ for root, dirs, files in os.walk(walk_path):
all.append((root, dirs, files))
# Don't descend into SUB1.
if 'SUB1' in dirs:
# Note that this also mutates the dirs we appended to all!
dirs.remove('SUB1')
self.assertEqual(len(all), 2)
- self.assertEqual(all[0], (test_support.TESTFN, ["SUB2"], ["tmp1"]))
- self.assertEqual(all[1], (sub2_path, [], ["tmp3"]))
+ self.assertEqual(all[0], (walk_path, ["SUB2"], ["tmp1"]))
+ self.assertEqual(all[1], sub2_tree)
# Walk bottom-up.
- all = list(os.walk(test_support.TESTFN, topdown=False))
+ all = list(os.walk(walk_path, topdown=False))
self.assertEqual(len(all), 4)
# We can't know which order SUB1 and SUB2 will appear in.
# Not flipped: SUB11, SUB1, SUB2, TESTFN
# flipped: SUB2, SUB11, SUB1, TESTFN
flipped = all[3][1][0] != "SUB1"
all[3][1].sort()
- self.assertEqual(all[3], (test_support.TESTFN, ["SUB1", "SUB2"], ["tmp1"]))
+ self.assertEqual(all[3], (walk_path, ["SUB1", "SUB2"], ["tmp1"]))
self.assertEqual(all[flipped], (sub11_path, [], []))
self.assertEqual(all[flipped + 1], (sub1_path, ["SUB11"], ["tmp2"]))
- self.assertEqual(all[2 - 2 * flipped], (sub2_path, [], ["tmp3"]))
+ self.assertEqual(all[2 - 2 * flipped], sub2_tree)
+
+ if hasattr(os, "symlink"):
+ # Walk, following symlinks.
+ for root, dirs, files in os.walk(walk_path, followlinks=True):
+ if root == link_path:
+ self.assertEqual(dirs, [])
+ self.assertEqual(files, ["tmp4"])
+ break
+ else:
+ self.fail("Didn't follow symlink with followlinks=True")
+ def tearDown(self):
# Tear everything down. This is a decent use for bottom-up on
# Windows, which doesn't have a recursive delete command. The
# (not so) subtlety is that rmdir will fail unless the dir's
# kids are removed first, so bottom up is essential.
for root, dirs, files in os.walk(test_support.TESTFN, topdown=False):
for name in files:
- os.remove(join(root, name))
+ os.remove(os.path.join(root, name))
for name in dirs:
- os.rmdir(join(root, name))
+ dirname = os.path.join(root, name)
+ if not os.path.islink(dirname):
+ os.rmdir(dirname)
+ else:
+ os.remove(dirname)
os.rmdir(test_support.TESTFN)
class MakedirTests (unittest.TestCase):
diff --git a/Lib/test/test_ossaudiodev.py b/Lib/test/test_ossaudiodev.py
index 6ca1f74..eb15e88 100644
--- a/Lib/test/test_ossaudiodev.py
+++ b/Lib/test/test_ossaudiodev.py
@@ -1,7 +1,7 @@
from test import test_support
test_support.requires('audio')
-from test.test_support import verbose, findfile, TestFailed, TestSkipped
+from test.test_support import verbose, findfile, TestSkipped
import errno
import fcntl
@@ -12,6 +12,7 @@ import select
import sunaudio
import time
import audioop
+import unittest
# Arggh, AFMT_S16_NE not defined on all platforms -- seems to be a
# fairly recent addition to OSS.
@@ -33,131 +34,143 @@ def read_sound_file(path):
fp.close()
if enc != SND_FORMAT_MULAW_8:
- print("Expect .au file with 8-bit mu-law samples")
- return
+ raise RuntimeError("Expect .au file with 8-bit mu-law samples")
# Convert the data to 16-bit signed.
data = audioop.ulaw2lin(data, 2)
return (data, rate, 16, nchannels)
-# version of assert that still works with -O
-def _assert(expr, message=None):
- if not expr:
- raise AssertionError(message or "assertion failed")
+class OSSAudioDevTests(unittest.TestCase):
-def play_sound_file(data, rate, ssize, nchannels):
- try:
- dsp = ossaudiodev.open('w')
- except IOError as msg:
- if msg.args[0] in (errno.EACCES, errno.ENOENT, errno.ENODEV, errno.EBUSY):
- raise TestSkipped, msg
- raise TestFailed, msg
-
- # at least check that these methods can be invoked
- dsp.bufsize()
- dsp.obufcount()
- dsp.obuffree()
- dsp.getptr()
- dsp.fileno()
-
- # Make sure the read-only attributes work.
- _assert(dsp.closed is False, "dsp.closed is not False")
- _assert(dsp.name == "/dev/dsp")
- _assert(dsp.mode == 'w', "bad dsp.mode: %r" % dsp.mode)
-
- # And make sure they're really read-only.
- for attr in ('closed', 'name', 'mode'):
+ def play_sound_file(self, data, rate, ssize, nchannels):
try:
- setattr(dsp, attr, 42)
- raise RuntimeError("dsp.%s not read-only" % attr)
- except TypeError:
- pass
-
- # Compute expected running time of sound sample (in seconds).
- expected_time = float(len(data)) / (ssize/8) / nchannels / rate
-
- # set parameters based on .au file headers
- dsp.setparameters(AFMT_S16_NE, nchannels, rate)
- print(("playing test sound file (expected running time: %.2f sec)"
- % expected_time))
- t1 = time.time()
- dsp.write(data)
- dsp.close()
- t2 = time.time()
- elapsed_time = t2 - t1
-
- percent_diff = (abs(elapsed_time - expected_time) / expected_time) * 100
- _assert(percent_diff <= 10.0, \
- ("elapsed time (%.2f sec) > 10%% off of expected time (%.2f sec)"
- % (elapsed_time, expected_time)))
-
-def test_setparameters(dsp):
- # Two configurations for testing:
- # config1 (8-bit, mono, 8 kHz) should work on even the most
- # ancient and crufty sound card, but maybe not on special-
- # purpose high-end hardware
- # config2 (16-bit, stereo, 44.1kHz) should work on all but the
- # most ancient and crufty hardware
- config1 = (ossaudiodev.AFMT_U8, 1, 8000)
- config2 = (AFMT_S16_NE, 2, 44100)
-
- for config in [config1, config2]:
- (fmt, channels, rate) = config
- if (dsp.setfmt(fmt) == fmt and
- dsp.channels(channels) == channels and
- dsp.speed(rate) == rate):
- break
- else:
- raise RuntimeError("unable to set audio sampling parameters: "
- "you must have really weird audio hardware")
-
- # setparameters() should be able to set this configuration in
- # either strict or non-strict mode.
- result = dsp.setparameters(fmt, channels, rate, False)
- _assert(result == (fmt, channels, rate),
- "setparameters%r: returned %r" % (config, result))
- result = dsp.setparameters(fmt, channels, rate, True)
- _assert(result == (fmt, channels, rate),
- "setparameters%r: returned %r" % (config, result))
-
-def test_bad_setparameters(dsp):
-
- # Now try some configurations that are presumably bogus: eg. 300
- # channels currently exceeds even Hollywood's ambitions, and
- # negative sampling rate is utter nonsense. setparameters() should
- # accept these in non-strict mode, returning something other than
- # was requested, but should barf in strict mode.
- fmt = AFMT_S16_NE
- rate = 44100
- channels = 2
- for config in [(fmt, 300, rate), # ridiculous nchannels
- (fmt, -5, rate), # impossible nchannels
- (fmt, channels, -50), # impossible rate
- ]:
- (fmt, channels, rate) = config
+ dsp = ossaudiodev.open('w')
+ except IOError as msg:
+ if msg.args[0] in (errno.EACCES, errno.ENOENT,
+ errno.ENODEV, errno.EBUSY):
+ raise TestSkipped(msg)
+ raise
+
+ # at least check that these methods can be invoked
+ dsp.bufsize()
+ dsp.obufcount()
+ dsp.obuffree()
+ dsp.getptr()
+ dsp.fileno()
+
+ # Make sure the read-only attributes work.
+ self.failUnless(dsp.close)
+ self.assertEqual(dsp.name, "/dev/dsp")
+ self.assertEqual(dsp.mode, "w", "bad dsp.mode: %r" % dsp.mode)
+
+ # And make sure they're really read-only.
+ for attr in ('closed', 'name', 'mode'):
+ try:
+ setattr(dsp, attr, 42)
+ except TypeError:
+ pass
+ else:
+ self.fail("dsp.%s not read-only" % attr)
+
+ # Compute expected running time of sound sample (in seconds).
+ expected_time = float(len(data)) / (ssize/8) / nchannels / rate
+
+ # set parameters based on .au file headers
+ dsp.setparameters(AFMT_S16_NE, nchannels, rate)
+ print ("playing test sound file (expected running time: %.2f sec)"
+ % expected_time)
+ t1 = time.time()
+ dsp.write(data)
+ dsp.close()
+ t2 = time.time()
+ elapsed_time = t2 - t1
+
+ percent_diff = (abs(elapsed_time - expected_time) / expected_time) * 100
+ self.failUnless(percent_diff <= 10.0,
+ "elapsed time > 10% off of expected time")
+
+ def set_parameters(self, dsp):
+ # Two configurations for testing:
+ # config1 (8-bit, mono, 8 kHz) should work on even the most
+ # ancient and crufty sound card, but maybe not on special-
+ # purpose high-end hardware
+ # config2 (16-bit, stereo, 44.1kHz) should work on all but the
+ # most ancient and crufty hardware
+ config1 = (ossaudiodev.AFMT_U8, 1, 8000)
+ config2 = (AFMT_S16_NE, 2, 44100)
+
+ for config in [config1, config2]:
+ (fmt, channels, rate) = config
+ if (dsp.setfmt(fmt) == fmt and
+ dsp.channels(channels) == channels and
+ dsp.speed(rate) == rate):
+ break
+ else:
+ raise RuntimeError("unable to set audio sampling parameters: "
+ "you must have really weird audio hardware")
+
+ # setparameters() should be able to set this configuration in
+ # either strict or non-strict mode.
result = dsp.setparameters(fmt, channels, rate, False)
- _assert(result != config,
- "setparameters: unexpectedly got requested configuration")
-
+ self.assertEqual(result, (fmt, channels, rate),
+ "setparameters%r: returned %r" % (config, result))
+
+ result = dsp.setparameters(fmt, channels, rate, True)
+ self.assertEqual(result, (fmt, channels, rate),
+ "setparameters%r: returned %r" % (config, result))
+
+ def set_bad_parameters(self, dsp):
+
+ # Now try some configurations that are presumably bogus: eg. 300
+ # channels currently exceeds even Hollywood's ambitions, and
+ # negative sampling rate is utter nonsense. setparameters() should
+ # accept these in non-strict mode, returning something other than
+ # was requested, but should barf in strict mode.
+ fmt = AFMT_S16_NE
+ rate = 44100
+ channels = 2
+ for config in [(fmt, 300, rate), # ridiculous nchannels
+ (fmt, -5, rate), # impossible nchannels
+ (fmt, channels, -50), # impossible rate
+ ]:
+ (fmt, channels, rate) = config
+ result = dsp.setparameters(fmt, channels, rate, False)
+ self.failIfEqual(result, config,
+ "unexpectedly got requested configuration")
+
+ try:
+ result = dsp.setparameters(fmt, channels, rate, True)
+ except ossaudiodev.OSSAudioError as err:
+ pass
+ else:
+ self.fail("expected OSSAudioError")
+
+ def test_playback(self):
+ sound_info = read_sound_file(findfile('audiotest.au'))
+ self.play_sound_file(*sound_info)
+
+ def test_set_parameters(self):
+ dsp = ossaudiodev.open("w")
try:
- result = dsp.setparameters(fmt, channels, rate, True)
- raise AssertionError("setparameters: expected OSSAudioError")
- except ossaudiodev.OSSAudioError as err:
- print("setparameters: got OSSAudioError as expected")
+ self.set_parameters(dsp)
-def test():
- (data, rate, ssize, nchannels) = read_sound_file(findfile('audiotest.au'))
- play_sound_file(data, rate, ssize, nchannels)
+ # Disabled because it fails under Linux 2.6 with ALSA's OSS
+ # emulation layer.
+ #self.set_bad_parameters(dsp)
+ finally:
+ dsp.close()
+ self.failUnless(dsp.closed)
- dsp = ossaudiodev.open("w")
- try:
- test_setparameters(dsp)
-
- # Disabled because it fails under Linux 2.6 with ALSA's OSS
- # emulation layer.
- #test_bad_setparameters(dsp)
- finally:
- dsp.close()
- _assert(dsp.closed is True, "dsp.closed is not True")
-test()
+def test_main():
+ try:
+ dsp = ossaudiodev.open('w')
+ except IOError as msg:
+ if msg.args[0] in (errno.EACCES, errno.ENOENT,
+ errno.ENODEV, errno.EBUSY):
+ raise TestSkipped(msg)
+ raise
+ test_support.run_unittest(__name__)
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_peepholer.py b/Lib/test/test_peepholer.py
index 9ed814a..1611e39 100644
--- a/Lib/test/test_peepholer.py
+++ b/Lib/test/test_peepholer.py
@@ -49,6 +49,11 @@ class TestTranforms(unittest.TestCase):
self.assert_(elem not in asm)
for elem in ('LOAD_CONST', '(None)'):
self.assert_(elem in asm)
+ def f():
+ 'Adding a docstring made this test fail in Py2.5.0'
+ return None
+ self.assert_('LOAD_CONST' in disassemble(f))
+ self.assert_('LOAD_GLOBAL' not in disassemble(f))
def test_while_one(self):
# Skip over: LOAD_CONST trueconst JUMP_IF_FALSE xx POP_TOP
@@ -195,14 +200,14 @@ class TestTranforms(unittest.TestCase):
# There should be one jump for the while loop.
self.assertEqual(asm.split().count('JUMP_ABSOLUTE'), 1)
self.assertEqual(asm.split().count('RETURN_VALUE'), 2)
-
+
def test_make_function_doesnt_bail(self):
def f():
- def g()->1+1:
+ def g()->1+1:
pass
return g
asm = disassemble(f)
- self.assert_('BINARY_ADD' not in asm)
+ self.assert_('BINARY_ADD' not in asm)
def test_main(verbose=None):
diff --git a/Lib/test/test_popen2.py b/Lib/test/test_popen2.py
index 008a67a..31f22d6 100644
--- a/Lib/test/test_popen2.py
+++ b/Lib/test/test_popen2.py
@@ -1,78 +1,92 @@
#! /usr/bin/env python
-"""Test script for popen2.py
- Christian Tismer
-"""
+"""Test script for popen2.py"""
import os
import sys
-from test.test_support import TestSkipped, reap_children
-
-# popen2 contains its own testing routine
-# which is especially useful to see if open files
-# like stdin can be read successfully by a forked
-# subprocess.
-
-def main():
- print("Test popen2 module:")
- if (sys.platform[:4] == 'beos' or sys.platform[:6] == 'atheos') \
- and __name__ != '__main__':
- # Locks get messed up or something. Generally we're supposed
- # to avoid mixing "posix" fork & exec with native threads, and
- # they may be right about that after all.
- raise TestSkipped, "popen2() doesn't work during import on " + sys.platform
- try:
- from os import popen
- except ImportError:
- # if we don't have os.popen, check that
- # we have os.fork. if not, skip the test
- # (by raising an ImportError)
- from os import fork
- import popen2
- popen2._test()
-
-
-def _test():
- # same test as popen2._test(), but using the os.popen*() API
- print("Testing os module:")
- import popen2
- # When the test runs, there shouldn't be any open pipes
- popen2._cleanup()
- assert not popen2._active, "Active pipes when test starts " + repr([c.cmd for c in popen2._active])
- cmd = "cat"
- teststr = "ab cd\n"
+import unittest
+import popen2
+
+from test.test_support import TestSkipped, run_unittest, reap_children
+
+if sys.platform[:4] == 'beos' or sys.platform[:6] == 'atheos':
+ # Locks get messed up or something. Generally we're supposed
+ # to avoid mixing "posix" fork & exec with native threads, and
+ # they may be right about that after all.
+ raise TestSkipped("popen2() doesn't work on " + sys.platform)
+
+# if we don't have os.popen, check that
+# we have os.fork. if not, skip the test
+# (by raising an ImportError)
+try:
+ from os import popen
+ del popen
+except ImportError:
+ from os import fork
+ del fork
+
+class Popen2Test(unittest.TestCase):
+ cmd = "cat"
if os.name == "nt":
cmd = "more"
+ teststr = "ab cd\n"
# "more" doesn't act the same way across Windows flavors,
# sometimes adding an extra newline at the start or the
# end. So we strip whitespace off both ends for comparison.
expected = teststr.strip()
- print("testing popen2...")
- w, r = os.popen2(cmd)
- w.write(teststr)
- w.close()
- got = r.read()
- if got.strip() != expected:
- raise ValueError("wrote %r read %r" % (teststr, got))
- print("testing popen3...")
- try:
- w, r, e = os.popen3([cmd])
- except:
- w, r, e = os.popen3(cmd)
- w.write(teststr)
- w.close()
- got = r.read()
- if got.strip() != expected:
- raise ValueError("wrote %r read %r" % (teststr, got))
- got = e.read()
- if got:
- raise ValueError("unexpected %r on stderr" % (got,))
- for inst in popen2._active[:]:
- inst.wait()
- popen2._cleanup()
- if popen2._active:
- raise ValueError("_active not empty")
- print("All OK")
-
-main()
-_test()
-reap_children()
+
+ def setUp(self):
+ popen2._cleanup()
+ # When the test runs, there shouldn't be any open pipes
+ self.assertFalse(popen2._active, "Active pipes when test starts" +
+ repr([c.cmd for c in popen2._active]))
+
+ def tearDown(self):
+ for inst in popen2._active:
+ inst.wait()
+ popen2._cleanup()
+ self.assertFalse(popen2._active, "_active not empty")
+ reap_children()
+
+ def validate_output(self, teststr, expected_out, r, w, e=None):
+ w.write(teststr)
+ w.close()
+ got = r.read()
+ self.assertEquals(expected_out, got.strip(), "wrote %r read %r" %
+ (teststr, got))
+
+ if e is not None:
+ got = e.read()
+ self.assertFalse(got, "unexpected %r on stderr" % got)
+
+ def test_popen2(self):
+ r, w = popen2.popen2(self.cmd)
+ self.validate_output(self.teststr, self.expected, r, w)
+
+ def test_popen3(self):
+ if os.name == 'posix':
+ r, w, e = popen2.popen3([self.cmd])
+ self.validate_output(self.teststr, self.expected, r, w, e)
+
+ r, w, e = popen2.popen3(self.cmd)
+ self.validate_output(self.teststr, self.expected, r, w, e)
+
+ def test_os_popen2(self):
+ # same test as test_popen2(), but using the os.popen*() API
+ w, r = os.popen2(self.cmd)
+ self.validate_output(self.teststr, self.expected, r, w)
+
+ def test_os_popen3(self):
+ # same test as test_popen3(), but using the os.popen*() API
+ if os.name == 'posix':
+ w, r, e = os.popen3([self.cmd])
+ self.validate_output(self.teststr, self.expected, r, w, e)
+
+ w, r, e = os.popen3(self.cmd)
+ self.validate_output(self.teststr, self.expected, r, w, e)
+
+
+def test_main():
+ run_unittest(Popen2Test)
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_poplib.py b/Lib/test/test_poplib.py
new file mode 100644
index 0000000..35ff636
--- /dev/null
+++ b/Lib/test/test_poplib.py
@@ -0,0 +1,71 @@
+import socket
+import threading
+import poplib
+import time
+
+from unittest import TestCase
+from test import test_support
+
+
+def server(evt):
+ serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ serv.settimeout(3)
+ serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ serv.bind(("", 9091))
+ serv.listen(5)
+ try:
+ conn, addr = serv.accept()
+ except socket.timeout:
+ pass
+ else:
+ conn.send("+ Hola mundo\n")
+ conn.close()
+ finally:
+ serv.close()
+ evt.set()
+
+class GeneralTests(TestCase):
+
+ def setUp(self):
+ self.evt = threading.Event()
+ threading.Thread(target=server, args=(self.evt,)).start()
+ time.sleep(.1)
+
+ def tearDown(self):
+ self.evt.wait()
+
+ def testBasic(self):
+ # connects
+ pop = poplib.POP3("localhost", 9091)
+ pop.sock.close()
+
+ def testTimeoutDefault(self):
+ # default
+ pop = poplib.POP3("localhost", 9091)
+ self.assertTrue(pop.sock.gettimeout() is None)
+ pop.sock.close()
+
+ def testTimeoutValue(self):
+ # a value
+ pop = poplib.POP3("localhost", 9091, timeout=30)
+ self.assertEqual(pop.sock.gettimeout(), 30)
+ pop.sock.close()
+
+ def testTimeoutNone(self):
+ # None, having other default
+ previous = socket.getdefaulttimeout()
+ socket.setdefaulttimeout(30)
+ try:
+ pop = poplib.POP3("localhost", 9091, timeout=None)
+ finally:
+ socket.setdefaulttimeout(previous)
+ self.assertEqual(pop.sock.gettimeout(), 30)
+ pop.sock.close()
+
+
+
+def test_main(verbose=None):
+ test_support.run_unittest(GeneralTests)
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_posixpath.py b/Lib/test/test_posixpath.py
index 20a1fc5..0abf464 100644
--- a/Lib/test/test_posixpath.py
+++ b/Lib/test/test_posixpath.py
@@ -2,15 +2,29 @@ import unittest
from test import test_support
import posixpath, os
-from posixpath import realpath, abspath, join, dirname, basename
+from posixpath import realpath, abspath, join, dirname, basename, relpath
# An absolute path to a temporary filename for testing. We can't rely on TESTFN
# being an absolute path, so we need this.
ABSTFN = abspath(test_support.TESTFN)
+def safe_rmdir(dirname):
+ try:
+ os.rmdir(dirname)
+ except OSError:
+ pass
+
class PosixPathTest(unittest.TestCase):
+ def setUp(self):
+ self.tearDown()
+
+ def tearDown(self):
+ for suffix in ["", "1", "2"]:
+ test_support.unlink(test_support.TESTFN + suffix)
+ safe_rmdir(test_support.TESTFN + suffix)
+
def assertIs(self, a, b):
self.assert_(a is b)
@@ -43,15 +57,27 @@ class PosixPathTest(unittest.TestCase):
self.assertRaises(TypeError, posixpath.split)
- def test_splitext(self):
- self.assertEqual(posixpath.splitext("foo.ext"), ("foo", ".ext"))
- self.assertEqual(posixpath.splitext("/foo/foo.ext"), ("/foo/foo", ".ext"))
- self.assertEqual(posixpath.splitext(".ext"), ("", ".ext"))
- self.assertEqual(posixpath.splitext("/foo.ext/foo"), ("/foo.ext/foo", ""))
- self.assertEqual(posixpath.splitext("foo.ext/"), ("foo.ext/", ""))
- self.assertEqual(posixpath.splitext(""), ("", ""))
- self.assertEqual(posixpath.splitext("foo.bar.ext"), ("foo.bar", ".ext"))
+ def splitextTest(self, path, filename, ext):
+ self.assertEqual(posixpath.splitext(path), (filename, ext))
+ self.assertEqual(posixpath.splitext("/" + path), ("/" + filename, ext))
+ self.assertEqual(posixpath.splitext("abc/" + path), ("abc/" + filename, ext))
+ self.assertEqual(posixpath.splitext("abc.def/" + path), ("abc.def/" + filename, ext))
+ self.assertEqual(posixpath.splitext("/abc.def/" + path), ("/abc.def/" + filename, ext))
+ self.assertEqual(posixpath.splitext(path + "/"), (filename + ext + "/", ""))
+ def test_splitext(self):
+ self.splitextTest("foo.bar", "foo", ".bar")
+ self.splitextTest("foo.boo.bar", "foo.boo", ".bar")
+ self.splitextTest("foo.boo.biff.bar", "foo.boo.biff", ".bar")
+ self.splitextTest(".csh.rc", ".csh", ".rc")
+ self.splitextTest("nodots", "nodots", "")
+ self.splitextTest(".cshrc", ".cshrc", "")
+ self.splitextTest("...manydots", "...manydots", "")
+ self.splitextTest("...manydots.ext", "...manydots", ".ext")
+ self.splitextTest(".", ".", "")
+ self.splitextTest("..", "..", "")
+ self.splitextTest("........", "........", "")
+ self.splitextTest("", "", "")
self.assertRaises(TypeError, posixpath.splitext)
def test_isabs(self):
@@ -113,7 +139,6 @@ class PosixPathTest(unittest.TestCase):
finally:
if not f.closed:
f.close()
- os.remove(test_support.TESTFN)
def test_time(self):
f = open(test_support.TESTFN, "wb")
@@ -135,7 +160,6 @@ class PosixPathTest(unittest.TestCase):
finally:
if not f.closed:
f.close()
- os.remove(test_support.TESTFN)
def test_islink(self):
self.assertIs(posixpath.islink(test_support.TESTFN + "1"), False)
@@ -154,14 +178,6 @@ class PosixPathTest(unittest.TestCase):
finally:
if not f.close():
f.close()
- try:
- os.remove(test_support.TESTFN + "1")
- except os.error:
- pass
- try:
- os.remove(test_support.TESTFN + "2")
- except os.error:
- pass
self.assertRaises(TypeError, posixpath.islink)
@@ -176,10 +192,6 @@ class PosixPathTest(unittest.TestCase):
finally:
if not f.close():
f.close()
- try:
- os.remove(test_support.TESTFN)
- except os.error:
- pass
self.assertRaises(TypeError, posixpath.exists)
@@ -197,14 +209,6 @@ class PosixPathTest(unittest.TestCase):
finally:
if not f.close():
f.close()
- try:
- os.remove(test_support.TESTFN)
- except os.error:
- pass
- try:
- os.rmdir(test_support.TESTFN)
- except os.error:
- pass
self.assertRaises(TypeError, posixpath.isdir)
@@ -222,67 +226,51 @@ class PosixPathTest(unittest.TestCase):
finally:
if not f.close():
f.close()
- try:
- os.remove(test_support.TESTFN)
- except os.error:
- pass
- try:
- os.rmdir(test_support.TESTFN)
- except os.error:
- pass
self.assertRaises(TypeError, posixpath.isdir)
- def test_samefile(self):
- f = open(test_support.TESTFN + "1", "wb")
- try:
- f.write("foo")
- f.close()
+ def test_samefile(self):
+ f = open(test_support.TESTFN + "1", "wb")
+ try:
+ f.write("foo")
+ f.close()
+ self.assertIs(
+ posixpath.samefile(
+ test_support.TESTFN + "1",
+ test_support.TESTFN + "1"
+ ),
+ True
+ )
+ # If we don't have links, assume that os.stat doesn't return resonable
+ # inode information and thus, that samefile() doesn't work
+ if hasattr(os, "symlink"):
+ os.symlink(
+ test_support.TESTFN + "1",
+ test_support.TESTFN + "2"
+ )
self.assertIs(
posixpath.samefile(
test_support.TESTFN + "1",
- test_support.TESTFN + "1"
+ test_support.TESTFN + "2"
),
True
)
- # If we don't have links, assume that os.stat doesn't return resonable
- # inode information and thus, that samefile() doesn't work
- if hasattr(os, "symlink"):
- os.symlink(
+ os.remove(test_support.TESTFN + "2")
+ f = open(test_support.TESTFN + "2", "wb")
+ f.write("bar")
+ f.close()
+ self.assertIs(
+ posixpath.samefile(
test_support.TESTFN + "1",
test_support.TESTFN + "2"
- )
- self.assertIs(
- posixpath.samefile(
- test_support.TESTFN + "1",
- test_support.TESTFN + "2"
- ),
- True
- )
- os.remove(test_support.TESTFN + "2")
- f = open(test_support.TESTFN + "2", "wb")
- f.write("bar")
- f.close()
- self.assertIs(
- posixpath.samefile(
- test_support.TESTFN + "1",
- test_support.TESTFN + "2"
- ),
- False
- )
- finally:
- if not f.close():
- f.close()
- try:
- os.remove(test_support.TESTFN + "1")
- except os.error:
- pass
- try:
- os.remove(test_support.TESTFN + "2")
- except os.error:
- pass
+ ),
+ False
+ )
+ finally:
+ if not f.close():
+ f.close()
- self.assertRaises(TypeError, posixpath.samefile)
+ self.assertRaises(TypeError, posixpath.samefile)
def test_samestat(self):
f = open(test_support.TESTFN + "1", "wb")
@@ -322,14 +310,6 @@ class PosixPathTest(unittest.TestCase):
finally:
if not f.close():
f.close()
- try:
- os.remove(test_support.TESTFN + "1")
- except os.error:
- pass
- try:
- os.remove(test_support.TESTFN + "2")
- except os.error:
- pass
self.assertRaises(TypeError, posixpath.samestat)
@@ -409,7 +389,7 @@ class PosixPathTest(unittest.TestCase):
os.symlink(ABSTFN+"1", ABSTFN)
self.assertEqual(realpath(ABSTFN), ABSTFN+"1")
finally:
- self.safe_remove(ABSTFN)
+ test_support.unlink(ABSTFN)
def test_realpath_symlink_loops(self):
# Bug #930024, return the path unchanged if we get into an infinite
@@ -429,9 +409,9 @@ class PosixPathTest(unittest.TestCase):
self.assertEqual(realpath(basename(ABSTFN)), ABSTFN)
finally:
os.chdir(old_path)
- self.safe_remove(ABSTFN)
- self.safe_remove(ABSTFN+"1")
- self.safe_remove(ABSTFN+"2")
+ test_support.unlink(ABSTFN)
+ test_support.unlink(ABSTFN+"1")
+ test_support.unlink(ABSTFN+"2")
def test_realpath_resolve_parents(self):
# We also need to resolve any symlinks in the parents of a relative
@@ -448,9 +428,9 @@ class PosixPathTest(unittest.TestCase):
self.assertEqual(realpath("a"), ABSTFN + "/y/a")
finally:
os.chdir(old_path)
- self.safe_remove(ABSTFN + "/k")
- self.safe_rmdir(ABSTFN + "/y")
- self.safe_rmdir(ABSTFN)
+ test_support.unlink(ABSTFN + "/k")
+ safe_rmdir(ABSTFN + "/y")
+ safe_rmdir(ABSTFN)
def test_realpath_resolve_before_normalizing(self):
# Bug #990669: Symbolic links should be resolved before we
@@ -474,10 +454,10 @@ class PosixPathTest(unittest.TestCase):
self.assertEqual(realpath(basename(ABSTFN) + "/link-y/.."), ABSTFN + "/k")
finally:
os.chdir(old_path)
- self.safe_remove(ABSTFN + "/link-y")
- self.safe_rmdir(ABSTFN + "/k/y")
- self.safe_rmdir(ABSTFN + "/k")
- self.safe_rmdir(ABSTFN)
+ test_support.unlink(ABSTFN + "/link-y")
+ safe_rmdir(ABSTFN + "/k/y")
+ safe_rmdir(ABSTFN + "/k")
+ safe_rmdir(ABSTFN)
def test_realpath_resolve_first(self):
# Bug #1213894: The first component of the path, if not absolute,
@@ -495,20 +475,24 @@ class PosixPathTest(unittest.TestCase):
self.assertEqual(realpath(base + "link/k"), ABSTFN + "/k")
finally:
os.chdir(old_path)
- self.safe_remove(ABSTFN + "link")
- self.safe_rmdir(ABSTFN + "/k")
- self.safe_rmdir(ABSTFN)
-
- # Convenience functions for removing temporary files.
- def pass_os_error(self, func, filename):
- try: func(filename)
- except OSError: pass
+ test_support.unlink(ABSTFN + "link")
+ safe_rmdir(ABSTFN + "/k")
+ safe_rmdir(ABSTFN)
- def safe_remove(self, filename):
- self.pass_os_error(os.remove, filename)
-
- def safe_rmdir(self, dirname):
- self.pass_os_error(os.rmdir, dirname)
+ def test_relpath(self):
+ (real_getcwd, os.getcwd) = (os.getcwd, lambda: r"/home/user/bar")
+ try:
+ curdir = os.path.split(os.getcwd())[-1]
+ self.assertRaises(ValueError, posixpath.relpath, "")
+ self.assertEqual(posixpath.relpath("a"), "a")
+ self.assertEqual(posixpath.relpath(posixpath.abspath("a")), "a")
+ self.assertEqual(posixpath.relpath("a/b"), "a/b")
+ self.assertEqual(posixpath.relpath("../a/b"), "../a/b")
+ self.assertEqual(posixpath.relpath("a", "../b"), "../"+curdir+"/a")
+ self.assertEqual(posixpath.relpath("a/b", "../c"), "../"+curdir+"/a/b")
+ self.assertEqual(posixpath.relpath("a", "b/c"), "../../a")
+ finally:
+ os.getcwd = real_getcwd
def test_main():
test_support.run_unittest(PosixPathTest)
diff --git a/Lib/test/test_pty.py b/Lib/test/test_pty.py
index 123c3f8..5ce387b 100644
--- a/Lib/test/test_pty.py
+++ b/Lib/test/test_pty.py
@@ -1,5 +1,9 @@
-import pty, os, sys, signal
-from test.test_support import verbose, TestFailed, TestSkipped
+import pty
+import os
+import sys
+import signal
+from test.test_support import verbose, TestSkipped, run_unittest
+import unittest
TEST_STRING_1 = "I wish to buy a fish license.\n"
TEST_STRING_2 = "For my pet fish, Eric.\n"
@@ -11,6 +15,7 @@ else:
def debug(msg):
pass
+
def normalize_output(data):
# Some operating systems do conversions on newline. We could possibly
# fix that by doing the appropriate termios.tcsetattr()s. I couldn't
@@ -32,116 +37,141 @@ def normalize_output(data):
return data
+
# Marginal testing of pty suite. Cannot do extensive 'do or fail' testing
# because pty code is not too portable.
-
-def test_basic_pty():
- try:
- debug("Calling master_open()")
- master_fd, slave_name = pty.master_open()
- debug("Got master_fd '%d', slave_name '%s'"%(master_fd, slave_name))
- debug("Calling slave_open(%r)"%(slave_name,))
- slave_fd = pty.slave_open(slave_name)
- debug("Got slave_fd '%d'"%slave_fd)
- except OSError:
- # " An optional feature could not be imported " ... ?
- raise TestSkipped, "Pseudo-terminals (seemingly) not functional."
-
- if not os.isatty(slave_fd):
- raise TestFailed, "slave_fd is not a tty"
-
- debug("Writing to slave_fd")
- os.write(slave_fd, TEST_STRING_1)
- s1 = os.read(master_fd, 1024)
- sys.stdout.write(normalize_output(s1))
-
- debug("Writing chunked output")
- os.write(slave_fd, TEST_STRING_2[:5])
- os.write(slave_fd, TEST_STRING_2[5:])
- s2 = os.read(master_fd, 1024)
- sys.stdout.write(normalize_output(s2))
-
- os.close(slave_fd)
- os.close(master_fd)
-
-def handle_sig(sig, frame):
- raise TestFailed, "isatty hung"
-
-# isatty() and close() can hang on some platforms
-# set an alarm before running the test to make sure we don't hang forever
-old_alarm = signal.signal(signal.SIGALRM, handle_sig)
-signal.alarm(10)
-
-try:
- test_basic_pty()
-finally:
- # remove alarm, restore old alarm handler
- signal.alarm(0)
- signal.signal(signal.SIGALRM, old_alarm)
-
-# basic pty passed.
-
-debug("calling pty.fork()")
-pid, master_fd = pty.fork()
-if pid == pty.CHILD:
- # stdout should be connected to a tty.
- if not os.isatty(1):
- debug("Child's fd 1 is not a tty?!")
- os._exit(3)
-
- # After pty.fork(), the child should already be a session leader.
- # (on those systems that have that concept.)
- debug("In child, calling os.setsid()")
- try:
- os.setsid()
- except OSError:
- # Good, we already were session leader
- debug("Good: OSError was raised.")
- pass
- except AttributeError:
- # Have pty, but not setsid() ?
- debug("No setsid() available ?")
- pass
- except:
- # We don't want this error to propagate, escaping the call to
- # os._exit() and causing very peculiar behavior in the calling
- # regrtest.py !
- # Note: could add traceback printing here.
- debug("An unexpected error was raised.")
- os._exit(1)
- else:
- debug("os.setsid() succeeded! (bad!)")
- os._exit(2)
- os._exit(4)
-else:
- debug("Waiting for child (%d) to finish."%pid)
- ##line = os.read(master_fd, 80)
- ##lines = line.replace('\r\n', '\n').split('\n')
- ##if False and lines != ['In child, calling os.setsid()',
- ## 'Good: OSError was raised.', '']:
- ## raise TestFailed("Unexpected output from child: %r" % line)
-
- (pid, status) = os.waitpid(pid, 0)
- res = status >> 8
- debug("Child (%d) exited with status %d (%d)."%(pid, res, status))
- if res == 1:
- raise TestFailed, "Child raised an unexpected exception in os.setsid()"
- elif res == 2:
- raise TestFailed, "pty.fork() failed to make child a session leader."
- elif res == 3:
- raise TestFailed, "Child spawned by pty.fork() did not have a tty as stdout"
- elif res != 4:
- raise TestFailed, "pty.fork() failed for unknown reasons."
-
- ##debug("Reading from master_fd now that the child has exited")
- ##try:
- ## s1 = os.read(master_fd, 1024)
- ##except os.error:
- ## pass
- ##else:
- ## raise TestFailed("Read from master_fd did not raise exception")
-
-
-os.close(master_fd)
-
-# pty.fork() passed.
+class PtyTest(unittest.TestCase):
+ def setUp(self):
+ # isatty() and close() can hang on some platforms. Set an alarm
+ # before running the test to make sure we don't hang forever.
+ self.old_alarm = signal.signal(signal.SIGALRM, self.handle_sig)
+ signal.alarm(10)
+
+ def tearDown(self):
+ # remove alarm, restore old alarm handler
+ signal.alarm(0)
+ signal.signal(signal.SIGALRM, self.old_alarm)
+
+ def handle_sig(self, sig, frame):
+ self.fail("isatty hung")
+
+ def test_basic(self):
+ try:
+ debug("Calling master_open()")
+ master_fd, slave_name = pty.master_open()
+ debug("Got master_fd '%d', slave_name '%s'" %
+ (master_fd, slave_name))
+ debug("Calling slave_open(%r)" % (slave_name,))
+ slave_fd = pty.slave_open(slave_name)
+ debug("Got slave_fd '%d'" % slave_fd)
+ except OSError:
+ # " An optional feature could not be imported " ... ?
+ raise TestSkipped, "Pseudo-terminals (seemingly) not functional."
+
+ self.assertTrue(os.isatty(slave_fd), 'slave_fd is not a tty')
+
+ debug("Writing to slave_fd")
+ os.write(slave_fd, TEST_STRING_1)
+ s1 = os.read(master_fd, 1024)
+ self.assertEquals('I wish to buy a fish license.\n',
+ normalize_output(s1))
+
+ debug("Writing chunked output")
+ os.write(slave_fd, TEST_STRING_2[:5])
+ os.write(slave_fd, TEST_STRING_2[5:])
+ s2 = os.read(master_fd, 1024)
+ self.assertEquals('For my pet fish, Eric.\n', normalize_output(s2))
+
+ os.close(slave_fd)
+ os.close(master_fd)
+
+
+ def test_fork(self):
+ debug("calling pty.fork()")
+ pid, master_fd = pty.fork()
+ if pid == pty.CHILD:
+ # stdout should be connected to a tty.
+ if not os.isatty(1):
+ debug("Child's fd 1 is not a tty?!")
+ os._exit(3)
+
+ # After pty.fork(), the child should already be a session leader.
+ # (on those systems that have that concept.)
+ debug("In child, calling os.setsid()")
+ try:
+ os.setsid()
+ except OSError:
+ # Good, we already were session leader
+ debug("Good: OSError was raised.")
+ pass
+ except AttributeError:
+ # Have pty, but not setsid()?
+ debug("No setsid() available?")
+ pass
+ except:
+ # We don't want this error to propagate, escaping the call to
+ # os._exit() and causing very peculiar behavior in the calling
+ # regrtest.py !
+ # Note: could add traceback printing here.
+ debug("An unexpected error was raised.")
+ os._exit(1)
+ else:
+ debug("os.setsid() succeeded! (bad!)")
+ os._exit(2)
+ os._exit(4)
+ else:
+ debug("Waiting for child (%d) to finish." % pid)
+ # In verbose mode, we have to consume the debug output from the
+ # child or the child will block, causing this test to hang in the
+ # parent's waitpid() call. The child blocks after a
+ # platform-dependent amount of data is written to its fd. On
+ # Linux 2.6, it's 4000 bytes and the child won't block, but on OS
+ # X even the small writes in the child above will block it. Also
+ # on Linux, the read() will throw an OSError (input/output error)
+ # when it tries to read past the end of the buffer but the child's
+ # already exited, so catch and discard those exceptions. It's not
+ # worth checking for EIO.
+ while True:
+ try:
+ data = os.read(master_fd, 80)
+ except OSError:
+ break
+ if not data:
+ break
+ sys.stdout.write(data.replace('\r\n', '\n'))
+
+ ##line = os.read(master_fd, 80)
+ ##lines = line.replace('\r\n', '\n').split('\n')
+ ##if False and lines != ['In child, calling os.setsid()',
+ ## 'Good: OSError was raised.', '']:
+ ## raise TestFailed("Unexpected output from child: %r" % line)
+
+ (pid, status) = os.waitpid(pid, 0)
+ res = status >> 8
+ debug("Child (%d) exited with status %d (%d)." % (pid, res, status))
+ if res == 1:
+ self.fail("Child raised an unexpected exception in os.setsid()")
+ elif res == 2:
+ self.fail("pty.fork() failed to make child a session leader.")
+ elif res == 3:
+ self.fail("Child spawned by pty.fork() did not have a tty as stdout")
+ elif res != 4:
+ self.fail("pty.fork() failed for unknown reasons.")
+
+ ##debug("Reading from master_fd now that the child has exited")
+ ##try:
+ ## s1 = os.read(master_fd, 1024)
+ ##except os.error:
+ ## pass
+ ##else:
+ ## raise TestFailed("Read from master_fd did not raise exception")
+
+ os.close(master_fd)
+
+ # pty.fork() passed.
+
+def test_main(verbose=None):
+ run_unittest(PtyTest)
+
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_pyexpat.py b/Lib/test/test_pyexpat.py
index 9772a00..0900d1e 100644
--- a/Lib/test/test_pyexpat.py
+++ b/Lib/test/test_pyexpat.py
@@ -1,108 +1,40 @@
-# Very simple test - Parse a file and print what happens
-
# XXX TypeErrors on calling handlers, or on bad return values from a
# handler, are obscure and unhelpful.
+import StringIO
+import unittest
+
import pyexpat
from xml.parsers import expat
-from test.test_support import sortdict, TestFailed
+from test.test_support import sortdict, run_unittest
-class Outputter:
- def StartElementHandler(self, name, attrs):
- print('Start element:\n\t' + repr(name), sortdict(attrs))
- def EndElementHandler(self, name):
- print('End element:\n\t' + repr(name))
-
- def CharacterDataHandler(self, data):
- data = data.strip()
- if data:
- print('Character data:')
- print('\t' + repr(data))
-
- def ProcessingInstructionHandler(self, target, data):
- print('PI:\n\t' + repr(target), repr(data))
-
- def StartNamespaceDeclHandler(self, prefix, uri):
- print('NS decl:\n\t' + repr(prefix), repr(uri))
-
- def EndNamespaceDeclHandler(self, prefix):
- print('End of NS decl:\n\t' + repr(prefix))
-
- def StartCdataSectionHandler(self):
- print('Start of CDATA section')
-
- def EndCdataSectionHandler(self):
- print('End of CDATA section')
-
- def CommentHandler(self, text):
- print('Comment:\n\t' + repr(text))
-
- def NotationDeclHandler(self, *args):
- name, base, sysid, pubid = args
- print('Notation declared:', args)
-
- def UnparsedEntityDeclHandler(self, *args):
- entityName, base, systemId, publicId, notationName = args
- print('Unparsed entity decl:\n\t' + str(args))
-
- def NotStandaloneHandler(self, userData):
- print('Not standalone')
- return 1
-
- def ExternalEntityRefHandler(self, *args):
- context, base, sysId, pubId = args
- print('External entity ref:', args[1:])
- return 1
-
- def DefaultHandler(self, userData):
- pass
-
- def DefaultHandlerExpand(self, userData):
- pass
-
-
-def confirm(ok):
- if ok:
- print("OK.")
- else:
- print("Not OK.")
-
-out = Outputter()
-parser = expat.ParserCreate(namespace_separator='!')
-
-# Test getting/setting returns_unicode
-parser.returns_unicode = 0; confirm(parser.returns_unicode == 0)
-parser.returns_unicode = 1; confirm(parser.returns_unicode == 1)
-parser.returns_unicode = 2; confirm(parser.returns_unicode == 1)
-parser.returns_unicode = 0; confirm(parser.returns_unicode == 0)
-
-# Test getting/setting ordered_attributes
-parser.ordered_attributes = 0; confirm(parser.ordered_attributes == 0)
-parser.ordered_attributes = 1; confirm(parser.ordered_attributes == 1)
-parser.ordered_attributes = 2; confirm(parser.ordered_attributes == 1)
-parser.ordered_attributes = 0; confirm(parser.ordered_attributes == 0)
-
-# Test getting/setting specified_attributes
-parser.specified_attributes = 0; confirm(parser.specified_attributes == 0)
-parser.specified_attributes = 1; confirm(parser.specified_attributes == 1)
-parser.specified_attributes = 2; confirm(parser.specified_attributes == 1)
-parser.specified_attributes = 0; confirm(parser.specified_attributes == 0)
-
-HANDLER_NAMES = [
- 'StartElementHandler', 'EndElementHandler',
- 'CharacterDataHandler', 'ProcessingInstructionHandler',
- 'UnparsedEntityDeclHandler', 'NotationDeclHandler',
- 'StartNamespaceDeclHandler', 'EndNamespaceDeclHandler',
- 'CommentHandler', 'StartCdataSectionHandler',
- 'EndCdataSectionHandler',
- 'DefaultHandler', 'DefaultHandlerExpand',
- #'NotStandaloneHandler',
- 'ExternalEntityRefHandler'
- ]
-for name in HANDLER_NAMES:
- setattr(parser, name, getattr(out, name))
+class SetAttributeTest(unittest.TestCase):
+ def setUp(self):
+ self.parser = expat.ParserCreate(namespace_separator='!')
+ self.set_get_pairs = [
+ [0, 0],
+ [1, 1],
+ [2, 1],
+ [0, 0],
+ ]
+
+ def test_returns_unicode(self):
+ for x, y in self.set_get_pairs:
+ self.parser.returns_unicode = x
+ self.assertEquals(self.parser.returns_unicode, y)
+
+ def test_ordered_attributes(self):
+ for x, y in self.set_get_pairs:
+ self.parser.ordered_attributes = x
+ self.assertEquals(self.parser.ordered_attributes, y)
+
+ def test_specified_attributes(self):
+ for x, y in self.set_get_pairs:
+ self.parser.specified_attributes = x
+ self.assertEquals(self.parser.specified_attributes, y)
+
data = '''\
<?xml version="1.0" encoding="iso-8859-1" standalone="no"?>
@@ -126,108 +58,228 @@ data = '''\
</root>
'''
+
# Produce UTF-8 output
-parser.returns_unicode = 0
-try:
- parser.Parse(data, 1)
-except expat.error:
- print('** Error', parser.ErrorCode, expat.ErrorString(parser.ErrorCode))
- print('** Line', parser.ErrorLineNumber)
- print('** Column', parser.ErrorColumnNumber)
- print('** Byte', parser.ErrorByteIndex)
-
-# Try the parse again, this time producing Unicode output
-parser = expat.ParserCreate(namespace_separator='!')
-parser.returns_unicode = 1
-
-for name in HANDLER_NAMES:
- setattr(parser, name, getattr(out, name))
-try:
- parser.Parse(data, 1)
-except expat.error:
- print('** Error', parser.ErrorCode, expat.ErrorString(parser.ErrorCode))
- print('** Line', parser.ErrorLineNumber)
- print('** Column', parser.ErrorColumnNumber)
- print('** Byte', parser.ErrorByteIndex)
-
-# Try parsing a file
-parser = expat.ParserCreate(namespace_separator='!')
-parser.returns_unicode = 1
-
-for name in HANDLER_NAMES:
- setattr(parser, name, getattr(out, name))
-import StringIO
-file = StringIO.StringIO(data)
-try:
- parser.ParseFile(file)
-except expat.error:
- print('** Error', parser.ErrorCode, expat.ErrorString(parser.ErrorCode))
- print('** Line', parser.ErrorLineNumber)
- print('** Column', parser.ErrorColumnNumber)
- print('** Byte', parser.ErrorByteIndex)
-
-
-# Tests that make sure we get errors when the namespace_separator value
-# is illegal, and that we don't for good values:
-print()
-print("Testing constructor for proper handling of namespace_separator values:")
-expat.ParserCreate()
-expat.ParserCreate(namespace_separator=None)
-expat.ParserCreate(namespace_separator=' ')
-print("Legal values tested o.k.")
-try:
- expat.ParserCreate(namespace_separator=42)
-except TypeError as e:
- print("Caught expected TypeError:")
- print(e)
-else:
- print("Failed to catch expected TypeError.")
-
-try:
- expat.ParserCreate(namespace_separator='too long')
-except ValueError as e:
- print("Caught expected ValueError:")
- print(e)
-else:
- print("Failed to catch expected ValueError.")
-
-# ParserCreate() needs to accept a namespace_separator of zero length
-# to satisfy the requirements of RDF applications that are required
-# to simply glue together the namespace URI and the localname. Though
-# considered a wart of the RDF specifications, it needs to be supported.
-#
-# See XML-SIG mailing list thread starting with
-# http://mail.python.org/pipermail/xml-sig/2001-April/005202.html
-#
-expat.ParserCreate(namespace_separator='') # too short
-
-# Test the interning machinery.
-p = expat.ParserCreate()
-L = []
-def collector(name, *args):
- L.append(name)
-p.StartElementHandler = collector
-p.EndElementHandler = collector
-p.Parse("<e> <e/> <e></e> </e>", 1)
-tag = L[0]
-if len(L) != 6:
- print("L should only contain 6 entries; found", len(L))
-for entry in L:
- if tag is not entry:
- print("expected L to contain many references to the same string", end=' ')
- print("(it didn't)")
- print("L =", repr(L))
- break
-
-# Tests of the buffer_text attribute.
-import sys
-
-class TextCollector:
- def __init__(self, parser):
+class ParseTest(unittest.TestCase):
+ class Outputter:
+ def __init__(self):
+ self.out = []
+
+ def StartElementHandler(self, name, attrs):
+ self.out.append('Start element: ' + repr(name) + ' ' +
+ sortdict(attrs))
+
+ def EndElementHandler(self, name):
+ self.out.append('End element: ' + repr(name))
+
+ def CharacterDataHandler(self, data):
+ data = data.strip()
+ if data:
+ self.out.append('Character data: ' + repr(data))
+
+ def ProcessingInstructionHandler(self, target, data):
+ self.out.append('PI: ' + repr(target) + ' ' + repr(data))
+
+ def StartNamespaceDeclHandler(self, prefix, uri):
+ self.out.append('NS decl: ' + repr(prefix) + ' ' + repr(uri))
+
+ def EndNamespaceDeclHandler(self, prefix):
+ self.out.append('End of NS decl: ' + repr(prefix))
+
+ def StartCdataSectionHandler(self):
+ self.out.append('Start of CDATA section')
+
+ def EndCdataSectionHandler(self):
+ self.out.append('End of CDATA section')
+
+ def CommentHandler(self, text):
+ self.out.append('Comment: ' + repr(text))
+
+ def NotationDeclHandler(self, *args):
+ name, base, sysid, pubid = args
+ self.out.append('Notation declared: %s' %(args,))
+
+ def UnparsedEntityDeclHandler(self, *args):
+ entityName, base, systemId, publicId, notationName = args
+ self.out.append('Unparsed entity decl: %s' %(args,))
+
+ def NotStandaloneHandler(self, userData):
+ self.out.append('Not standalone')
+ return 1
+
+ def ExternalEntityRefHandler(self, *args):
+ context, base, sysId, pubId = args
+ self.out.append('External entity ref: %s' %(args[1:],))
+ return 1
+
+ def DefaultHandler(self, userData):
+ pass
+
+ def DefaultHandlerExpand(self, userData):
+ pass
+
+ handler_names = [
+ 'StartElementHandler', 'EndElementHandler',
+ 'CharacterDataHandler', 'ProcessingInstructionHandler',
+ 'UnparsedEntityDeclHandler', 'NotationDeclHandler',
+ 'StartNamespaceDeclHandler', 'EndNamespaceDeclHandler',
+ 'CommentHandler', 'StartCdataSectionHandler',
+ 'EndCdataSectionHandler',
+ 'DefaultHandler', 'DefaultHandlerExpand',
+ #'NotStandaloneHandler',
+ 'ExternalEntityRefHandler'
+ ]
+
+ def test_utf8(self):
+
+ out = self.Outputter()
+ parser = expat.ParserCreate(namespace_separator='!')
+ for name in self.handler_names:
+ setattr(parser, name, getattr(out, name))
+ parser.returns_unicode = 0
+ parser.Parse(data, 1)
+
+ # Verify output
+ op = out.out
+ self.assertEquals(op[0], 'PI: \'xml-stylesheet\' \'href="stylesheet.css"\'')
+ self.assertEquals(op[1], "Comment: ' comment data '")
+ self.assertEquals(op[2], "Notation declared: ('notation', None, 'notation.jpeg', None)")
+ self.assertEquals(op[3], "Unparsed entity decl: ('unparsed_entity', None, 'entity.file', None, 'notation')")
+ self.assertEquals(op[4], "Start element: 'root' {'attr1': 'value1', 'attr2': 'value2\\xe1\\xbd\\x80'}")
+ self.assertEquals(op[5], "NS decl: 'myns' 'http://www.python.org/namespace'")
+ self.assertEquals(op[6], "Start element: 'http://www.python.org/namespace!subelement' {}")
+ self.assertEquals(op[7], "Character data: 'Contents of subelements'")
+ self.assertEquals(op[8], "End element: 'http://www.python.org/namespace!subelement'")
+ self.assertEquals(op[9], "End of NS decl: 'myns'")
+ self.assertEquals(op[10], "Start element: 'sub2' {}")
+ self.assertEquals(op[11], 'Start of CDATA section')
+ self.assertEquals(op[12], "Character data: 'contents of CDATA section'")
+ self.assertEquals(op[13], 'End of CDATA section')
+ self.assertEquals(op[14], "End element: 'sub2'")
+ self.assertEquals(op[15], "External entity ref: (None, 'entity.file', None)")
+ self.assertEquals(op[16], "End element: 'root'")
+
+ def test_unicode(self):
+ # Try the parse again, this time producing Unicode output
+ out = self.Outputter()
+ parser = expat.ParserCreate(namespace_separator='!')
+ parser.returns_unicode = 1
+ for name in self.handler_names:
+ setattr(parser, name, getattr(out, name))
+
+ parser.Parse(data, 1)
+
+ op = out.out
+ self.assertEquals(op[0], 'PI: u\'xml-stylesheet\' u\'href="stylesheet.css"\'')
+ self.assertEquals(op[1], "Comment: u' comment data '")
+ self.assertEquals(op[2], "Notation declared: (u'notation', None, u'notation.jpeg', None)")
+ self.assertEquals(op[3], "Unparsed entity decl: (u'unparsed_entity', None, u'entity.file', None, u'notation')")
+ self.assertEquals(op[4], "Start element: u'root' {u'attr1': u'value1', u'attr2': u'value2\\u1f40'}")
+ self.assertEquals(op[5], "NS decl: u'myns' u'http://www.python.org/namespace'")
+ self.assertEquals(op[6], "Start element: u'http://www.python.org/namespace!subelement' {}")
+ self.assertEquals(op[7], "Character data: u'Contents of subelements'")
+ self.assertEquals(op[8], "End element: u'http://www.python.org/namespace!subelement'")
+ self.assertEquals(op[9], "End of NS decl: u'myns'")
+ self.assertEquals(op[10], "Start element: u'sub2' {}")
+ self.assertEquals(op[11], 'Start of CDATA section')
+ self.assertEquals(op[12], "Character data: u'contents of CDATA section'")
+ self.assertEquals(op[13], 'End of CDATA section')
+ self.assertEquals(op[14], "End element: u'sub2'")
+ self.assertEquals(op[15], "External entity ref: (None, u'entity.file', None)")
+ self.assertEquals(op[16], "End element: u'root'")
+
+ def test_parse_file(self):
+ # Try parsing a file
+ out = self.Outputter()
+ parser = expat.ParserCreate(namespace_separator='!')
+ parser.returns_unicode = 1
+ for name in self.handler_names:
+ setattr(parser, name, getattr(out, name))
+ file = StringIO.StringIO(data)
+
+ parser.ParseFile(file)
+
+ op = out.out
+ self.assertEquals(op[0], 'PI: u\'xml-stylesheet\' u\'href="stylesheet.css"\'')
+ self.assertEquals(op[1], "Comment: u' comment data '")
+ self.assertEquals(op[2], "Notation declared: (u'notation', None, u'notation.jpeg', None)")
+ self.assertEquals(op[3], "Unparsed entity decl: (u'unparsed_entity', None, u'entity.file', None, u'notation')")
+ self.assertEquals(op[4], "Start element: u'root' {u'attr1': u'value1', u'attr2': u'value2\\u1f40'}")
+ self.assertEquals(op[5], "NS decl: u'myns' u'http://www.python.org/namespace'")
+ self.assertEquals(op[6], "Start element: u'http://www.python.org/namespace!subelement' {}")
+ self.assertEquals(op[7], "Character data: u'Contents of subelements'")
+ self.assertEquals(op[8], "End element: u'http://www.python.org/namespace!subelement'")
+ self.assertEquals(op[9], "End of NS decl: u'myns'")
+ self.assertEquals(op[10], "Start element: u'sub2' {}")
+ self.assertEquals(op[11], 'Start of CDATA section')
+ self.assertEquals(op[12], "Character data: u'contents of CDATA section'")
+ self.assertEquals(op[13], 'End of CDATA section')
+ self.assertEquals(op[14], "End element: u'sub2'")
+ self.assertEquals(op[15], "External entity ref: (None, u'entity.file', None)")
+ self.assertEquals(op[16], "End element: u'root'")
+
+
+class NamespaceSeparatorTest(unittest.TestCase):
+ def test_legal(self):
+ # Tests that make sure we get errors when the namespace_separator value
+ # is illegal, and that we don't for good values:
+ expat.ParserCreate()
+ expat.ParserCreate(namespace_separator=None)
+ expat.ParserCreate(namespace_separator=' ')
+
+ def test_illegal(self):
+ try:
+ expat.ParserCreate(namespace_separator=42)
+ self.fail()
+ except TypeError as e:
+ self.assertEquals(str(e),
+ 'ParserCreate() argument 2 must be string or None, not int')
+
+ try:
+ expat.ParserCreate(namespace_separator='too long')
+ self.fail()
+ except ValueError as e:
+ self.assertEquals(str(e),
+ 'namespace_separator must be at most one character, omitted, or None')
+
+ def test_zero_length(self):
+ # ParserCreate() needs to accept a namespace_separator of zero length
+ # to satisfy the requirements of RDF applications that are required
+ # to simply glue together the namespace URI and the localname. Though
+ # considered a wart of the RDF specifications, it needs to be supported.
+ #
+ # See XML-SIG mailing list thread starting with
+ # http://mail.python.org/pipermail/xml-sig/2001-April/005202.html
+ #
+ expat.ParserCreate(namespace_separator='') # too short
+
+
+class InterningTest(unittest.TestCase):
+ def test(self):
+ # Test the interning machinery.
+ p = expat.ParserCreate()
+ L = []
+ def collector(name, *args):
+ L.append(name)
+ p.StartElementHandler = collector
+ p.EndElementHandler = collector
+ p.Parse("<e> <e/> <e></e> </e>", 1)
+ tag = L[0]
+ self.assertEquals(len(L), 6)
+ for entry in L:
+ # L should have the same string repeated over and over.
+ self.assertTrue(tag is entry)
+
+
+class BufferTextTest(unittest.TestCase):
+ def setUp(self):
self.stuff = []
+ self.parser = expat.ParserCreate()
+ self.parser.buffer_text = 1
+ self.parser.CharacterDataHandler = self.CharacterDataHandler
def check(self, expected, label):
- require(self.stuff == expected,
+ self.assertEquals(self.stuff, expected,
"%s\nstuff = %r\nexpected = %r"
% (label, self.stuff, map(unicode, expected)))
@@ -238,9 +290,9 @@ class TextCollector:
self.stuff.append("<%s>" % name)
bt = attrs.get("buffer-text")
if bt == "yes":
- parser.buffer_text = 1
+ self.parser.buffer_text = 1
elif bt == "no":
- parser.buffer_text = 0
+ self.parser.buffer_text = 0
def EndElementHandler(self, name):
self.stuff.append("</%s>" % name)
@@ -248,95 +300,91 @@ class TextCollector:
def CommentHandler(self, data):
self.stuff.append("<!--%s-->" % data)
-def require(cond, label):
- # similar to confirm(), but no extraneous output
- if not cond:
- raise TestFailed(label)
-
-def setup(handlers=[]):
- parser = expat.ParserCreate()
- require(not parser.buffer_text,
- "buffer_text not disabled by default")
- parser.buffer_text = 1
- handler = TextCollector(parser)
- parser.CharacterDataHandler = handler.CharacterDataHandler
- for name in handlers:
- setattr(parser, name, getattr(handler, name))
- return parser, handler
-
-parser, handler = setup()
-require(parser.buffer_text,
- "text buffering either not acknowledged or not enabled")
-parser.Parse("<a>1<b/>2<c/>3</a>", 1)
-handler.check(["123"],
- "buffered text not properly collapsed")
-
-# XXX This test exposes more detail of Expat's text chunking than we
-# XXX like, but it tests what we need to concisely.
-parser, handler = setup(["StartElementHandler"])
-parser.Parse("<a>1<b buffer-text='no'/>2\n3<c buffer-text='yes'/>4\n5</a>", 1)
-handler.check(["<a>", "1", "<b>", "2", "\n", "3", "<c>", "4\n5"],
- "buffering control not reacting as expected")
-
-parser, handler = setup()
-parser.Parse("<a>1<b/>&lt;2&gt;<c/>&#32;\n&#x20;3</a>", 1)
-handler.check(["1<2> \n 3"],
- "buffered text not properly collapsed")
-
-parser, handler = setup(["StartElementHandler"])
-parser.Parse("<a>1<b/>2<c/>3</a>", 1)
-handler.check(["<a>", "1", "<b>", "2", "<c>", "3"],
- "buffered text not properly split")
-
-parser, handler = setup(["StartElementHandler", "EndElementHandler"])
-parser.CharacterDataHandler = None
-parser.Parse("<a>1<b/>2<c/>3</a>", 1)
-handler.check(["<a>", "<b>", "</b>", "<c>", "</c>", "</a>"],
- "huh?")
-
-parser, handler = setup(["StartElementHandler", "EndElementHandler"])
-parser.Parse("<a>1<b></b>2<c/>3</a>", 1)
-handler.check(["<a>", "1", "<b>", "</b>", "2", "<c>", "</c>", "3", "</a>"],
- "huh?")
-
-parser, handler = setup(["CommentHandler", "EndElementHandler",
- "StartElementHandler"])
-parser.Parse("<a>1<b/>2<c></c>345</a> ", 1)
-handler.check(["<a>", "1", "<b>", "</b>", "2", "<c>", "</c>", "345", "</a>"],
- "buffered text not properly split")
-
-parser, handler = setup(["CommentHandler", "EndElementHandler",
- "StartElementHandler"])
-parser.Parse("<a>1<b/>2<c></c>3<!--abc-->4<!--def-->5</a> ", 1)
-handler.check(["<a>", "1", "<b>", "</b>", "2", "<c>", "</c>", "3",
- "<!--abc-->", "4", "<!--def-->", "5", "</a>"],
- "buffered text not properly split")
+ def setHandlers(self, handlers=[]):
+ for name in handlers:
+ setattr(self.parser, name, getattr(self, name))
+
+ def test_default_to_disabled(self):
+ parser = expat.ParserCreate()
+ self.assertFalse(parser.buffer_text)
+
+ def test_buffering_enabled(self):
+ # Make sure buffering is turned on
+ self.assertTrue(self.parser.buffer_text)
+ self.parser.Parse("<a>1<b/>2<c/>3</a>", 1)
+ self.assertEquals(self.stuff, ['123'],
+ "buffered text not properly collapsed")
+
+ def test1(self):
+ # XXX This test exposes more detail of Expat's text chunking than we
+ # XXX like, but it tests what we need to concisely.
+ self.setHandlers(["StartElementHandler"])
+ self.parser.Parse("<a>1<b buffer-text='no'/>2\n3<c buffer-text='yes'/>4\n5</a>", 1)
+ self.assertEquals(self.stuff,
+ ["<a>", "1", "<b>", "2", "\n", "3", "<c>", "4\n5"],
+ "buffering control not reacting as expected")
+
+ def test2(self):
+ self.parser.Parse("<a>1<b/>&lt;2&gt;<c/>&#32;\n&#x20;3</a>", 1)
+ self.assertEquals(self.stuff, ["1<2> \n 3"],
+ "buffered text not properly collapsed")
+
+ def test3(self):
+ self.setHandlers(["StartElementHandler"])
+ self.parser.Parse("<a>1<b/>2<c/>3</a>", 1)
+ self.assertEquals(self.stuff, ["<a>", "1", "<b>", "2", "<c>", "3"],
+ "buffered text not properly split")
+
+ def test4(self):
+ self.setHandlers(["StartElementHandler", "EndElementHandler"])
+ self.parser.CharacterDataHandler = None
+ self.parser.Parse("<a>1<b/>2<c/>3</a>", 1)
+ self.assertEquals(self.stuff,
+ ["<a>", "<b>", "</b>", "<c>", "</c>", "</a>"])
+
+ def test5(self):
+ self.setHandlers(["StartElementHandler", "EndElementHandler"])
+ self.parser.Parse("<a>1<b></b>2<c/>3</a>", 1)
+ self.assertEquals(self.stuff,
+ ["<a>", "1", "<b>", "</b>", "2", "<c>", "</c>", "3", "</a>"])
+
+ def test6(self):
+ self.setHandlers(["CommentHandler", "EndElementHandler",
+ "StartElementHandler"])
+ self.parser.Parse("<a>1<b/>2<c></c>345</a> ", 1)
+ self.assertEquals(self.stuff,
+ ["<a>", "1", "<b>", "</b>", "2", "<c>", "</c>", "345", "</a>"],
+ "buffered text not properly split")
+
+ def test7(self):
+ self.setHandlers(["CommentHandler", "EndElementHandler",
+ "StartElementHandler"])
+ self.parser.Parse("<a>1<b/>2<c></c>3<!--abc-->4<!--def-->5</a> ", 1)
+ self.assertEquals(self.stuff,
+ ["<a>", "1", "<b>", "</b>", "2", "<c>", "</c>", "3",
+ "<!--abc-->", "4", "<!--def-->", "5", "</a>"],
+ "buffered text not properly split")
+
# Test handling of exception from callback:
-def StartElementHandler(name, attrs):
- raise RuntimeError(name)
+class HandlerExceptionTest(unittest.TestCase):
+ def StartElementHandler(self, name, attrs):
+ raise RuntimeError(name)
-parser = expat.ParserCreate()
-parser.StartElementHandler = StartElementHandler
+ def test(self):
+ parser = expat.ParserCreate()
+ parser.StartElementHandler = self.StartElementHandler
+ try:
+ parser.Parse("<a><b><c/></b></a>", 1)
+ self.fail()
+ except RuntimeError as e:
+ self.assertEquals(e.args[0], 'a',
+ "Expected RuntimeError for element 'a', but" + \
+ " found %r" % e.args[0])
-try:
- parser.Parse("<a><b><c/></b></a>", 1)
-except RuntimeError as e:
- if e.args[0] != "a":
- print("Expected RuntimeError for element 'a'; found %r" % e.args[0])
-else:
- print("Expected RuntimeError for 'a'")
# Test Current* members:
-class PositionTest:
-
- def __init__(self, expected_list, parser):
- self.parser = parser
- self.parser.StartElementHandler = self.StartElementHandler
- self.parser.EndElementHandler = self.EndElementHandler
- self.expected_list = expected_list
- self.upto = 0
-
+class PositionTest(unittest.TestCase):
def StartElementHandler(self, name, attrs):
self.check_pos('s')
@@ -348,41 +396,54 @@ class PositionTest:
self.parser.CurrentByteIndex,
self.parser.CurrentLineNumber,
self.parser.CurrentColumnNumber)
- require(self.upto < len(self.expected_list),
- 'too many parser events')
+ self.assertTrue(self.upto < len(self.expected_list),
+ 'too many parser events')
expected = self.expected_list[self.upto]
- require(pos == expected,
- 'expected position %s, got %s' % (expected, pos))
+ self.assertEquals(pos, expected,
+ 'Expected position %s, got position %s' %(pos, expected))
self.upto += 1
+ def test(self):
+ self.parser = expat.ParserCreate()
+ self.parser.StartElementHandler = self.StartElementHandler
+ self.parser.EndElementHandler = self.EndElementHandler
+ self.upto = 0
+ self.expected_list = [('s', 0, 1, 0), ('s', 5, 2, 1), ('s', 11, 3, 2),
+ ('e', 15, 3, 6), ('e', 17, 4, 1), ('e', 22, 5, 0)]
+
+ xml = '<a>\n <b>\n <c/>\n </b>\n</a>'
+ self.parser.Parse(xml, 1)
+
+
+class sf1296433Test(unittest.TestCase):
+ def test_parse_only_xml_data(self):
+ # http://python.org/sf/1296433
+ #
+ xml = "<?xml version='1.0' encoding='iso8859'?><s>%s</s>" % ('a' * 1025)
+ # this one doesn't crash
+ #xml = "<?xml version='1.0'?><s>%s</s>" % ('a' * 10000)
-parser = expat.ParserCreate()
-handler = PositionTest([('s', 0, 1, 0), ('s', 5, 2, 1), ('s', 11, 3, 2),
- ('e', 15, 3, 6), ('e', 17, 4, 1), ('e', 22, 5, 0)],
- parser)
-parser.Parse('''<a>
- <b>
- <c/>
- </b>
-</a>''', 1)
+ class SpecificException(Exception):
+ pass
+ def handler(text):
+ raise SpecificException
-def test_parse_only_xml_data():
- # http://python.org/sf/1296433
- #
- xml = "<?xml version='1.0' encoding='iso8859'?><s>%s</s>" % ('a' * 1025)
- # this one doesn't crash
- #xml = "<?xml version='1.0'?><s>%s</s>" % ('a' * 10000)
+ parser = expat.ParserCreate()
+ parser.CharacterDataHandler = handler
- def handler(text):
- raise Exception
+ self.assertRaises(Exception, parser.Parse, xml)
- parser = expat.ParserCreate()
- parser.CharacterDataHandler = handler
- try:
- parser.Parse(xml)
- except:
- pass
+def test_main():
+ run_unittest(SetAttributeTest,
+ ParseTest,
+ NamespaceSeparatorTest,
+ InterningTest,
+ BufferTextTest,
+ HandlerExceptionTest,
+ PositionTest,
+ sf1296433Test)
-test_parse_only_xml_data()
+if __name__ == "__main__":
+ test_main()
diff --git a/Lib/test/test_re.py b/Lib/test/test_re.py
index bca2ecf..13fa413 100644
--- a/Lib/test/test_re.py
+++ b/Lib/test/test_re.py
@@ -1,7 +1,7 @@
import sys
sys.path = ['.'] + sys.path
-from test.test_support import verbose, run_unittest
+from test.test_support import verbose, run_unittest, guard_warnings_filter
import re
from re import Scanner
import sys, os, traceback
@@ -418,6 +418,12 @@ class ReTests(unittest.TestCase):
pass # cPickle not found -- skip it
else:
self.pickle_test(cPickle)
+ # old pickles expect the _compile() reconstructor in sre module
+ import warnings
+ with guard_warnings_filter():
+ warnings.filterwarnings("ignore", "The sre module is deprecated",
+ DeprecationWarning)
+ from sre import _compile
def pickle_test(self, pickle):
oldpat = re.compile('a(?:b|(c|e){1,2}?|d)+?(.)')
@@ -599,6 +605,13 @@ class ReTests(unittest.TestCase):
self.assertEqual(next(iter).span(), (4, 4))
self.assertRaises(StopIteration, next, iter)
+ def test_empty_array(self):
+ # SF buf 1647541
+ import array
+ for typecode in 'cbBuhHiIlLfd':
+ a = array.array(typecode)
+ self.assertEqual(re.compile("bla").match(a), None)
+ self.assertEqual(re.compile("").match(a).groups(), ())
def run_re_tests():
from test.re_tests import benchmarks, tests, SUCCEED, FAIL, SYNTAX_ERROR
diff --git a/Lib/test/test_robotparser.py b/Lib/test/test_robotparser.py
index 6a23b22..666d00a 100644
--- a/Lib/test/test_robotparser.py
+++ b/Lib/test/test_robotparser.py
@@ -135,8 +135,8 @@ bad = [] # Bug report says "/" should be denied, but that is not in the RFC
RobotTest(7, doc, good, bad)
def test_main():
- test_support.run_suite(tests)
+ test_support.run_unittest(tests)
if __name__=='__main__':
test_support.Verbose = 1
- test_support.run_suite(tests)
+ test_main()
diff --git a/Lib/test/test_sax.py b/Lib/test/test_sax.py
index 5b071dd..0d76cbb 100644
--- a/Lib/test/test_sax.py
+++ b/Lib/test/test_sax.py
@@ -13,26 +13,66 @@ from xml.sax.saxutils import XMLGenerator, escape, unescape, quoteattr, \
from xml.sax.expatreader import create_parser
from xml.sax.xmlreader import InputSource, AttributesImpl, AttributesNSImpl
from cStringIO import StringIO
-from test.test_support import verify, verbose, TestFailed, findfile
+from test.test_support import findfile, run_unittest
+import unittest
import os
-# ===== Utilities
-
-tests = 0
-failures = []
-
-def confirm(outcome, name):
- global tests
-
- tests = tests + 1
- if outcome:
- if verbose:
- print("Passed", name)
- else:
- failures.append(name)
+ns_uri = "http://www.python.org/xml-ns/saxtest/"
-def test_make_parser2():
- try:
+class XmlTestBase(unittest.TestCase):
+ def verify_empty_attrs(self, attrs):
+ self.assertRaises(KeyError, attrs.getValue, "attr")
+ self.assertRaises(KeyError, attrs.getValueByQName, "attr")
+ self.assertRaises(KeyError, attrs.getNameByQName, "attr")
+ self.assertRaises(KeyError, attrs.getQNameByName, "attr")
+ self.assertRaises(KeyError, attrs.__getitem__, "attr")
+ self.assertEquals(attrs.getLength(), 0)
+ self.assertEquals(attrs.getNames(), [])
+ self.assertEquals(attrs.getQNames(), [])
+ self.assertEquals(len(attrs), 0)
+ self.assertFalse("attr" in attrs)
+ self.assertEquals(list(attrs.keys()), [])
+ self.assertEquals(attrs.get("attrs"), None)
+ self.assertEquals(attrs.get("attrs", 25), 25)
+ self.assertEquals(list(attrs.items()), [])
+ self.assertEquals(list(attrs.values()), [])
+
+ def verify_empty_nsattrs(self, attrs):
+ self.assertRaises(KeyError, attrs.getValue, (ns_uri, "attr"))
+ self.assertRaises(KeyError, attrs.getValueByQName, "ns:attr")
+ self.assertRaises(KeyError, attrs.getNameByQName, "ns:attr")
+ self.assertRaises(KeyError, attrs.getQNameByName, (ns_uri, "attr"))
+ self.assertRaises(KeyError, attrs.__getitem__, (ns_uri, "attr"))
+ self.assertEquals(attrs.getLength(), 0)
+ self.assertEquals(attrs.getNames(), [])
+ self.assertEquals(attrs.getQNames(), [])
+ self.assertEquals(len(attrs), 0)
+ self.assertFalse((ns_uri, "attr") in attrs)
+ self.assertEquals(list(attrs.keys()), [])
+ self.assertEquals(attrs.get((ns_uri, "attr")), None)
+ self.assertEquals(attrs.get((ns_uri, "attr"), 25), 25)
+ self.assertEquals(list(attrs.items()), [])
+ self.assertEquals(list(attrs.values()), [])
+
+ def verify_attrs_wattr(self, attrs):
+ self.assertEquals(attrs.getLength(), 1)
+ self.assertEquals(attrs.getNames(), ["attr"])
+ self.assertEquals(attrs.getQNames(), ["attr"])
+ self.assertEquals(len(attrs), 1)
+ self.assertTrue("attr" in attrs)
+ self.assertEquals(list(attrs.keys()), ["attr"])
+ self.assertEquals(attrs.get("attr"), "val")
+ self.assertEquals(attrs.get("attr", 25), "val")
+ self.assertEquals(list(attrs.items()), [("attr", "val")])
+ self.assertEquals(list(attrs.values()), ["val"])
+ self.assertEquals(attrs.getValue("attr"), "val")
+ self.assertEquals(attrs.getValueByQName("attr"), "val")
+ self.assertEquals(attrs.getNameByQName("attr"), "attr")
+ self.assertEquals(attrs["attr"], "val")
+ self.assertEquals(attrs.getQNameByName("attr"), "attr")
+
+class MakeParserTest(unittest.TestCase):
+ def test_make_parser2(self):
# Creating parsers several times in a row should succeed.
# Testing this because there have been failures of this kind
# before.
@@ -48,10 +88,6 @@ def test_make_parser2():
p = make_parser()
from xml.sax import make_parser
p = make_parser()
- except:
- return 0
- else:
- return p
# ===========================================================================
@@ -60,215 +96,214 @@ def test_make_parser2():
#
# ===========================================================================
-# ===== escape
-
-def test_escape_basic():
- return escape("Donald Duck & Co") == "Donald Duck &amp; Co"
+class SaxutilsTest(unittest.TestCase):
+ # ===== escape
+ def test_escape_basic(self):
+ self.assertEquals(escape("Donald Duck & Co"), "Donald Duck &amp; Co")
-def test_escape_all():
- return escape("<Donald Duck & Co>") == "&lt;Donald Duck &amp; Co&gt;"
+ def test_escape_all(self):
+ self.assertEquals(escape("<Donald Duck & Co>"),
+ "&lt;Donald Duck &amp; Co&gt;")
-def test_escape_extra():
- return escape("Hei på deg", {"å" : "&aring;"}) == "Hei p&aring; deg"
+ def test_escape_extra(self):
+ self.assertEquals(escape("Hei på deg", {"å" : "&aring;"}),
+ "Hei p&aring; deg")
-# ===== unescape
+ # ===== unescape
+ def test_unescape_basic(self):
+ self.assertEquals(unescape("Donald Duck &amp; Co"), "Donald Duck & Co")
-def test_unescape_basic():
- return unescape("Donald Duck &amp; Co") == "Donald Duck & Co"
+ def test_unescape_all(self):
+ self.assertEquals(unescape("&lt;Donald Duck &amp; Co&gt;"),
+ "<Donald Duck & Co>")
-def test_unescape_all():
- return unescape("&lt;Donald Duck &amp; Co&gt;") == "<Donald Duck & Co>"
+ def test_unescape_extra(self):
+ self.assertEquals(unescape("Hei på deg", {"å" : "&aring;"}),
+ "Hei p&aring; deg")
-def test_unescape_extra():
- return unescape("Hei på deg", {"å" : "&aring;"}) == "Hei p&aring; deg"
+ def test_unescape_amp_extra(self):
+ self.assertEquals(unescape("&amp;foo;", {"&foo;": "splat"}), "&foo;")
-def test_unescape_amp_extra():
- return unescape("&amp;foo;", {"&foo;": "splat"}) == "&foo;"
+ # ===== quoteattr
+ def test_quoteattr_basic(self):
+ self.assertEquals(quoteattr("Donald Duck & Co"),
+ '"Donald Duck &amp; Co"')
-# ===== quoteattr
+ def test_single_quoteattr(self):
+ self.assertEquals(quoteattr('Includes "double" quotes'),
+ '\'Includes "double" quotes\'')
-def test_quoteattr_basic():
- return quoteattr("Donald Duck & Co") == '"Donald Duck &amp; Co"'
+ def test_double_quoteattr(self):
+ self.assertEquals(quoteattr("Includes 'single' quotes"),
+ "\"Includes 'single' quotes\"")
-def test_single_quoteattr():
- return (quoteattr('Includes "double" quotes')
- == '\'Includes "double" quotes\'')
+ def test_single_double_quoteattr(self):
+ self.assertEquals(quoteattr("Includes 'single' and \"double\" quotes"),
+ "\"Includes 'single' and &quot;double&quot; quotes\"")
-def test_double_quoteattr():
- return (quoteattr("Includes 'single' quotes")
- == "\"Includes 'single' quotes\"")
-
-def test_single_double_quoteattr():
- return (quoteattr("Includes 'single' and \"double\" quotes")
- == "\"Includes 'single' and &quot;double&quot; quotes\"")
-
-# ===== make_parser
-
-def test_make_parser():
- try:
+ # ===== make_parser
+ def test_make_parser(self):
# Creating a parser should succeed - it should fall back
# to the expatreader
p = make_parser(['xml.parsers.no_such_parser'])
- except:
- return 0
- else:
- return p
# ===== XMLGenerator
start = '<?xml version="1.0" encoding="iso-8859-1"?>\n'
-def test_xmlgen_basic():
- result = StringIO()
- gen = XMLGenerator(result)
- gen.startDocument()
- gen.startElement("doc", {})
- gen.endElement("doc")
- gen.endDocument()
-
- return result.getvalue() == start + "<doc></doc>"
-
-def test_xmlgen_content():
- result = StringIO()
- gen = XMLGenerator(result)
-
- gen.startDocument()
- gen.startElement("doc", {})
- gen.characters("huhei")
- gen.endElement("doc")
- gen.endDocument()
-
- return result.getvalue() == start + "<doc>huhei</doc>"
-
-def test_xmlgen_pi():
- result = StringIO()
- gen = XMLGenerator(result)
-
- gen.startDocument()
- gen.processingInstruction("test", "data")
- gen.startElement("doc", {})
- gen.endElement("doc")
- gen.endDocument()
-
- return result.getvalue() == start + "<?test data?><doc></doc>"
-
-def test_xmlgen_content_escape():
- result = StringIO()
- gen = XMLGenerator(result)
-
- gen.startDocument()
- gen.startElement("doc", {})
- gen.characters("<huhei&")
- gen.endElement("doc")
- gen.endDocument()
-
- return result.getvalue() == start + "<doc>&lt;huhei&amp;</doc>"
-
-def test_xmlgen_attr_escape():
- result = StringIO()
- gen = XMLGenerator(result)
-
- gen.startDocument()
- gen.startElement("doc", {"a": '"'})
- gen.startElement("e", {"a": "'"})
- gen.endElement("e")
- gen.startElement("e", {"a": "'\""})
- gen.endElement("e")
- gen.startElement("e", {"a": "\n\r\t"})
- gen.endElement("e")
- gen.endElement("doc")
- gen.endDocument()
-
- return result.getvalue() == start + ("<doc a='\"'><e a=\"'\"></e>"
- "<e a=\"'&quot;\"></e>"
- "<e a=\"&#10;&#13;&#9;\"></e></doc>")
-
-def test_xmlgen_ignorable():
- result = StringIO()
- gen = XMLGenerator(result)
-
- gen.startDocument()
- gen.startElement("doc", {})
- gen.ignorableWhitespace(" ")
- gen.endElement("doc")
- gen.endDocument()
-
- return result.getvalue() == start + "<doc> </doc>"
+class XmlgenTest(unittest.TestCase):
+ def test_xmlgen_basic(self):
+ result = StringIO()
+ gen = XMLGenerator(result)
+ gen.startDocument()
+ gen.startElement("doc", {})
+ gen.endElement("doc")
+ gen.endDocument()
+
+ self.assertEquals(result.getvalue(), start + "<doc></doc>")
+
+ def test_xmlgen_content(self):
+ result = StringIO()
+ gen = XMLGenerator(result)
+
+ gen.startDocument()
+ gen.startElement("doc", {})
+ gen.characters("huhei")
+ gen.endElement("doc")
+ gen.endDocument()
+
+ self.assertEquals(result.getvalue(), start + "<doc>huhei</doc>")
+
+ def test_xmlgen_pi(self):
+ result = StringIO()
+ gen = XMLGenerator(result)
+
+ gen.startDocument()
+ gen.processingInstruction("test", "data")
+ gen.startElement("doc", {})
+ gen.endElement("doc")
+ gen.endDocument()
+
+ self.assertEquals(result.getvalue(), start + "<?test data?><doc></doc>")
+
+ def test_xmlgen_content_escape(self):
+ result = StringIO()
+ gen = XMLGenerator(result)
+
+ gen.startDocument()
+ gen.startElement("doc", {})
+ gen.characters("<huhei&")
+ gen.endElement("doc")
+ gen.endDocument()
+
+ self.assertEquals(result.getvalue(),
+ start + "<doc>&lt;huhei&amp;</doc>")
+
+ def test_xmlgen_attr_escape(self):
+ result = StringIO()
+ gen = XMLGenerator(result)
+
+ gen.startDocument()
+ gen.startElement("doc", {"a": '"'})
+ gen.startElement("e", {"a": "'"})
+ gen.endElement("e")
+ gen.startElement("e", {"a": "'\""})
+ gen.endElement("e")
+ gen.startElement("e", {"a": "\n\r\t"})
+ gen.endElement("e")
+ gen.endElement("doc")
+ gen.endDocument()
+
+ self.assertEquals(result.getvalue(), start +
+ ("<doc a='\"'><e a=\"'\"></e>"
+ "<e a=\"'&quot;\"></e>"
+ "<e a=\"&#10;&#13;&#9;\"></e></doc>"))
+
+ def test_xmlgen_ignorable(self):
+ result = StringIO()
+ gen = XMLGenerator(result)
+
+ gen.startDocument()
+ gen.startElement("doc", {})
+ gen.ignorableWhitespace(" ")
+ gen.endElement("doc")
+ gen.endDocument()
+
+ self.assertEquals(result.getvalue(), start + "<doc> </doc>")
+
+ def test_xmlgen_ns(self):
+ result = StringIO()
+ gen = XMLGenerator(result)
+
+ gen.startDocument()
+ gen.startPrefixMapping("ns1", ns_uri)
+ gen.startElementNS((ns_uri, "doc"), "ns1:doc", {})
+ # add an unqualified name
+ gen.startElementNS((None, "udoc"), None, {})
+ gen.endElementNS((None, "udoc"), None)
+ gen.endElementNS((ns_uri, "doc"), "ns1:doc")
+ gen.endPrefixMapping("ns1")
+ gen.endDocument()
+
+ self.assertEquals(result.getvalue(), start + \
+ ('<ns1:doc xmlns:ns1="%s"><udoc></udoc></ns1:doc>' %
+ ns_uri))
-ns_uri = "http://www.python.org/xml-ns/saxtest/"
+ def test_1463026_1(self):
+ result = StringIO()
+ gen = XMLGenerator(result)
-def test_xmlgen_ns():
- result = StringIO()
- gen = XMLGenerator(result)
-
- gen.startDocument()
- gen.startPrefixMapping("ns1", ns_uri)
- gen.startElementNS((ns_uri, "doc"), "ns1:doc", {})
- # add an unqualified name
- gen.startElementNS((None, "udoc"), None, {})
- gen.endElementNS((None, "udoc"), None)
- gen.endElementNS((ns_uri, "doc"), "ns1:doc")
- gen.endPrefixMapping("ns1")
- gen.endDocument()
-
- return result.getvalue() == start + \
- ('<ns1:doc xmlns:ns1="%s"><udoc></udoc></ns1:doc>' %
- ns_uri)
-
-def test_1463026_1():
- result = StringIO()
- gen = XMLGenerator(result)
-
- gen.startDocument()
- gen.startElementNS((None, 'a'), 'a', {(None, 'b'):'c'})
- gen.endElementNS((None, 'a'), 'a')
- gen.endDocument()
-
- return result.getvalue() == start+'<a b="c"></a>'
-
-def test_1463026_2():
- result = StringIO()
- gen = XMLGenerator(result)
-
- gen.startDocument()
- gen.startPrefixMapping(None, 'qux')
- gen.startElementNS(('qux', 'a'), 'a', {})
- gen.endElementNS(('qux', 'a'), 'a')
- gen.endPrefixMapping(None)
- gen.endDocument()
-
- return result.getvalue() == start+'<a xmlns="qux"></a>'
-
-def test_1463026_3():
- result = StringIO()
- gen = XMLGenerator(result)
-
- gen.startDocument()
- gen.startPrefixMapping('my', 'qux')
- gen.startElementNS(('qux', 'a'), 'a', {(None, 'b'):'c'})
- gen.endElementNS(('qux', 'a'), 'a')
- gen.endPrefixMapping('my')
- gen.endDocument()
-
- return result.getvalue() == start+'<my:a xmlns:my="qux" b="c"></my:a>'
-
-# ===== Xmlfilterbase
-
-def test_filter_basic():
- result = StringIO()
- gen = XMLGenerator(result)
- filter = XMLFilterBase()
- filter.setContentHandler(gen)
-
- filter.startDocument()
- filter.startElement("doc", {})
- filter.characters("content")
- filter.ignorableWhitespace(" ")
- filter.endElement("doc")
- filter.endDocument()
-
- return result.getvalue() == start + "<doc>content </doc>"
+ gen.startDocument()
+ gen.startElementNS((None, 'a'), 'a', {(None, 'b'):'c'})
+ gen.endElementNS((None, 'a'), 'a')
+ gen.endDocument()
+
+ self.assertEquals(result.getvalue(), start+'<a b="c"></a>')
+
+ def test_1463026_2(self):
+ result = StringIO()
+ gen = XMLGenerator(result)
+
+ gen.startDocument()
+ gen.startPrefixMapping(None, 'qux')
+ gen.startElementNS(('qux', 'a'), 'a', {})
+ gen.endElementNS(('qux', 'a'), 'a')
+ gen.endPrefixMapping(None)
+ gen.endDocument()
+
+ self.assertEquals(result.getvalue(), start+'<a xmlns="qux"></a>')
+
+ def test_1463026_3(self):
+ result = StringIO()
+ gen = XMLGenerator(result)
+
+ gen.startDocument()
+ gen.startPrefixMapping('my', 'qux')
+ gen.startElementNS(('qux', 'a'), 'a', {(None, 'b'):'c'})
+ gen.endElementNS(('qux', 'a'), 'a')
+ gen.endPrefixMapping('my')
+ gen.endDocument()
+
+ self.assertEquals(result.getvalue(),
+ start+'<my:a xmlns:my="qux" b="c"></my:a>')
+
+
+class XMLFilterBaseTest(unittest.TestCase):
+ def test_filter_basic(self):
+ result = StringIO()
+ gen = XMLGenerator(result)
+ filter = XMLFilterBase()
+ filter.setContentHandler(gen)
+
+ filter.startDocument()
+ filter.startElement("doc", {})
+ filter.characters("content")
+ filter.ignorableWhitespace(" ")
+ filter.endElement("doc")
+ filter.endDocument()
+
+ self.assertEquals(result.getvalue(), start + "<doc>content </doc>")
# ===========================================================================
#
@@ -276,229 +311,233 @@ def test_filter_basic():
#
# ===========================================================================
-# ===== XMLReader support
+xml_test_out = open(findfile("test"+os.extsep+"xml"+os.extsep+"out")).read()
-def test_expat_file():
- parser = create_parser()
- result = StringIO()
- xmlgen = XMLGenerator(result)
+class ExpatReaderTest(XmlTestBase):
- parser.setContentHandler(xmlgen)
- parser.parse(open(findfile("test"+os.extsep+"xml")))
+ # ===== XMLReader support
- return result.getvalue() == xml_test_out
+ def test_expat_file(self):
+ parser = create_parser()
+ result = StringIO()
+ xmlgen = XMLGenerator(result)
-# ===== DTDHandler support
+ parser.setContentHandler(xmlgen)
+ parser.parse(open(findfile("test"+os.extsep+"xml")))
-class TestDTDHandler:
+ self.assertEquals(result.getvalue(), xml_test_out)
- def __init__(self):
- self._notations = []
- self._entities = []
+ # ===== DTDHandler support
- def notationDecl(self, name, publicId, systemId):
- self._notations.append((name, publicId, systemId))
+ class TestDTDHandler:
- def unparsedEntityDecl(self, name, publicId, systemId, ndata):
- self._entities.append((name, publicId, systemId, ndata))
+ def __init__(self):
+ self._notations = []
+ self._entities = []
-def test_expat_dtdhandler():
- parser = create_parser()
- handler = TestDTDHandler()
- parser.setDTDHandler(handler)
+ def notationDecl(self, name, publicId, systemId):
+ self._notations.append((name, publicId, systemId))
- parser.feed('<!DOCTYPE doc [\n')
- parser.feed(' <!ENTITY img SYSTEM "expat.gif" NDATA GIF>\n')
- parser.feed(' <!NOTATION GIF PUBLIC "-//CompuServe//NOTATION Graphics Interchange Format 89a//EN">\n')
- parser.feed(']>\n')
- parser.feed('<doc></doc>')
- parser.close()
+ def unparsedEntityDecl(self, name, publicId, systemId, ndata):
+ self._entities.append((name, publicId, systemId, ndata))
- return handler._notations == [("GIF", "-//CompuServe//NOTATION Graphics Interchange Format 89a//EN", None)] and \
- handler._entities == [("img", None, "expat.gif", "GIF")]
+ def test_expat_dtdhandler(self):
+ parser = create_parser()
+ handler = self.TestDTDHandler()
+ parser.setDTDHandler(handler)
-# ===== EntityResolver support
+ parser.feed('<!DOCTYPE doc [\n')
+ parser.feed(' <!ENTITY img SYSTEM "expat.gif" NDATA GIF>\n')
+ parser.feed(' <!NOTATION GIF PUBLIC "-//CompuServe//NOTATION Graphics Interchange Format 89a//EN">\n')
+ parser.feed(']>\n')
+ parser.feed('<doc></doc>')
+ parser.close()
-class TestEntityResolver:
+ self.assertEquals(handler._notations,
+ [("GIF", "-//CompuServe//NOTATION Graphics Interchange Format 89a//EN", None)])
+ self.assertEquals(handler._entities, [("img", None, "expat.gif", "GIF")])
- def resolveEntity(self, publicId, systemId):
- inpsrc = InputSource()
- inpsrc.setByteStream(StringIO("<entity/>"))
- return inpsrc
+ # ===== EntityResolver support
-def test_expat_entityresolver():
- parser = create_parser()
- parser.setEntityResolver(TestEntityResolver())
- result = StringIO()
- parser.setContentHandler(XMLGenerator(result))
+ class TestEntityResolver:
- parser.feed('<!DOCTYPE doc [\n')
- parser.feed(' <!ENTITY test SYSTEM "whatever">\n')
- parser.feed(']>\n')
- parser.feed('<doc>&test;</doc>')
- parser.close()
+ def resolveEntity(self, publicId, systemId):
+ inpsrc = InputSource()
+ inpsrc.setByteStream(StringIO("<entity/>"))
+ return inpsrc
- return result.getvalue() == start + "<doc><entity></entity></doc>"
+ def test_expat_entityresolver(self):
+ parser = create_parser()
+ parser.setEntityResolver(self.TestEntityResolver())
+ result = StringIO()
+ parser.setContentHandler(XMLGenerator(result))
-# ===== Attributes support
+ parser.feed('<!DOCTYPE doc [\n')
+ parser.feed(' <!ENTITY test SYSTEM "whatever">\n')
+ parser.feed(']>\n')
+ parser.feed('<doc>&test;</doc>')
+ parser.close()
-class AttrGatherer(ContentHandler):
+ self.assertEquals(result.getvalue(), start +
+ "<doc><entity></entity></doc>")
- def startElement(self, name, attrs):
- self._attrs = attrs
+ # ===== Attributes support
- def startElementNS(self, name, qname, attrs):
- self._attrs = attrs
+ class AttrGatherer(ContentHandler):
-def test_expat_attrs_empty():
- parser = create_parser()
- gather = AttrGatherer()
- parser.setContentHandler(gather)
+ def startElement(self, name, attrs):
+ self._attrs = attrs
- parser.feed("<doc/>")
- parser.close()
+ def startElementNS(self, name, qname, attrs):
+ self._attrs = attrs
- return verify_empty_attrs(gather._attrs)
+ def test_expat_attrs_empty(self):
+ parser = create_parser()
+ gather = self.AttrGatherer()
+ parser.setContentHandler(gather)
-def test_expat_attrs_wattr():
- parser = create_parser()
- gather = AttrGatherer()
- parser.setContentHandler(gather)
+ parser.feed("<doc/>")
+ parser.close()
- parser.feed("<doc attr='val'/>")
- parser.close()
+ self.verify_empty_attrs(gather._attrs)
- return verify_attrs_wattr(gather._attrs)
+ def test_expat_attrs_wattr(self):
+ parser = create_parser()
+ gather = self.AttrGatherer()
+ parser.setContentHandler(gather)
-def test_expat_nsattrs_empty():
- parser = create_parser(1)
- gather = AttrGatherer()
- parser.setContentHandler(gather)
+ parser.feed("<doc attr='val'/>")
+ parser.close()
- parser.feed("<doc/>")
- parser.close()
+ self.verify_attrs_wattr(gather._attrs)
- return verify_empty_nsattrs(gather._attrs)
+ def test_expat_nsattrs_empty(self):
+ parser = create_parser(1)
+ gather = self.AttrGatherer()
+ parser.setContentHandler(gather)
-def test_expat_nsattrs_wattr():
- parser = create_parser(1)
- gather = AttrGatherer()
- parser.setContentHandler(gather)
+ parser.feed("<doc/>")
+ parser.close()
- parser.feed("<doc xmlns:ns='%s' ns:attr='val'/>" % ns_uri)
- parser.close()
+ self.verify_empty_nsattrs(gather._attrs)
- attrs = gather._attrs
+ def test_expat_nsattrs_wattr(self):
+ parser = create_parser(1)
+ gather = self.AttrGatherer()
+ parser.setContentHandler(gather)
- return attrs.getLength() == 1 and \
- attrs.getNames() == [(ns_uri, "attr")] and \
- (attrs.getQNames() == [] or attrs.getQNames() == ["ns:attr"]) and \
- len(attrs) == 1 and \
- (ns_uri, "attr") in attrs and \
- list(attrs.keys()) == [(ns_uri, "attr")] and \
- attrs.get((ns_uri, "attr")) == "val" and \
- attrs.get((ns_uri, "attr"), 25) == "val" and \
- list(attrs.items()) == [((ns_uri, "attr"), "val")] and \
- list(attrs.values()) == ["val"] and \
- attrs.getValue((ns_uri, "attr")) == "val" and \
- attrs[(ns_uri, "attr")] == "val"
+ parser.feed("<doc xmlns:ns='%s' ns:attr='val'/>" % ns_uri)
+ parser.close()
-# ===== InputSource support
+ attrs = gather._attrs
-xml_test_out = open(findfile("test"+os.extsep+"xml"+os.extsep+"out")).read()
+ self.assertEquals(attrs.getLength(), 1)
+ self.assertEquals(attrs.getNames(), [(ns_uri, "attr")])
+ self.assertTrue((attrs.getQNames() == [] or
+ attrs.getQNames() == ["ns:attr"]))
+ self.assertEquals(len(attrs), 1)
+ self.assertTrue((ns_uri, "attr") in attrs)
+ self.assertEquals(attrs.get((ns_uri, "attr")), "val")
+ self.assertEquals(attrs.get((ns_uri, "attr"), 25), "val")
+ self.assertEquals(list(attrs.items()), [((ns_uri, "attr"), "val")])
+ self.assertEquals(list(attrs.values()), ["val"])
+ self.assertEquals(attrs.getValue((ns_uri, "attr")), "val")
+ self.assertEquals(attrs[(ns_uri, "attr")], "val")
-def test_expat_inpsource_filename():
- parser = create_parser()
- result = StringIO()
- xmlgen = XMLGenerator(result)
+ # ===== InputSource support
- parser.setContentHandler(xmlgen)
- parser.parse(findfile("test"+os.extsep+"xml"))
+ def test_expat_inpsource_filename(self):
+ parser = create_parser()
+ result = StringIO()
+ xmlgen = XMLGenerator(result)
- return result.getvalue() == xml_test_out
+ parser.setContentHandler(xmlgen)
+ parser.parse(findfile("test"+os.extsep+"xml"))
-def test_expat_inpsource_sysid():
- parser = create_parser()
- result = StringIO()
- xmlgen = XMLGenerator(result)
+ self.assertEquals(result.getvalue(), xml_test_out)
- parser.setContentHandler(xmlgen)
- parser.parse(InputSource(findfile("test"+os.extsep+"xml")))
+ def test_expat_inpsource_sysid(self):
+ parser = create_parser()
+ result = StringIO()
+ xmlgen = XMLGenerator(result)
- return result.getvalue() == xml_test_out
+ parser.setContentHandler(xmlgen)
+ parser.parse(InputSource(findfile("test"+os.extsep+"xml")))
-def test_expat_inpsource_stream():
- parser = create_parser()
- result = StringIO()
- xmlgen = XMLGenerator(result)
+ self.assertEquals(result.getvalue(), xml_test_out)
- parser.setContentHandler(xmlgen)
- inpsrc = InputSource()
- inpsrc.setByteStream(open(findfile("test"+os.extsep+"xml")))
- parser.parse(inpsrc)
+ def test_expat_inpsource_stream(self):
+ parser = create_parser()
+ result = StringIO()
+ xmlgen = XMLGenerator(result)
- return result.getvalue() == xml_test_out
+ parser.setContentHandler(xmlgen)
+ inpsrc = InputSource()
+ inpsrc.setByteStream(open(findfile("test"+os.extsep+"xml")))
+ parser.parse(inpsrc)
-# ===== IncrementalParser support
+ self.assertEquals(result.getvalue(), xml_test_out)
-def test_expat_incremental():
- result = StringIO()
- xmlgen = XMLGenerator(result)
- parser = create_parser()
- parser.setContentHandler(xmlgen)
+ # ===== IncrementalParser support
- parser.feed("<doc>")
- parser.feed("</doc>")
- parser.close()
+ def test_expat_incremental(self):
+ result = StringIO()
+ xmlgen = XMLGenerator(result)
+ parser = create_parser()
+ parser.setContentHandler(xmlgen)
- return result.getvalue() == start + "<doc></doc>"
+ parser.feed("<doc>")
+ parser.feed("</doc>")
+ parser.close()
-def test_expat_incremental_reset():
- result = StringIO()
- xmlgen = XMLGenerator(result)
- parser = create_parser()
- parser.setContentHandler(xmlgen)
+ self.assertEquals(result.getvalue(), start + "<doc></doc>")
- parser.feed("<doc>")
- parser.feed("text")
+ def test_expat_incremental_reset(self):
+ result = StringIO()
+ xmlgen = XMLGenerator(result)
+ parser = create_parser()
+ parser.setContentHandler(xmlgen)
- result = StringIO()
- xmlgen = XMLGenerator(result)
- parser.setContentHandler(xmlgen)
- parser.reset()
+ parser.feed("<doc>")
+ parser.feed("text")
- parser.feed("<doc>")
- parser.feed("text")
- parser.feed("</doc>")
- parser.close()
+ result = StringIO()
+ xmlgen = XMLGenerator(result)
+ parser.setContentHandler(xmlgen)
+ parser.reset()
- return result.getvalue() == start + "<doc>text</doc>"
+ parser.feed("<doc>")
+ parser.feed("text")
+ parser.feed("</doc>")
+ parser.close()
-# ===== Locator support
+ self.assertEquals(result.getvalue(), start + "<doc>text</doc>")
-def test_expat_locator_noinfo():
- result = StringIO()
- xmlgen = XMLGenerator(result)
- parser = create_parser()
- parser.setContentHandler(xmlgen)
+ # ===== Locator support
- parser.feed("<doc>")
- parser.feed("</doc>")
- parser.close()
+ def test_expat_locator_noinfo(self):
+ result = StringIO()
+ xmlgen = XMLGenerator(result)
+ parser = create_parser()
+ parser.setContentHandler(xmlgen)
- return parser.getSystemId() is None and \
- parser.getPublicId() is None and \
- parser.getLineNumber() == 1
+ parser.feed("<doc>")
+ parser.feed("</doc>")
+ parser.close()
-def test_expat_locator_withinfo():
- result = StringIO()
- xmlgen = XMLGenerator(result)
- parser = create_parser()
- parser.setContentHandler(xmlgen)
- parser.parse(findfile("test.xml"))
+ self.assertEquals(parser.getSystemId(), None)
+ self.assertEquals(parser.getPublicId(), None)
+ self.assertEquals(parser.getLineNumber(), 1)
- return parser.getSystemId() == findfile("test.xml") and \
- parser.getPublicId() is None
+ def test_expat_locator_withinfo(self):
+ result = StringIO()
+ xmlgen = XMLGenerator(result)
+ parser = create_parser()
+ parser.setContentHandler(xmlgen)
+ parser.parse(findfile("test.xml"))
+
+ self.assertEquals(parser.getSystemId(), findfile("test.xml"))
+ self.assertEquals(parser.getPublicId(), None)
# ===========================================================================
@@ -507,63 +546,59 @@ def test_expat_locator_withinfo():
#
# ===========================================================================
-def test_expat_inpsource_location():
- parser = create_parser()
- parser.setContentHandler(ContentHandler()) # do nothing
- source = InputSource()
- source.setByteStream(StringIO("<foo bar foobar>")) #ill-formed
- name = "a file name"
- source.setSystemId(name)
- try:
- parser.parse(source)
- except SAXException as e:
- return e.getSystemId() == name
-
-def test_expat_incomplete():
- parser = create_parser()
- parser.setContentHandler(ContentHandler()) # do nothing
- try:
- parser.parse(StringIO("<foo>"))
- except SAXParseException:
- return 1 # ok, error found
- else:
- return 0
-
-def test_sax_parse_exception_str():
- # pass various values from a locator to the SAXParseException to
- # make sure that the __str__() doesn't fall apart when None is
- # passed instead of an integer line and column number
- #
- # use "normal" values for the locator:
- str(SAXParseException("message", None,
- DummyLocator(1, 1)))
- # use None for the line number:
- str(SAXParseException("message", None,
- DummyLocator(None, 1)))
- # use None for the column number:
- str(SAXParseException("message", None,
- DummyLocator(1, None)))
- # use None for both:
- str(SAXParseException("message", None,
- DummyLocator(None, None)))
- return 1
-
-class DummyLocator:
- def __init__(self, lineno, colno):
- self._lineno = lineno
- self._colno = colno
-
- def getPublicId(self):
- return "pubid"
-
- def getSystemId(self):
- return "sysid"
-
- def getLineNumber(self):
- return self._lineno
-
- def getColumnNumber(self):
- return self._colno
+class ErrorReportingTest(unittest.TestCase):
+ def test_expat_inpsource_location(self):
+ parser = create_parser()
+ parser.setContentHandler(ContentHandler()) # do nothing
+ source = InputSource()
+ source.setByteStream(StringIO("<foo bar foobar>")) #ill-formed
+ name = "a file name"
+ source.setSystemId(name)
+ try:
+ parser.parse(source)
+ self.fail()
+ except SAXException as e:
+ self.assertEquals(e.getSystemId(), name)
+
+ def test_expat_incomplete(self):
+ parser = create_parser()
+ parser.setContentHandler(ContentHandler()) # do nothing
+ self.assertRaises(SAXParseException, parser.parse, StringIO("<foo>"))
+
+ def test_sax_parse_exception_str(self):
+ # pass various values from a locator to the SAXParseException to
+ # make sure that the __str__() doesn't fall apart when None is
+ # passed instead of an integer line and column number
+ #
+ # use "normal" values for the locator:
+ str(SAXParseException("message", None,
+ self.DummyLocator(1, 1)))
+ # use None for the line number:
+ str(SAXParseException("message", None,
+ self.DummyLocator(None, 1)))
+ # use None for the column number:
+ str(SAXParseException("message", None,
+ self.DummyLocator(1, None)))
+ # use None for both:
+ str(SAXParseException("message", None,
+ self.DummyLocator(None, None)))
+
+ class DummyLocator:
+ def __init__(self, lineno, colno):
+ self._lineno = lineno
+ self._colno = colno
+
+ def getPublicId(self):
+ return "pubid"
+
+ def getSystemId(self):
+ return "sysid"
+
+ def getLineNumber(self):
+ return self._lineno
+
+ def getColumnNumber(self):
+ return self._colno
# ===========================================================================
#
@@ -571,217 +606,91 @@ class DummyLocator:
#
# ===========================================================================
-# ===== AttributesImpl
-
-def verify_empty_attrs(attrs):
- try:
- attrs.getValue("attr")
- gvk = 0
- except KeyError:
- gvk = 1
-
- try:
- attrs.getValueByQName("attr")
- gvqk = 0
- except KeyError:
- gvqk = 1
-
- try:
- attrs.getNameByQName("attr")
- gnqk = 0
- except KeyError:
- gnqk = 1
-
- try:
- attrs.getQNameByName("attr")
- gqnk = 0
- except KeyError:
- gqnk = 1
-
- try:
- attrs["attr"]
- gik = 0
- except KeyError:
- gik = 1
-
- return attrs.getLength() == 0 and \
- attrs.getNames() == [] and \
- attrs.getQNames() == [] and \
- len(attrs) == 0 and \
- "attr" not in attrs and \
- attrs.keys() == [] and \
- attrs.get("attrs") is None and \
- attrs.get("attrs", 25) == 25 and \
- attrs.items() == [] and \
- attrs.values() == [] and \
- gvk and gvqk and gnqk and gik and gqnk
-
-def verify_attrs_wattr(attrs):
- return attrs.getLength() == 1 and \
- attrs.getNames() == ["attr"] and \
- attrs.getQNames() == ["attr"] and \
- len(attrs) == 1 and \
- "attr" in attrs and \
- attrs.keys() == ["attr"] and \
- attrs.get("attr") == "val" and \
- attrs.get("attr", 25) == "val" and \
- attrs.items() == [("attr", "val")] and \
- attrs.values() == ["val"] and \
- attrs.getValue("attr") == "val" and \
- attrs.getValueByQName("attr") == "val" and \
- attrs.getNameByQName("attr") == "attr" and \
- attrs["attr"] == "val" and \
- attrs.getQNameByName("attr") == "attr"
-
-def test_attrs_empty():
- return verify_empty_attrs(AttributesImpl({}))
-
-def test_attrs_wattr():
- return verify_attrs_wattr(AttributesImpl({"attr" : "val"}))
-
-# ===== AttributesImpl
-
-def verify_empty_nsattrs(attrs):
- try:
- attrs.getValue((ns_uri, "attr"))
- gvk = 0
- except KeyError:
- gvk = 1
-
- try:
- attrs.getValueByQName("ns:attr")
- gvqk = 0
- except KeyError:
- gvqk = 1
-
- try:
- attrs.getNameByQName("ns:attr")
- gnqk = 0
- except KeyError:
- gnqk = 1
-
- try:
- attrs.getQNameByName((ns_uri, "attr"))
- gqnk = 0
- except KeyError:
- gqnk = 1
-
- try:
- attrs[(ns_uri, "attr")]
- gik = 0
- except KeyError:
- gik = 1
-
- return attrs.getLength() == 0 and \
- attrs.getNames() == [] and \
- attrs.getQNames() == [] and \
- len(attrs) == 0 and \
- (ns_uri, "attr") not in attrs and \
- attrs.keys() == [] and \
- attrs.get((ns_uri, "attr")) is None and \
- attrs.get((ns_uri, "attr"), 25) == 25 and \
- attrs.items() == [] and \
- attrs.values() == [] and \
- gvk and gvqk and gnqk and gik and gqnk
-
-def test_nsattrs_empty():
- return verify_empty_nsattrs(AttributesNSImpl({}, {}))
-
-def test_nsattrs_wattr():
- attrs = AttributesNSImpl({(ns_uri, "attr") : "val"},
- {(ns_uri, "attr") : "ns:attr"})
-
- return attrs.getLength() == 1 and \
- attrs.getNames() == [(ns_uri, "attr")] and \
- attrs.getQNames() == ["ns:attr"] and \
- len(attrs) == 1 and \
- (ns_uri, "attr") in attrs and \
- attrs.keys() == [(ns_uri, "attr")] and \
- attrs.get((ns_uri, "attr")) == "val" and \
- attrs.get((ns_uri, "attr"), 25) == "val" and \
- attrs.items() == [((ns_uri, "attr"), "val")] and \
- attrs.values() == ["val"] and \
- attrs.getValue((ns_uri, "attr")) == "val" and \
- attrs.getValueByQName("ns:attr") == "val" and \
- attrs.getNameByQName("ns:attr") == (ns_uri, "attr") and \
- attrs[(ns_uri, "attr")] == "val" and \
- attrs.getQNameByName((ns_uri, "attr")) == "ns:attr"
-
-
-# During the development of Python 2.5, an attempt to move the "xml"
-# package implementation to a new package ("xmlcore") proved painful.
-# The goal of this change was to allow applications to be able to
-# obtain and rely on behavior in the standard library implementation
-# of the XML support without needing to be concerned about the
-# availability of the PyXML implementation.
-#
-# While the existing import hackery in Lib/xml/__init__.py can cause
-# PyXML's _xmlpus package to supplant the "xml" package, that only
-# works because either implementation uses the "xml" package name for
-# imports.
-#
-# The move resulted in a number of problems related to the fact that
-# the import machinery's "package context" is based on the name that's
-# being imported rather than the __name__ of the actual package
-# containment; it wasn't possible for the "xml" package to be replaced
-# by a simple module that indirected imports to the "xmlcore" package.
-#
-# The following two tests exercised bugs that were introduced in that
-# attempt. Keeping these tests around will help detect problems with
-# other attempts to provide reliable access to the standard library's
-# implementation of the XML support.
-
-def test_sf_1511497():
- # Bug report: http://www.python.org/sf/1511497
- import sys
- old_modules = sys.modules.copy()
- for modname in list(sys.modules.keys()):
- if modname.startswith("xml."):
- del sys.modules[modname]
- try:
- import xml.sax.expatreader
- module = xml.sax.expatreader
- return module.__name__ == "xml.sax.expatreader"
- finally:
- sys.modules.update(old_modules)
-
-def test_sf_1513611():
- # Bug report: http://www.python.org/sf/1513611
- sio = StringIO("invalid")
- parser = make_parser()
- from xml.sax import SAXParseException
- try:
- parser.parse(sio)
- except SAXParseException:
- return True
- else:
- return False
-
-# ===== Main program
-
-def make_test_output():
- parser = create_parser()
- result = StringIO()
- xmlgen = XMLGenerator(result)
-
- parser.setContentHandler(xmlgen)
- parser.parse(findfile("test"+os.extsep+"xml"))
-
- outf = open(findfile("test"+os.extsep+"xml"+os.extsep+"out"), "w")
- outf.write(result.getvalue())
- outf.close()
-
-items = sorted(locals().items())
-for (name, value) in items:
- if name[ : 5] == "test_":
- confirm(value(), name)
-# We delete the items variable so that the assignment to items above
-# doesn't pick up the old value of items (which messes with attempts
-# to find reference leaks).
-del items
-
-if verbose:
- print("%d tests, %d failures" % (tests, len(failures)))
-if failures:
- raise TestFailed("%d of %d tests failed: %s"
- % (len(failures), tests, ", ".join(failures)))
+class XmlReaderTest(XmlTestBase):
+
+ # ===== AttributesImpl
+ def test_attrs_empty(self):
+ self.verify_empty_attrs(AttributesImpl({}))
+
+ def test_attrs_wattr(self):
+ self.verify_attrs_wattr(AttributesImpl({"attr" : "val"}))
+
+ def test_nsattrs_empty(self):
+ self.verify_empty_nsattrs(AttributesNSImpl({}, {}))
+
+ def test_nsattrs_wattr(self):
+ attrs = AttributesNSImpl({(ns_uri, "attr") : "val"},
+ {(ns_uri, "attr") : "ns:attr"})
+
+ self.assertEquals(attrs.getLength(), 1)
+ self.assertEquals(attrs.getNames(), [(ns_uri, "attr")])
+ self.assertEquals(attrs.getQNames(), ["ns:attr"])
+ self.assertEquals(len(attrs), 1)
+ self.assertTrue((ns_uri, "attr") in attrs)
+ self.assertEquals(list(attrs.keys()), [(ns_uri, "attr")])
+ self.assertEquals(attrs.get((ns_uri, "attr")), "val")
+ self.assertEquals(attrs.get((ns_uri, "attr"), 25), "val")
+ self.assertEquals(list(attrs.items()), [((ns_uri, "attr"), "val")])
+ self.assertEquals(list(attrs.values()), ["val"])
+ self.assertEquals(attrs.getValue((ns_uri, "attr")), "val")
+ self.assertEquals(attrs.getValueByQName("ns:attr"), "val")
+ self.assertEquals(attrs.getNameByQName("ns:attr"), (ns_uri, "attr"))
+ self.assertEquals(attrs[(ns_uri, "attr")], "val")
+ self.assertEquals(attrs.getQNameByName((ns_uri, "attr")), "ns:attr")
+
+
+ # During the development of Python 2.5, an attempt to move the "xml"
+ # package implementation to a new package ("xmlcore") proved painful.
+ # The goal of this change was to allow applications to be able to
+ # obtain and rely on behavior in the standard library implementation
+ # of the XML support without needing to be concerned about the
+ # availability of the PyXML implementation.
+ #
+ # While the existing import hackery in Lib/xml/__init__.py can cause
+ # PyXML's _xmlpus package to supplant the "xml" package, that only
+ # works because either implementation uses the "xml" package name for
+ # imports.
+ #
+ # The move resulted in a number of problems related to the fact that
+ # the import machinery's "package context" is based on the name that's
+ # being imported rather than the __name__ of the actual package
+ # containment; it wasn't possible for the "xml" package to be replaced
+ # by a simple module that indirected imports to the "xmlcore" package.
+ #
+ # The following two tests exercised bugs that were introduced in that
+ # attempt. Keeping these tests around will help detect problems with
+ # other attempts to provide reliable access to the standard library's
+ # implementation of the XML support.
+
+ def test_sf_1511497(self):
+ # Bug report: http://www.python.org/sf/1511497
+ import sys
+ old_modules = sys.modules.copy()
+ for modname in list(sys.modules.keys()):
+ if modname.startswith("xml."):
+ del sys.modules[modname]
+ try:
+ import xml.sax.expatreader
+ module = xml.sax.expatreader
+ self.assertEquals(module.__name__, "xml.sax.expatreader")
+ finally:
+ sys.modules.update(old_modules)
+
+ def test_sf_1513611(self):
+ # Bug report: http://www.python.org/sf/1513611
+ sio = StringIO("invalid")
+ parser = make_parser()
+ from xml.sax import SAXParseException
+ self.assertRaises(SAXParseException, parser.parse, sio)
+
+
+def unittest_main():
+ run_unittest(MakeParserTest,
+ SaxutilsTest,
+ XmlgenTest,
+ ExpatReaderTest,
+ ErrorReportingTest,
+ XmlReaderTest)
+
+if __name__ == "__main__":
+ unittest_main()
diff --git a/Lib/test/test_scope.py b/Lib/test/test_scope.py
index f52ab91..f5c1462 100644
--- a/Lib/test/test_scope.py
+++ b/Lib/test/test_scope.py
@@ -477,6 +477,39 @@ self.assert_(X.passed)
del d['h']
self.assertEqual(d, {'x': 2, 'y': 7, 'w': 6})
+ def testLocalsClass(self):
+ # This test verifies that calling locals() does not pollute
+ # the local namespace of the class with free variables. Old
+ # versions of Python had a bug, where a free variable being
+ # passed through a class namespace would be inserted into
+ # locals() by locals() or exec or a trace function.
+ #
+ # The real bug lies in frame code that copies variables
+ # between fast locals and the locals dict, e.g. when executing
+ # a trace function.
+
+ def f(x):
+ class C:
+ x = 12
+ def m(self):
+ return x
+ locals()
+ return C
+
+ self.assertEqual(f(1).x, 12)
+
+ def f(x):
+ class C:
+ y = x
+ def m(self):
+ return x
+ z = list(locals())
+ return C
+
+ varnames = f(1).z
+ self.assert_("x" not in varnames)
+ self.assert_("y" in varnames)
+
def testBoundAndFree(self):
# var is bound and free in class
@@ -607,7 +640,7 @@ self.assert_(X.passed)
c = f(0)
self.assertEqual(c.get(), 1)
self.assert_("x" not in c.__class__.__dict__)
-
+
def testNonLocalGenerator(self):
diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py
index 997e17f..45bf32c 100644
--- a/Lib/test/test_set.py
+++ b/Lib/test/test_set.py
@@ -288,10 +288,17 @@ class TestJointOps(unittest.TestCase):
s = self.thetype(d)
self.assertEqual(sum(elem.hash_count for elem in d), n)
s.difference(d)
- self.assertEqual(sum(elem.hash_count for elem in d), n)
+ self.assertEqual(sum(elem.hash_count for elem in d), n)
if hasattr(s, 'symmetric_difference_update'):
s.symmetric_difference_update(d)
- self.assertEqual(sum(elem.hash_count for elem in d), n)
+ self.assertEqual(sum(elem.hash_count for elem in d), n)
+ d2 = dict.fromkeys(set(d))
+ self.assertEqual(sum(elem.hash_count for elem in d), n)
+ d3 = dict.fromkeys(frozenset(d))
+ self.assertEqual(sum(elem.hash_count for elem in d), n)
+ d3 = dict.fromkeys(frozenset(d), 123)
+ self.assertEqual(sum(elem.hash_count for elem in d), n)
+ self.assertEqual(d3, dict.fromkeys(d, 123))
class TestSet(TestJointOps):
thetype = set
diff --git a/Lib/test/test_slice.py b/Lib/test/test_slice.py
index f22f619..977b28a 100644
--- a/Lib/test/test_slice.py
+++ b/Lib/test/test_slice.py
@@ -2,6 +2,7 @@
import unittest
from test import test_support
+from cPickle import loads, dumps
import sys
@@ -92,6 +93,24 @@ class SliceTest(unittest.TestCase):
self.assertRaises(OverflowError, slice(None).indices, 1<<100)
+ def test_setslice_without_getslice(self):
+ tmp = []
+ class X(object):
+ def __setslice__(self, i, j, k):
+ tmp.append((i, j, k))
+
+ x = X()
+ x[1:2] = 42
+ self.assertEquals(tmp, [(1, 2, 42)])
+
+ def test_pickle(self):
+ s = slice(10, 20, 3)
+ for protocol in (0,1,2):
+ t = loads(dumps(s, protocol))
+ self.assertEqual(s, t)
+ self.assertEqual(s.indices(15), t.indices(15))
+ self.assertNotEqual(id(s), id(t))
+
def test_main():
test_support.run_unittest(SliceTest)
diff --git a/Lib/test/test_smtplib.py b/Lib/test/test_smtplib.py
new file mode 100644
index 0000000..3542ddb
--- /dev/null
+++ b/Lib/test/test_smtplib.py
@@ -0,0 +1,71 @@
+import socket
+import threading
+import smtplib
+import time
+
+from unittest import TestCase
+from test import test_support
+
+
+def server(evt):
+ serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ serv.settimeout(3)
+ serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ serv.bind(("", 9091))
+ serv.listen(5)
+ try:
+ conn, addr = serv.accept()
+ except socket.timeout:
+ pass
+ else:
+ conn.send("220 Hola mundo\n")
+ conn.close()
+ finally:
+ serv.close()
+ evt.set()
+
+class GeneralTests(TestCase):
+
+ def setUp(self):
+ self.evt = threading.Event()
+ threading.Thread(target=server, args=(self.evt,)).start()
+ time.sleep(.1)
+
+ def tearDown(self):
+ self.evt.wait()
+
+ def testBasic(self):
+ # connects
+ smtp = smtplib.SMTP("localhost", 9091)
+ smtp.sock.close()
+
+ def testTimeoutDefault(self):
+ # default
+ smtp = smtplib.SMTP("localhost", 9091)
+ self.assertTrue(smtp.sock.gettimeout() is None)
+ smtp.sock.close()
+
+ def testTimeoutValue(self):
+ # a value
+ smtp = smtplib.SMTP("localhost", 9091, timeout=30)
+ self.assertEqual(smtp.sock.gettimeout(), 30)
+ smtp.sock.close()
+
+ def testTimeoutNone(self):
+ # None, having other default
+ previous = socket.getdefaulttimeout()
+ socket.setdefaulttimeout(30)
+ try:
+ smtp = smtplib.SMTP("localhost", 9091, timeout=None)
+ finally:
+ socket.setdefaulttimeout(previous)
+ self.assertEqual(smtp.sock.gettimeout(), 30)
+ smtp.sock.close()
+
+
+
+def test_main(verbose=None):
+ test_support.run_unittest(GeneralTests)
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py
index 4f74186..ead3e4f 100644
--- a/Lib/test/test_socket.py
+++ b/Lib/test/test_socket.py
@@ -75,7 +75,7 @@ class ThreadableTest:
Note, the server setup function cannot call any blocking
functions that rely on the client thread during setup,
- unless serverExplicityReady() is called just before
+ unless serverExplicitReady() is called just before
the blocking call (such as in setting up a client/server
connection and performing the accept() in setUp().
"""
@@ -597,6 +597,13 @@ class BasicUDPTest(ThreadedUDPSocketTest):
def _testRecvFrom(self):
self.cli.sendto(MSG, 0, (HOST, PORT))
+ def testRecvFromNegative(self):
+ # Negative lengths passed to recvfrom should give ValueError.
+ self.assertRaises(ValueError, self.serv.recvfrom, -1)
+
+ def _testRecvFromNegative(self):
+ self.cli.sendto(MSG, 0, (HOST, PORT))
+
class TCPCloserTest(ThreadedTCPSocketTest):
def testClose(self):
@@ -810,6 +817,98 @@ class SmallBufferedFileObjectClassTestCase(FileObjectClassTestCase):
bufsize = 2 # Exercise the buffering code
+class NetworkConnectionTest(object):
+ """Prove network connection."""
+ def clientSetUp(self):
+ self.cli = socket.create_connection((HOST, PORT))
+ self.serv_conn = self.cli
+
+class BasicTCPTest2(NetworkConnectionTest, BasicTCPTest):
+ """Tests that NetworkConnection does not break existing TCP functionality.
+ """
+
+class NetworkConnectionNoServer(unittest.TestCase):
+ def testWithoutServer(self):
+ self.failUnlessRaises(socket.error, lambda: socket.create_connection((HOST, PORT)))
+
+class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest):
+
+ def __init__(self, methodName='runTest'):
+ SocketTCPTest.__init__(self, methodName=methodName)
+ ThreadableTest.__init__(self)
+
+ def clientSetUp(self):
+ pass
+
+ def clientTearDown(self):
+ self.cli.close()
+ self.cli = None
+ ThreadableTest.clientTearDown(self)
+
+ def _justAccept(self):
+ conn, addr = self.serv.accept()
+
+ testFamily = _justAccept
+ def _testFamily(self):
+ self.cli = socket.create_connection((HOST, PORT), timeout=30)
+ self.assertEqual(self.cli.family, 2)
+
+ testTimeoutDefault = _justAccept
+ def _testTimeoutDefault(self):
+ self.cli = socket.create_connection((HOST, PORT))
+ self.assertTrue(self.cli.gettimeout() is None)
+
+ testTimeoutValueNamed = _justAccept
+ def _testTimeoutValueNamed(self):
+ self.cli = socket.create_connection((HOST, PORT), timeout=30)
+ self.assertEqual(self.cli.gettimeout(), 30)
+
+ testTimeoutValueNonamed = _justAccept
+ def _testTimeoutValueNonamed(self):
+ self.cli = socket.create_connection((HOST, PORT), 30)
+ self.assertEqual(self.cli.gettimeout(), 30)
+
+ testTimeoutNone = _justAccept
+ def _testTimeoutNone(self):
+ previous = socket.getdefaulttimeout()
+ socket.setdefaulttimeout(30)
+ try:
+ self.cli = socket.create_connection((HOST, PORT), timeout=None)
+ finally:
+ socket.setdefaulttimeout(previous)
+ self.assertEqual(self.cli.gettimeout(), 30)
+
+
+class NetworkConnectionBehaviourTest(SocketTCPTest, ThreadableTest):
+
+ def __init__(self, methodName='runTest'):
+ SocketTCPTest.__init__(self, methodName=methodName)
+ ThreadableTest.__init__(self)
+
+ def clientSetUp(self):
+ pass
+
+ def clientTearDown(self):
+ self.cli.close()
+ self.cli = None
+ ThreadableTest.clientTearDown(self)
+
+ def testInsideTimeout(self):
+ conn, addr = self.serv.accept()
+ time.sleep(3)
+ conn.send("done!")
+ testOutsideTimeout = testInsideTimeout
+
+ def _testInsideTimeout(self):
+ self.cli = sock = socket.create_connection((HOST, PORT))
+ data = sock.recv(5)
+ self.assertEqual(data, "done!")
+
+ def _testOutsideTimeout(self):
+ self.cli = sock = socket.create_connection((HOST, PORT), timeout=1)
+ self.failUnlessRaises(socket.timeout, lambda: sock.recv(5))
+
+
class Urllib2FileobjectTest(unittest.TestCase):
# urllib2.HTTPHandler has "borrowed" socket._fileobject, and requires that
@@ -977,7 +1076,7 @@ class BufferIOTest(SocketConnectedTest):
def test_main():
tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
- TestExceptions, BufferIOTest]
+ TestExceptions, BufferIOTest, BasicTCPTest2]
if sys.platform != 'mac':
tests.extend([ BasicUDPTest, UDPTimeoutTest ])
@@ -988,6 +1087,9 @@ def test_main():
LineBufferedFileObjectClassTestCase,
SmallBufferedFileObjectClassTestCase,
Urllib2FileobjectTest,
+ NetworkConnectionNoServer,
+ NetworkConnectionAttributesTest,
+ NetworkConnectionBehaviourTest,
])
if hasattr(socket, "socketpair"):
tests.append(BasicSocketPairTest)
diff --git a/Lib/test/test_socket_ssl.py b/Lib/test/test_socket_ssl.py
index b04effe..42efb6e 100644
--- a/Lib/test/test_socket_ssl.py
+++ b/Lib/test/test_socket_ssl.py
@@ -1,128 +1,221 @@
# Test just the SSL support in the socket module, in a moderately bogus way.
import sys
+import unittest
from test import test_support
import socket
import errno
+import threading
+import subprocess
+import time
+import os
+import urllib
-# Optionally test SSL support. This requires the 'network' resource as given
-# on the regrtest command line.
-skip_expected = not (test_support.is_resource_enabled('network') and
- hasattr(socket, "ssl"))
+# Optionally test SSL support, if we have it in the tested platform
+skip_expected = not hasattr(socket, "ssl")
-def test_basic():
- test_support.requires('network')
+class ConnectedTests(unittest.TestCase):
- import urllib
-
- if test_support.verbose:
- print("test_basic ...")
-
- socket.RAND_status()
- try:
- socket.RAND_egd(1)
- except TypeError:
- pass
- else:
- print("didn't raise TypeError")
- socket.RAND_add("this is a random string", 75.0)
-
- f = urllib.urlopen('https://sf.net')
- buf = f.read()
- f.close()
-
-def test_timeout():
- test_support.requires('network')
-
- def error_msg(extra_msg):
- print("""\
- WARNING: an attempt to connect to %r %s, in
- test_timeout. That may be legitimate, but is not the outcome we hoped
- for. If this message is seen often, test_timeout should be changed to
- use a more reliable address.""" % (ADDR, extra_msg), file=sys.stderr)
-
- if test_support.verbose:
- print("test_timeout ...")
-
- # A service which issues a welcome banner (without need to write
- # anything).
- # XXX ("gmail.org", 995) has been unreliable so far, from time to time
- # XXX non-responsive for hours on end (& across all buildbot slaves,
- # XXX so that's not just a local thing).
- ADDR = "gmail.org", 995
-
- s = socket.socket()
- s.settimeout(30.0)
- try:
- s.connect(ADDR)
- except socket.timeout:
- error_msg('timed out')
- return
- except socket.error as exc: # In case connection is refused.
- if exc.args[0] == errno.ECONNREFUSED:
- error_msg('was refused')
- return
+ def testBasic(self):
+ socket.RAND_status()
+ try:
+ socket.RAND_egd(1)
+ except TypeError:
+ pass
else:
- raise
+ print("didn't raise TypeError")
+ socket.RAND_add("this is a random string", 75.0)
+
+ with test_support.transient_internet():
+ f = urllib.urlopen('https://sf.net')
+ buf = f.read()
+ f.close()
+
+ def testTimeout(self):
+ def error_msg(extra_msg):
+ print("""\
+ WARNING: an attempt to connect to %r %s, in
+ test_timeout. That may be legitimate, but is not the outcome we
+ hoped for. If this message is seen often, test_timeout should be
+ changed to use a more reliable address.""" % (ADDR, extra_msg), file=sys.stderr)
+
+ # A service which issues a welcome banner (without need to write
+ # anything).
+ # XXX ("gmail.org", 995) has been unreliable so far, from time to
+ # XXX time non-responsive for hours on end (& across all buildbot
+ # XXX slaves, so that's not just a local thing).
+ ADDR = "gmail.org", 995
- ss = socket.ssl(s)
- # Read part of return welcome banner twice.
- ss.read(1)
- ss.read(1)
- s.close()
-
-def test_rude_shutdown():
- if test_support.verbose:
- print("test_rude_shutdown ...")
-
- try:
- import threading
- except ImportError:
- return
+ s = socket.socket()
+ s.settimeout(30.0)
+ try:
+ s.connect(ADDR)
+ except socket.timeout:
+ error_msg('timed out')
+ return
+ except socket.error as exc: # In case connection is refused.
+ if exc.args[0] == errno.ECONNREFUSED:
+ error_msg('was refused')
+ return
+ else:
+ raise
+
+ ss = socket.ssl(s)
+ # Read part of return welcome banner twice.
+ ss.read(1)
+ ss.read(1)
+ s.close()
+
+class BasicTests(unittest.TestCase):
+
+ def testRudeShutdown(self):
+ # Some random port to connect to.
+ PORT = [9934]
+
+ listener_ready = threading.Event()
+ listener_gone = threading.Event()
+
+ # `listener` runs in a thread. It opens a socket listening on
+ # PORT, and sits in an accept() until the main thread connects.
+ # Then it rudely closes the socket, and sets Event `listener_gone`
+ # to let the main thread know the socket is gone.
+ def listener():
+ s = socket.socket()
+ PORT[0] = test_support.bind_port(s, '', PORT[0])
+ s.listen(5)
+ listener_ready.set()
+ s.accept()
+ s = None # reclaim the socket object, which also closes it
+ listener_gone.set()
+
+ def connector():
+ listener_ready.wait()
+ s = socket.socket()
+ s.connect(('localhost', PORT[0]))
+ listener_gone.wait()
+ try:
+ ssl_sock = socket.ssl(s)
+ except socket.sslerror:
+ pass
+ else:
+ raise test_support.TestFailed(
+ 'connecting to closed SSL socket should have failed')
- # Some random port to connect to.
- PORT = [9934]
+ t = threading.Thread(target=listener)
+ t.start()
+ connector()
+ t.join()
- listener_ready = threading.Event()
- listener_gone = threading.Event()
+class OpenSSLTests(unittest.TestCase):
- # `listener` runs in a thread. It opens a socket listening on PORT, and
- # sits in an accept() until the main thread connects. Then it rudely
- # closes the socket, and sets Event `listener_gone` to let the main thread
- # know the socket is gone.
- def listener():
+ def testBasic(self):
s = socket.socket()
- PORT[0] = test_support.bind_port(s, '', PORT[0])
- s.listen(5)
- listener_ready.set()
- s.accept()
- s = None # reclaim the socket object, which also closes it
- listener_gone.set()
-
- def connector():
- listener_ready.wait()
+ s.connect(("localhost", 4433))
+ ss = socket.ssl(s)
+ ss.write("Foo\n")
+ i = ss.read(4)
+ self.assertEqual(i, "Foo\n")
+ s.close()
+
+ def testMethods(self):
+ # read & write is already tried in the Basic test
+ # now we'll try to get the server info about certificates
+ # this came from the certificate I used, one I found in /usr/share/openssl
+ info = "/C=PT/ST=Queensland/L=Lisboa/O=Neuronio, Lda./OU=Desenvolvimento/CN=brutus.neuronio.pt/emailAddress=sampo@iki.fi"
+
s = socket.socket()
- s.connect(('localhost', PORT[0]))
- listener_gone.wait()
+ s.connect(("localhost", 4433))
+ ss = socket.ssl(s)
+ cert = ss.server()
+ self.assertEqual(cert, info)
+ cert = ss.issuer()
+ self.assertEqual(cert, info)
+ s.close()
+
+
+class OpenSSLServer(threading.Thread):
+ def __init__(self):
+ self.s = None
+ self.keepServing = True
+ self._external()
+ if self.haveServer:
+ threading.Thread.__init__(self)
+
+ def _external(self):
+ # let's find the .pem files
+ curdir = os.path.dirname(__file__) or os.curdir
+ cert_file = os.path.join(curdir, "ssl_cert.pem")
+ if not os.access(cert_file, os.F_OK):
+ raise ValueError("No cert file found! (tried %r)" % cert_file)
+ key_file = os.path.join(curdir, "ssl_key.pem")
+ if not os.access(key_file, os.F_OK):
+ raise ValueError("No key file found! (tried %r)" % key_file)
+
try:
- ssl_sock = socket.ssl(s)
- except socket.sslerror:
- pass
+ cmd = "openssl s_server -cert %s -key %s -quiet" % (cert_file, key_file)
+ self.s = subprocess.Popen(cmd.split(), stdin=subprocess.PIPE,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT)
+ time.sleep(1)
+ except:
+ self.haveServer = False
else:
- raise test_support.TestFailed(
- 'connecting to closed SSL socket should have failed')
-
- t = threading.Thread(target=listener)
- t.start()
- connector()
- t.join()
+ # let's try if it is actually up
+ try:
+ s = socket.socket()
+ s.connect(("localhost", 4433))
+ s.close()
+ if self.s.stdout.readline() != "ERROR\n":
+ raise ValueError
+ except:
+ self.haveServer = False
+ else:
+ self.haveServer = True
+
+ def run(self):
+ while self.keepServing:
+ time.sleep(.5)
+ l = self.s.stdout.readline()
+ self.s.stdin.write(l)
+
+ def shutdown(self):
+ self.keepServing = False
+ if not self.s:
+ return
+ if sys.platform == "win32":
+ subprocess.TerminateProcess(int(self.s._handle), -1)
+ else:
+ os.kill(self.s.pid, 15)
def test_main():
if not hasattr(socket, "ssl"):
raise test_support.TestSkipped("socket module has no ssl support")
- test_rude_shutdown()
- test_basic()
- test_timeout()
+
+ tests = [BasicTests]
+
+ if test_support.is_resource_enabled('network'):
+ tests.append(ConnectedTests)
+
+ # in these platforms we can kill the openssl process
+ if sys.platform in ("sunos5", "darwin", "linux1",
+ "linux2", "win32", "hp-ux11"):
+
+ server = OpenSSLServer()
+ if server.haveServer:
+ tests.append(OpenSSLTests)
+ server.start()
+ else:
+ server = None
+
+ thread_info = test_support.threading_setup()
+
+ try:
+ test_support.run_unittest(*tests)
+ finally:
+ if server is not None and server.haveServer:
+ server.shutdown()
+
+ test_support.threading_cleanup(*thread_info)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_socketserver.py b/Lib/test/test_socketserver.py
index 062be65..da936a4 100644
--- a/Lib/test/test_socketserver.py
+++ b/Lib/test/test_socketserver.py
@@ -74,6 +74,7 @@ class ServerThread(threading.Thread):
self.__addr = addr
self.__svrcls = svrcls
self.__hdlrcls = hdlrcls
+ self.ready = threading.Event()
def run(self):
class svrcls(MyMixinServer, self.__svrcls):
pass
@@ -81,9 +82,13 @@ class ServerThread(threading.Thread):
svr = svrcls(self.__addr, self.__hdlrcls)
# pull the address out of the server in case it changed
# this can happen if another process is using the port
- addr = getattr(svr, 'server_address')
+ addr = svr.server_address
if addr:
self.__addr = addr
+ if self.__addr != svr.socket.getsockname():
+ raise RuntimeError('server_address was %s, expected %s' %
+ (self.__addr, svr.socket.getsockname()))
+ self.ready.set()
if verbose: print("thread: serving three times")
svr.serve_a_few()
if verbose: print("thread: done")
@@ -136,7 +141,9 @@ def testloop(proto, servers, hdlrcls, testfunc):
t.start()
if verbose: print("server running")
for i in range(NREQ):
- time.sleep(DELAY)
+ t.ready.wait(10*DELAY)
+ if not t.ready.isSet():
+ raise RuntimeError("Server not ready within a reasonable time")
if verbose: print("test client", i)
testfunc(proto, addr)
if verbose: print("waiting for server")
diff --git a/Lib/test/test_stringprep.py b/Lib/test/test_stringprep.py
index 2baf4a5..60425dd 100644
--- a/Lib/test/test_stringprep.py
+++ b/Lib/test/test_stringprep.py
@@ -1,88 +1,96 @@
# To fully test this module, we would need a copy of the stringprep tables.
# Since we don't have them, this test checks only a few codepoints.
-from test.test_support import verify, vereq
+import unittest
+from test import test_support
-import stringprep
from stringprep import *
-verify(in_table_a1(u"\u0221"))
-verify(not in_table_a1(u"\u0222"))
+class StringprepTests(unittest.TestCase):
+ def test(self):
+ self.failUnless(in_table_a1(u"\u0221"))
+ self.failIf(in_table_a1(u"\u0222"))
-verify(in_table_b1(u"\u00ad"))
-verify(not in_table_b1(u"\u00ae"))
+ self.failUnless(in_table_b1(u"\u00ad"))
+ self.failIf(in_table_b1(u"\u00ae"))
-verify(map_table_b2(u"\u0041"), u"\u0061")
-verify(map_table_b2(u"\u0061"), u"\u0061")
+ self.failUnless(map_table_b2(u"\u0041"), u"\u0061")
+ self.failUnless(map_table_b2(u"\u0061"), u"\u0061")
-verify(map_table_b3(u"\u0041"), u"\u0061")
-verify(map_table_b3(u"\u0061"), u"\u0061")
+ self.failUnless(map_table_b3(u"\u0041"), u"\u0061")
+ self.failUnless(map_table_b3(u"\u0061"), u"\u0061")
-verify(in_table_c11(u"\u0020"))
-verify(not in_table_c11(u"\u0021"))
+ self.failUnless(in_table_c11(u"\u0020"))
+ self.failIf(in_table_c11(u"\u0021"))
-verify(in_table_c12(u"\u00a0"))
-verify(not in_table_c12(u"\u00a1"))
+ self.failUnless(in_table_c12(u"\u00a0"))
+ self.failIf(in_table_c12(u"\u00a1"))
-verify(in_table_c12(u"\u00a0"))
-verify(not in_table_c12(u"\u00a1"))
+ self.failUnless(in_table_c12(u"\u00a0"))
+ self.failIf(in_table_c12(u"\u00a1"))
-verify(in_table_c11_c12(u"\u00a0"))
-verify(not in_table_c11_c12(u"\u00a1"))
+ self.failUnless(in_table_c11_c12(u"\u00a0"))
+ self.failIf(in_table_c11_c12(u"\u00a1"))
-verify(in_table_c21(u"\u001f"))
-verify(not in_table_c21(u"\u0020"))
+ self.failUnless(in_table_c21(u"\u001f"))
+ self.failIf(in_table_c21(u"\u0020"))
-verify(in_table_c22(u"\u009f"))
-verify(not in_table_c22(u"\u00a0"))
+ self.failUnless(in_table_c22(u"\u009f"))
+ self.failIf(in_table_c22(u"\u00a0"))
-verify(in_table_c21_c22(u"\u009f"))
-verify(not in_table_c21_c22(u"\u00a0"))
+ self.failUnless(in_table_c21_c22(u"\u009f"))
+ self.failIf(in_table_c21_c22(u"\u00a0"))
-verify(in_table_c3(u"\ue000"))
-verify(not in_table_c3(u"\uf900"))
+ self.failUnless(in_table_c3(u"\ue000"))
+ self.failIf(in_table_c3(u"\uf900"))
-verify(in_table_c4(u"\uffff"))
-verify(not in_table_c4(u"\u0000"))
+ self.failUnless(in_table_c4(u"\uffff"))
+ self.failIf(in_table_c4(u"\u0000"))
-verify(in_table_c5(u"\ud800"))
-verify(not in_table_c5(u"\ud7ff"))
+ self.failUnless(in_table_c5(u"\ud800"))
+ self.failIf(in_table_c5(u"\ud7ff"))
-verify(in_table_c6(u"\ufff9"))
-verify(not in_table_c6(u"\ufffe"))
+ self.failUnless(in_table_c6(u"\ufff9"))
+ self.failIf(in_table_c6(u"\ufffe"))
-verify(in_table_c7(u"\u2ff0"))
-verify(not in_table_c7(u"\u2ffc"))
+ self.failUnless(in_table_c7(u"\u2ff0"))
+ self.failIf(in_table_c7(u"\u2ffc"))
-verify(in_table_c8(u"\u0340"))
-verify(not in_table_c8(u"\u0342"))
+ self.failUnless(in_table_c8(u"\u0340"))
+ self.failIf(in_table_c8(u"\u0342"))
-# C.9 is not in the bmp
-# verify(in_table_c9(u"\U000E0001"))
-# verify(not in_table_c8(u"\U000E0002"))
+ # C.9 is not in the bmp
+ # self.failUnless(in_table_c9(u"\U000E0001"))
+ # self.failIf(in_table_c8(u"\U000E0002"))
-verify(in_table_d1(u"\u05be"))
-verify(not in_table_d1(u"\u05bf"))
+ self.failUnless(in_table_d1(u"\u05be"))
+ self.failIf(in_table_d1(u"\u05bf"))
-verify(in_table_d2(u"\u0041"))
-verify(not in_table_d2(u"\u0040"))
+ self.failUnless(in_table_d2(u"\u0041"))
+ self.failIf(in_table_d2(u"\u0040"))
-# This would generate a hash of all predicates. However, running
-# it is quite expensive, and only serves to detect changes in the
-# unicode database. Instead, stringprep.py asserts the version of
-# the database.
+ # This would generate a hash of all predicates. However, running
+ # it is quite expensive, and only serves to detect changes in the
+ # unicode database. Instead, stringprep.py asserts the version of
+ # the database.
-# import hashlib
-# predicates = [k for k in dir(stringprep) if k.startswith("in_table")]
-# predicates.sort()
-# for p in predicates:
-# f = getattr(stringprep, p)
-# # Collect all BMP code points
-# data = ["0"] * 0x10000
-# for i in range(0x10000):
-# if f(unichr(i)):
-# data[i] = "1"
-# data = "".join(data)
-# h = hashlib.sha1()
-# h.update(data)
-# print p, h.hexdigest()
+ # import hashlib
+ # predicates = [k for k in dir(stringprep) if k.startswith("in_table")]
+ # predicates.sort()
+ # for p in predicates:
+ # f = getattr(stringprep, p)
+ # # Collect all BMP code points
+ # data = ["0"] * 0x10000
+ # for i in range(0x10000):
+ # if f(unichr(i)):
+ # data[i] = "1"
+ # data = "".join(data)
+ # h = hashlib.sha1()
+ # h.update(data)
+ # print p, h.hexdigest()
+
+def test_main():
+ test_support.run_unittest(StringprepTests)
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_strptime.py b/Lib/test/test_strptime.py
index c1af281..0e1909e 100644
--- a/Lib/test/test_strptime.py
+++ b/Lib/test/test_strptime.py
@@ -505,6 +505,35 @@ class CacheTests(unittest.TestCase):
self.failIfEqual(locale_time_id,
id(_strptime._TimeRE_cache.locale_time))
+ def test_TimeRE_recreation(self):
+ # The TimeRE instance should be recreated upon changing the locale.
+ locale_info = locale.getlocale(locale.LC_TIME)
+ try:
+ locale.setlocale(locale.LC_TIME, ('en_US', 'UTF8'))
+ except locale.Error:
+ return
+ try:
+ _strptime.strptime('10', '%d')
+ # Get id of current cache object.
+ first_time_re_id = id(_strptime._TimeRE_cache)
+ try:
+ # Change the locale and force a recreation of the cache.
+ locale.setlocale(locale.LC_TIME, ('de_DE', 'UTF8'))
+ _strptime.strptime('10', '%d')
+ # Get the new cache object's id.
+ second_time_re_id = id(_strptime._TimeRE_cache)
+ # They should not be equal.
+ self.failIfEqual(first_time_re_id, second_time_re_id)
+ # Possible test locale is not supported while initial locale is.
+ # If this is the case just suppress the exception and fall-through
+ # to the reseting to the original locale.
+ except locale.Error:
+ pass
+ # Make sure we don't trample on the locale setting once we leave the
+ # test.
+ finally:
+ locale.setlocale(locale.LC_TIME, locale_info)
+
def test_main():
test_support.run_unittest(
diff --git a/Lib/test/test_struct.py b/Lib/test/test_struct.py
index 0678761..b62d74c 100644
--- a/Lib/test/test_struct.py
+++ b/Lib/test/test_struct.py
@@ -614,53 +614,61 @@ def test_pack_into_fn():
assertRaises(struct.error, pack_into, small_buf, 0, test_string)
assertRaises(struct.error, pack_into, small_buf, 2, test_string)
+def test_unpack_with_buffer():
+ # SF bug 1563759: struct.unpack doens't support buffer protocol objects
+ data1 = array.array('B', '\x12\x34\x56\x78')
+ data2 = buffer('......\x12\x34\x56\x78......', 6, 4)
+ for data in [data1, data2]:
+ value, = struct.unpack('>I', data)
+ vereq(value, 0x12345678)
# Test methods to pack and unpack from buffers rather than strings.
test_unpack_from()
test_pack_into()
test_pack_into_fn()
+test_unpack_with_buffer()
def test_bool():
for prefix in tuple("<>!=")+('',):
false = (), [], [], '', 0
true = [1], 'test', 5, -1, 0xffffffff+1, 0xffffffff/2
-
+
falseFormat = prefix + 't' * len(false)
if verbose:
print('trying bool pack/unpack on', false, 'using format', falseFormat)
packedFalse = struct.pack(falseFormat, *false)
unpackedFalse = struct.unpack(falseFormat, packedFalse)
-
+
trueFormat = prefix + 't' * len(true)
if verbose:
print('trying bool pack/unpack on', true, 'using format', trueFormat)
packedTrue = struct.pack(trueFormat, *true)
unpackedTrue = struct.unpack(trueFormat, packedTrue)
-
+
if len(true) != len(unpackedTrue):
raise TestFailed('unpacked true array is not of same size as input')
if len(false) != len(unpackedFalse):
raise TestFailed('unpacked false array is not of same size as input')
-
+
for t in unpackedFalse:
if t is not False:
raise TestFailed('%r did not unpack as False' % t)
for t in unpackedTrue:
if t is not True:
raise TestFailed('%r did not unpack as false' % t)
-
+
if prefix and verbose:
print('trying size of bool with format %r' % (prefix+'t'))
packed = struct.pack(prefix+'t', 1)
-
+
if len(packed) != struct.calcsize(prefix+'t'):
raise TestFailed('packed length is not equal to calculated size')
-
+
if len(packed) != 1 and prefix:
raise TestFailed('encoded bool is not one byte: %r' % packed)
elif not prefix and verbose:
print('size of bool in native format is %i' % (len(packed)))
-
+
for c in '\x01\x7f\xff\x0f\xf0':
if struct.unpack('>t', c)[0] is not True:
raise TestFailed('%c did not unpack as True' % c)
diff --git a/Lib/test/test_structmembers.py b/Lib/test/test_structmembers.py
index 0713b87..599c6fb 100644
--- a/Lib/test/test_structmembers.py
+++ b/Lib/test/test_structmembers.py
@@ -4,7 +4,7 @@ from _testcapi import test_structmembersType, \
INT_MAX, INT_MIN, UINT_MAX, \
LONG_MAX, LONG_MIN, ULONG_MAX
-import warnings, unittest, test.test_warnings
+import warnings, unittest
from test import test_support
ts=test_structmembersType(1,2,3,4,5,6,7,8,9.99999,10.1010101010)
@@ -39,34 +39,39 @@ class ReadWriteTests(unittest.TestCase):
ts.T_ULONG=ULONG_MAX
self.assertEquals(ts.T_ULONG, ULONG_MAX)
-class TestWarnings(test.test_warnings.TestModule):
- def has_warned(self):
- self.assertEqual(test.test_warnings.msg.category,
- RuntimeWarning.__name__)
+class TestWarnings(unittest.TestCase):
+ def has_warned(self, w):
+ self.assert_(w.category is RuntimeWarning)
def test_byte_max(self):
- ts.T_BYTE=CHAR_MAX+1
- self.has_warned()
+ with test_support.catch_warning() as w:
+ ts.T_BYTE=CHAR_MAX+1
+ self.has_warned(w)
def test_byte_min(self):
- ts.T_BYTE=CHAR_MIN-1
- self.has_warned()
+ with test_support.catch_warning() as w:
+ ts.T_BYTE=CHAR_MIN-1
+ self.has_warned(w)
def test_ubyte_max(self):
- ts.T_UBYTE=UCHAR_MAX+1
- self.has_warned()
+ with test_support.catch_warning() as w:
+ ts.T_UBYTE=UCHAR_MAX+1
+ self.has_warned(w)
def test_short_max(self):
- ts.T_SHORT=SHRT_MAX+1
- self.has_warned()
+ with test_support.catch_warning() as w:
+ ts.T_SHORT=SHRT_MAX+1
+ self.has_warned(w)
def test_short_min(self):
- ts.T_SHORT=SHRT_MIN-1
- self.has_warned()
+ with test_support.catch_warning() as w:
+ ts.T_SHORT=SHRT_MIN-1
+ self.has_warned(w)
def test_ushort_max(self):
- ts.T_USHORT=USHRT_MAX+1
- self.has_warned()
+ with test_support.catch_warning() as w:
+ ts.T_USHORT=USHRT_MAX+1
+ self.has_warned(w)
diff --git a/Lib/test/test_support.py b/Lib/test/test_support.py
index 6fbb3cc..1ff0e4d 100644
--- a/Lib/test/test_support.py
+++ b/Lib/test/test_support.py
@@ -1,11 +1,17 @@
"""Supporting definitions for the Python regression tests."""
if __name__ != 'test.test_support':
- raise ImportError, 'test_support must be imported from the test package'
+ raise ImportError('test_support must be imported from the test package')
-from contextlib import contextmanager
+import contextlib
+import errno
+import socket
import sys
+import os
+import os.path
import warnings
+import types
+import unittest
class Error(Exception):
"""Base class for regression test exceptions."""
@@ -54,7 +60,6 @@ def unload(name):
pass
def unlink(filename):
- import os
try:
os.unlink(filename)
except OSError:
@@ -64,7 +69,6 @@ def forget(modname):
'''"Forget" a module was ever imported by removing it from sys.modules and
deleting any .pyc and .pyo files.'''
unload(modname)
- import os
for dirname in sys.path:
unlink(os.path.join(dirname, modname + os.extsep + 'pyc'))
# Deleting the .pyo file cannot be within the 'try' for the .pyc since
@@ -96,7 +100,6 @@ def bind_port(sock, host='', preferred_port=54321):
tests and we don't try multiple ports, the test can fails. This
makes the test more robust."""
- import socket, errno
# some random ports that hopefully no one is listening on.
for port in [preferred_port, 9907, 10243, 32999]:
try:
@@ -107,7 +110,7 @@ def bind_port(sock, host='', preferred_port=54321):
if err != errno.EADDRINUSE:
raise
print(' WARNING: failed to listen on port %d, trying another' % port, file=sys.__stderr__)
- raise TestFailed, 'unable to find port to listen on'
+ raise TestFailed('unable to find port to listen on')
FUZZ = 1e-6
@@ -135,7 +138,6 @@ except NameError:
is_jython = sys.platform.startswith('java')
-import os
# Filename used for testing
if os.name == 'java':
# Jython disallows @ in module names
@@ -197,13 +199,12 @@ except IOError:
if fp is not None:
fp.close()
unlink(TESTFN)
-del os, fp
+del fp
def findfile(file, here=__file__):
"""Try to find a file on sys.path and the working directory. If it is not
found the argument passed to the function is returned (this does not
necessarily signal failure; could still be the legitimate path)."""
- import os
if os.path.isabs(file):
return file
path = sys.path
@@ -235,7 +236,7 @@ def vereq(a, b):
"""
if not (a == b):
- raise TestFailed, "%r == %r" % (a, b)
+ raise TestFailed("%r == %r" % (a, b))
def sortdict(dict):
"Like repr(dict), but in sorted order."
@@ -254,7 +255,6 @@ def check_syntax_error(testcase, statement):
def open_urlresource(url):
import urllib, urlparse
- import os.path
filename = urlparse.urlparse(url)[2].split('/')[-1] # '/': it's URL!
@@ -268,7 +268,7 @@ def open_urlresource(url):
fn, _ = urllib.urlretrieve(url, filename)
return open(fn)
-@contextmanager
+@contextlib.contextmanager
def guard_warnings_filter():
"""Guard the warnings filter from being permanently changed."""
original_filters = warnings.filters[:]
@@ -277,14 +277,49 @@ def guard_warnings_filter():
finally:
warnings.filters = original_filters
+class WarningMessage(object):
+ "Holds the result of the latest showwarning() call"
+ def __init__(self):
+ self.message = None
+ self.category = None
+ self.filename = None
+ self.lineno = None
+
+ def _showwarning(self, message, category, filename, lineno, file=None):
+ self.message = message
+ self.category = category
+ self.filename = filename
+ self.lineno = lineno
+
+@contextlib.contextmanager
+def catch_warning():
+ """
+ Guard the warnings filter from being permanently changed and record the
+ data of the last warning that has been issued.
+
+ Use like this:
+
+ with catch_warning as w:
+ warnings.warn("foo")
+ assert str(w.message) == "foo"
+ """
+ warning = WarningMessage()
+ original_filters = warnings.filters[:]
+ original_showwarning = warnings.showwarning
+ warnings.showwarning = warning._showwarning
+ try:
+ yield warning
+ finally:
+ warnings.showwarning = original_showwarning
+ warnings.filters = original_filters
+
class EnvironmentVarGuard(object):
"""Class to help protect the environment variable properly. Can be used as
a context manager."""
def __init__(self):
- from os import environ
- self._environ = environ
+ self._environ = os.environ
self._unset = set()
self._reset = dict()
@@ -309,6 +344,40 @@ class EnvironmentVarGuard(object):
for unset in self._unset:
del self._environ[unset]
+class TransientResource(object):
+
+ """Raise ResourceDenied if an exception is raised while the context manager
+ is in effect that matches the specified exception and attributes."""
+
+ def __init__(self, exc, **kwargs):
+ self.exc = exc
+ self.attrs = kwargs
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type_=None, value=None, traceback=None):
+ """If type_ is a subclass of self.exc and value has attributes matching
+ self.attrs, raise ResourceDenied. Otherwise let the exception
+ propagate (if any)."""
+ if type_ is not None and issubclass(self.exc, type_):
+ for attr, attr_value in self.attrs.items():
+ if not hasattr(value, attr):
+ break
+ if getattr(value, attr) != attr_value:
+ break
+ else:
+ raise ResourceDenied("an optional resource is not available")
+
+
+def transient_internet():
+ """Return a context manager that raises ResourceDenied when various issues
+ with the Internet connection manifest themselves as exceptions."""
+ time_out = TransientResource(IOError, errno=errno.ETIMEDOUT)
+ socket_peer_reset = TransientResource(socket.error, errno=errno.ECONNRESET)
+ ioerror_peer_reset = TransientResource(IOError, errno=errno.ECONNRESET)
+ return contextlib.nested(time_out, socket_peer_reset, ioerror_peer_reset)
+
#=======================================================================
# Decorator for running a function in a different locale, correctly resetting
@@ -432,10 +501,7 @@ def bigaddrspacetest(f):
return wrapper
#=======================================================================
-# Preliminary PyUNIT integration.
-
-import unittest
-
+# unittest integration.
class BasicTestRunner:
def run(self, test):
@@ -444,7 +510,7 @@ class BasicTestRunner:
return result
-def run_suite(suite, testclass=None):
+def _run_suite(suite):
"""Run tests from a unittest.TestSuite-derived class."""
if verbose:
runner = unittest.TextTestRunner(sys.stdout, verbosity=2)
@@ -458,28 +524,26 @@ def run_suite(suite, testclass=None):
elif len(result.failures) == 1 and not result.errors:
err = result.failures[0][1]
else:
- if testclass is None:
- msg = "errors occurred; run in verbose mode for details"
- else:
- msg = "errors occurred in %s.%s" \
- % (testclass.__module__, testclass.__name__)
+ msg = "errors occurred; run in verbose mode for details"
raise TestFailed(msg)
raise TestFailed(err)
def run_unittest(*classes):
"""Run tests from unittest.TestCase-derived classes."""
+ valid_types = (unittest.TestSuite, unittest.TestCase)
suite = unittest.TestSuite()
for cls in classes:
- if isinstance(cls, (unittest.TestSuite, unittest.TestCase)):
+ if isinstance(cls, str):
+ if cls in sys.modules:
+ suite.addTest(unittest.findTestCases(sys.modules[cls]))
+ else:
+ raise ValueError("str arguments must be keys in sys.modules")
+ elif isinstance(cls, valid_types):
suite.addTest(cls)
else:
suite.addTest(unittest.makeSuite(cls))
- if len(classes)==1:
- testclass = classes[0]
- else:
- testclass = None
- run_suite(suite, testclass)
+ _run_suite(suite)
#=======================================================================
@@ -545,7 +609,6 @@ def reap_children():
# Reap all our dead child processes so we don't leave zombies around.
# These hog resources and might be causing some of the buildbots to die.
- import os
if hasattr(os, 'waitpid'):
any_process = -1
while True:
diff --git a/Lib/test/test_syntax.py b/Lib/test/test_syntax.py
index b5a5c5d..2b48ea6 100644
--- a/Lib/test/test_syntax.py
+++ b/Lib/test/test_syntax.py
@@ -374,7 +374,7 @@ Misuse of the nonlocal statement can lead to a few unique syntax errors.
Traceback (most recent call last):
...
SyntaxError: name 'x' is parameter and nonlocal
-
+
>>> def f():
... global x
... nonlocal x
@@ -403,7 +403,7 @@ TODO(jhylton): Figure out how to test SyntaxWarning with doctest.
## Traceback (most recent call last):
## ...
## SyntaxWarning: name 'x' is assigned to before nonlocal declaration
-
+
## >>> def f():
## ... x = 1
## ... nonlocal x
@@ -411,7 +411,56 @@ TODO(jhylton): Figure out how to test SyntaxWarning with doctest.
## ...
## SyntaxWarning: name 'x' is assigned to before nonlocal declaration
-
+
+This tests assignment-context; there was a bug in Python 2.5 where compiling
+a complex 'if' (one with 'elif') would fail to notice an invalid suite,
+leading to spurious errors.
+
+ >>> if 1:
+ ... x() = 1
+ ... elif 1:
+ ... pass
+ Traceback (most recent call last):
+ ...
+ SyntaxError: can't assign to function call (<doctest test.test_syntax[48]>, line 2)
+
+ >>> if 1:
+ ... pass
+ ... elif 1:
+ ... x() = 1
+ Traceback (most recent call last):
+ ...
+ SyntaxError: can't assign to function call (<doctest test.test_syntax[49]>, line 4)
+
+ >>> if 1:
+ ... x() = 1
+ ... elif 1:
+ ... pass
+ ... else:
+ ... pass
+ Traceback (most recent call last):
+ ...
+ SyntaxError: can't assign to function call (<doctest test.test_syntax[50]>, line 2)
+
+ >>> if 1:
+ ... pass
+ ... elif 1:
+ ... x() = 1
+ ... else:
+ ... pass
+ Traceback (most recent call last):
+ ...
+ SyntaxError: can't assign to function call (<doctest test.test_syntax[51]>, line 4)
+
+ >>> if 1:
+ ... pass
+ ... elif 1:
+ ... pass
+ ... else:
+ ... x() = 1
+ Traceback (most recent call last):
+ ...
+ SyntaxError: can't assign to function call (<doctest test.test_syntax[52]>, line 6)
"""
diff --git a/Lib/test/test_tarfile.py b/Lib/test/test_tarfile.py
index e9bf497..ac7dca3 100644
--- a/Lib/test/test_tarfile.py
+++ b/Lib/test/test_tarfile.py
@@ -1,8 +1,12 @@
+# encoding: iso8859-1
+
import sys
import os
import shutil
import tempfile
import StringIO
+import md5
+import errno
import unittest
import tarfile
@@ -20,452 +24,547 @@ try:
except ImportError:
bz2 = None
+def md5sum(data):
+ return md5.new(data).hexdigest()
+
def path(path):
return test_support.findfile(path)
-testtar = path("testtar.tar")
-tempdir = os.path.join(tempfile.gettempdir(), "testtar" + os.extsep + "dir")
-tempname = test_support.TESTFN
-membercount = 12
-
-def tarname(comp=""):
- if not comp:
- return testtar
- return os.path.join(tempdir, "%s%s%s" % (testtar, os.extsep, comp))
+TEMPDIR = os.path.join(tempfile.gettempdir(), "test_tarfile_tmp")
+tarname = path("testtar.tar")
+gzipname = os.path.join(TEMPDIR, "testtar.tar.gz")
+bz2name = os.path.join(TEMPDIR, "testtar.tar.bz2")
+tmpname = os.path.join(TEMPDIR, "tmp.tar")
-def dirname():
- if not os.path.exists(tempdir):
- os.mkdir(tempdir)
- return tempdir
+md5_regtype = "65f477c818ad9e15f7feab0c6d37742f"
+md5_sparse = "a54fbc4ca4f4399a90e1b27164012fc6"
-def tmpname():
- return tempname
+class ReadTest(unittest.TestCase):
-class BaseTest(unittest.TestCase):
- comp = ''
- mode = 'r'
- sep = ':'
+ tarname = tarname
+ mode = "r:"
def setUp(self):
- mode = self.mode + self.sep + self.comp
- self.tar = tarfile.open(tarname(self.comp), mode)
+ self.tar = tarfile.open(self.tarname, mode=self.mode, encoding="iso8859-1")
def tearDown(self):
self.tar.close()
-class ReadTest(BaseTest):
- def test(self):
- """Test member extraction.
- """
- members = 0
+class UstarReadTest(ReadTest):
+
+ def test_fileobj_regular_file(self):
+ tarinfo = self.tar.getmember("ustar/regtype")
+ fobj = self.tar.extractfile(tarinfo)
+ data = fobj.read()
+ self.assert_((len(data), md5sum(data)) == (tarinfo.size, md5_regtype),
+ "regular file extraction failed")
+
+ def test_fileobj_readlines(self):
+ self.tar.extract("ustar/regtype", TEMPDIR)
+ tarinfo = self.tar.getmember("ustar/regtype")
+ fobj1 = open(os.path.join(TEMPDIR, "ustar/regtype"), "rU")
+ fobj2 = self.tar.extractfile(tarinfo)
+
+ lines1 = fobj1.readlines()
+ lines2 = fobj2.readlines()
+ self.assert_(lines1 == lines2,
+ "fileobj.readlines() failed")
+ self.assert_(len(lines2) == 114,
+ "fileobj.readlines() failed")
+ self.assert_(lines2[83] == \
+ "I will gladly admit that Python is not the fastest running scripting language.\n",
+ "fileobj.readlines() failed")
+
+ def test_fileobj_iter(self):
+ self.tar.extract("ustar/regtype", TEMPDIR)
+ tarinfo = self.tar.getmember("ustar/regtype")
+ fobj1 = open(os.path.join(TEMPDIR, "ustar/regtype"), "rU")
+ fobj2 = self.tar.extractfile(tarinfo)
+ lines1 = fobj1.readlines()
+ lines2 = [line for line in fobj2]
+ self.assert_(lines1 == lines2,
+ "fileobj.__iter__() failed")
+
+ def test_fileobj_seek(self):
+ self.tar.extract("ustar/regtype", TEMPDIR)
+ fobj = open(os.path.join(TEMPDIR, "ustar/regtype"), "rb")
+ data = fobj.read()
+ fobj.close()
+
+ tarinfo = self.tar.getmember("ustar/regtype")
+ fobj = self.tar.extractfile(tarinfo)
+
+ text = fobj.read()
+ fobj.seek(0)
+ self.assert_(0 == fobj.tell(),
+ "seek() to file's start failed")
+ fobj.seek(2048, 0)
+ self.assert_(2048 == fobj.tell(),
+ "seek() to absolute position failed")
+ fobj.seek(-1024, 1)
+ self.assert_(1024 == fobj.tell(),
+ "seek() to negative relative position failed")
+ fobj.seek(1024, 1)
+ self.assert_(2048 == fobj.tell(),
+ "seek() to positive relative position failed")
+ s = fobj.read(10)
+ self.assert_(s == data[2048:2058],
+ "read() after seek failed")
+ fobj.seek(0, 2)
+ self.assert_(tarinfo.size == fobj.tell(),
+ "seek() to file's end failed")
+ self.assert_(fobj.read() == "",
+ "read() at file's end did not return empty string")
+ fobj.seek(-tarinfo.size, 2)
+ self.assert_(0 == fobj.tell(),
+ "relative seek() to file's start failed")
+ fobj.seek(512)
+ s1 = fobj.readlines()
+ fobj.seek(512)
+ s2 = fobj.readlines()
+ self.assert_(s1 == s2,
+ "readlines() after seek failed")
+ fobj.seek(0)
+ self.assert_(len(fobj.readline()) == fobj.tell(),
+ "tell() after readline() failed")
+ fobj.seek(512)
+ self.assert_(len(fobj.readline()) + 512 == fobj.tell(),
+ "tell() after seek() and readline() failed")
+ fobj.seek(0)
+ line = fobj.readline()
+ self.assert_(fobj.read() == data[len(line):],
+ "read() after readline() failed")
+ fobj.close()
+
+
+class MiscReadTest(ReadTest):
+
+ def test_no_filename(self):
+ fobj = open(self.tarname, "rb")
+ tar = tarfile.open(fileobj=fobj, mode=self.mode)
+ self.assertEqual(tar.name, os.path.abspath(fobj.name))
+
+ def test_fail_comp(self):
+ # For Gzip and Bz2 Tests: fail with a ReadError on an uncompressed file.
+ if self.mode == "r:":
+ return
+ self.assertRaises(tarfile.ReadError, tarfile.open, tarname, self.mode)
+ fobj = open(tarname, "rb")
+ self.assertRaises(tarfile.ReadError, tarfile.open, fileobj=fobj, mode=self.mode)
+
+ def test_v7_dirtype(self):
+ # Test old style dirtype member (bug #1336623):
+ # Old V7 tars create directory members using an AREGTYPE
+ # header with a "/" appended to the filename field.
+ tarinfo = self.tar.getmember("misc/dirtype-old-v7")
+ self.assert_(tarinfo.type == tarfile.DIRTYPE,
+ "v7 dirtype failed")
+
+ def test_check_members(self):
for tarinfo in self.tar:
- members += 1
- if not tarinfo.isreg():
+ self.assert_(int(tarinfo.mtime) == 07606136617,
+ "wrong mtime for %s" % tarinfo.name)
+ if not tarinfo.name.startswith("ustar/"):
continue
- f = self.tar.extractfile(tarinfo)
- self.assert_(len(f.read()) == tarinfo.size,
- "size read does not match expected size")
- f.close()
-
- self.assert_(members == membercount,
- "could not find all members")
-
- def test_sparse(self):
- """Test sparse member extraction.
- """
- if self.sep != "|":
- f1 = self.tar.extractfile("S-SPARSE")
- f2 = self.tar.extractfile("S-SPARSE-WITH-NULLS")
- self.assert_(f1.read() == f2.read(),
- "_FileObject failed on sparse file member")
-
- def test_readlines(self):
- """Test readlines() method of _FileObject.
- """
- if self.sep != "|":
- filename = "0-REGTYPE-TEXT"
- self.tar.extract(filename, dirname())
- f = open(os.path.join(dirname(), filename), "rU")
- lines1 = f.readlines()
- f.close()
- lines2 = self.tar.extractfile(filename).readlines()
- self.assert_(lines1 == lines2,
- "_FileObject.readline() does not work correctly")
-
- def test_iter(self):
- # Test iteration over ExFileObject.
- if self.sep != "|":
- filename = "0-REGTYPE-TEXT"
- self.tar.extract(filename, dirname())
- f = open(os.path.join(dirname(), filename), "rU")
- lines1 = f.readlines()
- f.close()
- lines2 = [line for line in self.tar.extractfile(filename)]
- self.assert_(lines1 == lines2,
- "ExFileObject iteration does not work correctly")
-
- def test_seek(self):
- """Test seek() method of _FileObject, incl. random reading.
- """
- if self.sep != "|":
- filename = "0-REGTYPE-TEXT"
- self.tar.extract(filename, dirname())
- f = open(os.path.join(dirname(), filename), "rb")
- data = f.read()
- f.close()
-
- tarinfo = self.tar.getmember(filename)
- fobj = self.tar.extractfile(tarinfo)
-
- text = fobj.read()
- fobj.seek(0)
- self.assert_(0 == fobj.tell(),
- "seek() to file's start failed")
- fobj.seek(2048, 0)
- self.assert_(2048 == fobj.tell(),
- "seek() to absolute position failed")
- fobj.seek(-1024, 1)
- self.assert_(1024 == fobj.tell(),
- "seek() to negative relative position failed")
- fobj.seek(1024, 1)
- self.assert_(2048 == fobj.tell(),
- "seek() to positive relative position failed")
- s = fobj.read(10)
- self.assert_(s == data[2048:2058],
- "read() after seek failed")
- fobj.seek(0, 2)
- self.assert_(tarinfo.size == fobj.tell(),
- "seek() to file's end failed")
- self.assert_(fobj.read() == "",
- "read() at file's end did not return empty string")
- fobj.seek(-tarinfo.size, 2)
- self.assert_(0 == fobj.tell(),
- "relative seek() to file's start failed")
- fobj.seek(512)
- s1 = fobj.readlines()
- fobj.seek(512)
- s2 = fobj.readlines()
- self.assert_(s1 == s2,
- "readlines() after seek failed")
- fobj.seek(0)
- self.assert_(len(fobj.readline()) == fobj.tell(),
- "tell() after readline() failed")
- fobj.seek(512)
- self.assert_(len(fobj.readline()) + 512 == fobj.tell(),
- "tell() after seek() and readline() failed")
- fobj.seek(0)
- line = fobj.readline()
- self.assert_(fobj.read() == data[len(line):],
- "read() after readline() failed")
- fobj.close()
+ self.assert_(tarinfo.uname == "tarfile",
+ "wrong uname for %s" % tarinfo.name)
- def test_old_dirtype(self):
- """Test old style dirtype member (bug #1336623).
- """
- # Old tars create directory members using a REGTYPE
- # header with a "/" appended to the filename field.
+ def test_find_members(self):
+ self.assert_(self.tar.getmembers()[-1].name == "misc/eof",
+ "could not find all members")
- # Create an old tar style directory entry.
- filename = tmpname()
- tarinfo = tarfile.TarInfo("directory/")
- tarinfo.type = tarfile.REGTYPE
+ def test_extract_hardlink(self):
+ # Test hardlink extraction (e.g. bug #857297).
+ tar = tarfile.open(tarname, errorlevel=1, encoding="iso8859-1")
- fobj = open(filename, "w")
- fobj.write(tarinfo.tobuf())
- fobj.close()
+ tar.extract("ustar/regtype", TEMPDIR)
+ try:
+ tar.extract("ustar/lnktype", TEMPDIR)
+ except EnvironmentError as e:
+ if e.errno == errno.ENOENT:
+ self.fail("hardlink not extracted properly")
+
+ data = open(os.path.join(TEMPDIR, "ustar/lnktype"), "rb").read()
+ self.assertEqual(md5sum(data), md5_regtype)
try:
- # Test if it is still a directory entry when
- # read back.
- tar = tarfile.open(filename)
- tarinfo = tar.getmembers()[0]
- tar.close()
-
- self.assert_(tarinfo.type == tarfile.DIRTYPE)
- self.assert_(tarinfo.name.endswith("/"))
- finally:
- try:
- os.unlink(filename)
- except:
- pass
-
-class ReadStreamTest(ReadTest):
- sep = "|"
-
- def test(self):
- """Test member extraction, and for StreamError when
- seeking backwards.
- """
- ReadTest.test(self)
- tarinfo = self.tar.getmembers()[0]
- f = self.tar.extractfile(tarinfo)
+ tar.extract("ustar/symtype", TEMPDIR)
+ except EnvironmentError as e:
+ if e.errno == errno.ENOENT:
+ self.fail("symlink not extracted properly")
+
+ data = open(os.path.join(TEMPDIR, "ustar/symtype"), "rb").read()
+ self.assertEqual(md5sum(data), md5_regtype)
+
+
+class StreamReadTest(ReadTest):
+
+ mode="r|"
+
+ def test_fileobj_regular_file(self):
+ tarinfo = self.tar.next() # get "regtype" (can't use getmember)
+ fobj = self.tar.extractfile(tarinfo)
+ data = fobj.read()
+ self.assert_((len(data), md5sum(data)) == (tarinfo.size, md5_regtype),
+ "regular file extraction failed")
+
+ def test_provoke_stream_error(self):
+ tarinfos = self.tar.getmembers()
+ f = self.tar.extractfile(tarinfos[0]) # read the first member
self.assertRaises(tarfile.StreamError, f.read)
- def test_stream(self):
- """Compare the normal tar and the stream tar.
- """
- stream = self.tar
- tar = tarfile.open(tarname(), 'r')
+ def test_compare_members(self):
+ tar1 = tarfile.open(tarname, encoding="iso8859-1")
+ tar2 = self.tar
- while 1:
- t1 = tar.next()
- t2 = stream.next()
+ while True:
+ t1 = tar1.next()
+ t2 = tar2.next()
if t1 is None:
break
self.assert_(t2 is not None, "stream.next() failed.")
if t2.islnk() or t2.issym():
- self.assertRaises(tarfile.StreamError, stream.extractfile, t2)
+ self.assertRaises(tarfile.StreamError, tar2.extractfile, t2)
continue
- v1 = tar.extractfile(t1)
- v2 = stream.extractfile(t2)
+
+ v1 = tar1.extractfile(t1)
+ v2 = tar2.extractfile(t2)
if v1 is None:
continue
self.assert_(v2 is not None, "stream.extractfile() failed")
self.assert_(v1.read() == v2.read(), "stream extraction failed")
- tar.close()
- stream.close()
+ tar1.close()
-class ReadDetectTest(ReadTest):
- def setUp(self):
- self.tar = tarfile.open(tarname(self.comp), self.mode)
+class DetectReadTest(unittest.TestCase):
-class ReadDetectFileobjTest(ReadTest):
+ def _testfunc_file(self, name, mode):
+ try:
+ tarfile.open(name, mode)
+ except tarfile.ReadError:
+ self.fail()
- def setUp(self):
- name = tarname(self.comp)
- self.tar = tarfile.open(name, mode=self.mode,
- fileobj=open(name, "rb"))
+ def _testfunc_fileobj(self, name, mode):
+ try:
+ tarfile.open(name, mode, fileobj=open(name, "rb"))
+ except tarfile.ReadError:
+ self.fail()
-class ReadAsteriskTest(ReadTest):
+ def _test_modes(self, testfunc):
+ testfunc(tarname, "r")
+ testfunc(tarname, "r:")
+ testfunc(tarname, "r:*")
+ testfunc(tarname, "r|")
+ testfunc(tarname, "r|*")
- def setUp(self):
- mode = self.mode + self.sep + "*"
- self.tar = tarfile.open(tarname(self.comp), mode)
+ if gzip:
+ self.assertRaises(tarfile.ReadError, tarfile.open, tarname, mode="r:gz")
+ self.assertRaises(tarfile.ReadError, tarfile.open, tarname, mode="r|gz")
+ self.assertRaises(tarfile.ReadError, tarfile.open, gzipname, mode="r:")
+ self.assertRaises(tarfile.ReadError, tarfile.open, gzipname, mode="r|")
-class ReadStreamAsteriskTest(ReadStreamTest):
+ testfunc(gzipname, "r")
+ testfunc(gzipname, "r:*")
+ testfunc(gzipname, "r:gz")
+ testfunc(gzipname, "r|*")
+ testfunc(gzipname, "r|gz")
- def setUp(self):
- mode = self.mode + self.sep + "*"
- self.tar = tarfile.open(tarname(self.comp), mode)
+ if bz2:
+ self.assertRaises(tarfile.ReadError, tarfile.open, tarname, mode="r:bz2")
+ self.assertRaises(tarfile.ReadError, tarfile.open, tarname, mode="r|bz2")
+ self.assertRaises(tarfile.ReadError, tarfile.open, bz2name, mode="r:")
+ self.assertRaises(tarfile.ReadError, tarfile.open, bz2name, mode="r|")
-class WriteTest(BaseTest):
- mode = 'w'
+ testfunc(bz2name, "r")
+ testfunc(bz2name, "r:*")
+ testfunc(bz2name, "r:bz2")
+ testfunc(bz2name, "r|*")
+ testfunc(bz2name, "r|bz2")
- def setUp(self):
- mode = self.mode + self.sep + self.comp
- self.src = tarfile.open(tarname(self.comp), 'r')
- self.dstname = tmpname()
- self.dst = tarfile.open(self.dstname, mode)
+ def test_detect_file(self):
+ self._test_modes(self._testfunc_file)
- def tearDown(self):
- self.src.close()
- self.dst.close()
+ def test_detect_fileobj(self):
+ self._test_modes(self._testfunc_fileobj)
- def test_posix(self):
- self.dst.posix = 1
- self._test()
- def test_nonposix(self):
- self.dst.posix = 0
- self._test()
+class MemberReadTest(ReadTest):
- def test_small(self):
- self.dst.add(os.path.join(os.path.dirname(__file__),"cfgparser.1"))
- self.dst.close()
- self.assertNotEqual(os.stat(self.dstname).st_size, 0)
+ def _test_member(self, tarinfo, chksum=None, **kwargs):
+ if chksum is not None:
+ self.assert_(md5sum(self.tar.extractfile(tarinfo).read()) == chksum,
+ "wrong md5sum for %s" % tarinfo.name)
- def _test(self):
- for tarinfo in self.src:
- if not tarinfo.isreg():
- continue
- f = self.src.extractfile(tarinfo)
- if self.dst.posix and len(tarinfo.name) > tarfile.LENGTH_NAME and "/" not in tarinfo.name:
- self.assertRaises(ValueError, self.dst.addfile,
- tarinfo, f)
- else:
- self.dst.addfile(tarinfo, f)
+ kwargs["mtime"] = 07606136617
+ kwargs["uid"] = 1000
+ kwargs["gid"] = 100
+ if "old-v7" not in tarinfo.name:
+ # V7 tar can't handle alphabetic owners.
+ kwargs["uname"] = "tarfile"
+ kwargs["gname"] = "tarfile"
+ for k, v in kwargs.items():
+ self.assert_(getattr(tarinfo, k) == v,
+ "wrong value in %s field of %s" % (k, tarinfo.name))
- def test_add_self(self):
- dstname = os.path.abspath(self.dstname)
+ def test_find_regtype(self):
+ tarinfo = self.tar.getmember("ustar/regtype")
+ self._test_member(tarinfo, size=7011, chksum=md5_regtype)
- self.assertEqual(self.dst.name, dstname, "archive name must be absolute")
+ def test_find_conttype(self):
+ tarinfo = self.tar.getmember("ustar/conttype")
+ self._test_member(tarinfo, size=7011, chksum=md5_regtype)
- self.dst.add(dstname)
- self.assertEqual(self.dst.getnames(), [], "added the archive to itself")
+ def test_find_dirtype(self):
+ tarinfo = self.tar.getmember("ustar/dirtype")
+ self._test_member(tarinfo, size=0)
- cwd = os.getcwd()
- os.chdir(dirname())
- self.dst.add(dstname)
- os.chdir(cwd)
- self.assertEqual(self.dst.getnames(), [], "added the archive to itself")
+ def test_find_dirtype_with_size(self):
+ tarinfo = self.tar.getmember("ustar/dirtype-with-size")
+ self._test_member(tarinfo, size=255)
+ def test_find_lnktype(self):
+ tarinfo = self.tar.getmember("ustar/lnktype")
+ self._test_member(tarinfo, size=0, linkname="ustar/regtype")
-class AppendTest(unittest.TestCase):
- # Test append mode (cp. patch #1652681).
+ def test_find_symtype(self):
+ tarinfo = self.tar.getmember("ustar/symtype")
+ self._test_member(tarinfo, size=0, linkname="regtype")
- def setUp(self):
- self.tarname = tmpname()
- if os.path.exists(self.tarname):
- os.remove(self.tarname)
+ def test_find_blktype(self):
+ tarinfo = self.tar.getmember("ustar/blktype")
+ self._test_member(tarinfo, size=0, devmajor=3, devminor=0)
- def _add_testfile(self, fileobj=None):
- tar = tarfile.open(self.tarname, "a", fileobj=fileobj)
- tar.addfile(tarfile.TarInfo("bar"))
- tar.close()
+ def test_find_chrtype(self):
+ tarinfo = self.tar.getmember("ustar/chrtype")
+ self._test_member(tarinfo, size=0, devmajor=1, devminor=3)
- def _create_testtar(self):
- src = tarfile.open(tarname())
- t = src.getmember("0-REGTYPE")
- t.name = "foo"
- f = src.extractfile(t)
- tar = tarfile.open(self.tarname, "w")
- tar.addfile(t, f)
- tar.close()
+ def test_find_fifotype(self):
+ tarinfo = self.tar.getmember("ustar/fifotype")
+ self._test_member(tarinfo, size=0)
- def _test(self, names=["bar"], fileobj=None):
- tar = tarfile.open(self.tarname, fileobj=fileobj)
- self.assert_(tar.getnames() == names)
+ def test_find_sparse(self):
+ tarinfo = self.tar.getmember("ustar/sparse")
+ self._test_member(tarinfo, size=86016, chksum=md5_sparse)
- def test_non_existing(self):
- self._add_testfile()
- self._test()
+ def test_find_umlauts(self):
+ tarinfo = self.tar.getmember("ustar/umlauts-ÄÖÜäöüß")
+ self._test_member(tarinfo, size=7011, chksum=md5_regtype)
- def test_empty(self):
- open(self.tarname, "wb").close()
- self._add_testfile()
- self._test()
+ def test_find_ustar_longname(self):
+ name = "ustar/" + "12345/" * 39 + "1234567/longname"
+ self.assert_(name in self.tar.getnames())
- def test_empty_fileobj(self):
- fobj = StringIO.StringIO()
- self._add_testfile(fobj)
- fobj.seek(0)
- self._test(fileobj=fobj)
+ def test_find_regtype_oldv7(self):
+ tarinfo = self.tar.getmember("misc/regtype-old-v7")
+ self._test_member(tarinfo, size=7011, chksum=md5_regtype)
- def test_fileobj(self):
- self._create_testtar()
- data = open(self.tarname, "rb").read()
- fobj = StringIO.StringIO(data)
- self._add_testfile(fobj)
- fobj.seek(0)
- self._test(names=["foo", "bar"], fileobj=fobj)
+ def test_find_pax_umlauts(self):
+ self.tar = tarfile.open(self.tarname, mode=self.mode, encoding="iso8859-1")
+ tarinfo = self.tar.getmember("pax/umlauts-ÄÖÜäöüß")
+ self._test_member(tarinfo, size=7011, chksum=md5_regtype)
- def test_existing(self):
- self._create_testtar()
- self._add_testfile()
- self._test(names=["foo", "bar"])
+class LongnameTest(ReadTest):
-class Write100Test(BaseTest):
- # The name field in a tar header stores strings of at most 100 chars.
- # If a string is shorter than 100 chars it has to be padded with '\0',
- # which implies that a string of exactly 100 chars is stored without
- # a trailing '\0'.
+ def test_read_longname(self):
+ # Test reading of longname (bug #1471427).
+ name = self.subdir + "/" + "123/" * 125 + "longname"
+ try:
+ tarinfo = self.tar.getmember(name)
+ except KeyError:
+ self.fail("longname not found")
+ self.assert_(tarinfo.type != tarfile.DIRTYPE, "read longname as dirtype")
- def setUp(self):
- self.name = "01234567890123456789012345678901234567890123456789"
- self.name += "01234567890123456789012345678901234567890123456789"
+ def test_read_longlink(self):
+ longname = self.subdir + "/" + "123/" * 125 + "longname"
+ longlink = self.subdir + "/" + "123/" * 125 + "longlink"
+ try:
+ tarinfo = self.tar.getmember(longlink)
+ except KeyError:
+ self.fail("longlink not found")
+ self.assert_(tarinfo.linkname == longname, "linkname wrong")
- self.tar = tarfile.open(tmpname(), "w")
- t = tarfile.TarInfo(self.name)
- self.tar.addfile(t)
- self.tar.close()
+ def test_truncated_longname(self):
+ longname = self.subdir + "/" + "123/" * 125 + "longname"
+ tarinfo = self.tar.getmember(longname)
+ offset = tarinfo.offset
+ self.tar.fileobj.seek(offset)
+ fobj = StringIO.StringIO(self.tar.fileobj.read(1536))
+ self.assertRaises(tarfile.ReadError, tarfile.open, name="foo.tar", fileobj=fobj)
- self.tar = tarfile.open(tmpname())
- def tearDown(self):
- self.tar.close()
+class GNUReadTest(LongnameTest):
- def test(self):
- self.assertEqual(self.tar.getnames()[0], self.name,
- "failed to store 100 char filename")
+ subdir = "gnu"
+ def test_sparse_file(self):
+ tarinfo1 = self.tar.getmember("ustar/sparse")
+ fobj1 = self.tar.extractfile(tarinfo1)
+ tarinfo2 = self.tar.getmember("gnu/sparse")
+ fobj2 = self.tar.extractfile(tarinfo2)
+ self.assert_(fobj1.read() == fobj2.read(),
+ "sparse file extraction failed")
-class WriteSize0Test(BaseTest):
- mode = 'w'
- def setUp(self):
- self.tmpdir = dirname()
- self.dstname = tmpname()
- self.dst = tarfile.open(self.dstname, "w")
+class PaxReadTest(ReadTest):
- def tearDown(self):
- self.dst.close()
+ subdir = "pax"
+
+ def test_pax_globheaders(self):
+ tar = tarfile.open(tarname, encoding="iso8859-1")
+ tarinfo = tar.getmember("pax/regtype1")
+ self.assertEqual(tarinfo.uname, "foo")
+ self.assertEqual(tarinfo.gname, "bar")
+ self.assertEqual(tarinfo.pax_headers.get("VENDOR.umlauts"), "ÄÖÜäöüß")
+
+ tarinfo = tar.getmember("pax/regtype2")
+ self.assertEqual(tarinfo.uname, "")
+ self.assertEqual(tarinfo.gname, "bar")
+ self.assertEqual(tarinfo.pax_headers.get("VENDOR.umlauts"), "ÄÖÜäöüß")
+
+ tarinfo = tar.getmember("pax/regtype3")
+ self.assertEqual(tarinfo.uname, "tarfile")
+ self.assertEqual(tarinfo.gname, "tarfile")
+ self.assertEqual(tarinfo.pax_headers.get("VENDOR.umlauts"), "ÄÖÜäöüß")
+
+
+class WriteTest(unittest.TestCase):
+
+ mode = "w:"
+
+ def test_100_char_name(self):
+ # The name field in a tar header stores strings of at most 100 chars.
+ # If a string is shorter than 100 chars it has to be padded with '\0',
+ # which implies that a string of exactly 100 chars is stored without
+ # a trailing '\0'.
+ name = "0123456789" * 10
+ tar = tarfile.open(tmpname, self.mode)
+ t = tarfile.TarInfo(name)
+ tar.addfile(t)
+ tar.close()
+
+ tar = tarfile.open(tmpname)
+ self.assert_(tar.getnames()[0] == name,
+ "failed to store 100 char filename")
+ tar.close()
+
+ def test_tar_size(self):
+ # Test for bug #1013882.
+ tar = tarfile.open(tmpname, self.mode)
+ path = os.path.join(TEMPDIR, "file")
+ fobj = open(path, "wb")
+ fobj.write("aaa")
+ fobj.close()
+ tar.add(path)
+ tar.close()
+ self.assert_(os.path.getsize(tmpname) > 0,
+ "tarfile is empty")
+
+ # The test_*_size tests test for bug #1167128.
+ def test_file_size(self):
+ tar = tarfile.open(tmpname, self.mode)
- def test_file(self):
- path = os.path.join(self.tmpdir, "file")
- f = open(path, "w")
- f.close()
- tarinfo = self.dst.gettarinfo(path)
+ path = os.path.join(TEMPDIR, "file")
+ fobj = open(path, "wb")
+ fobj.close()
+ tarinfo = tar.gettarinfo(path)
self.assertEqual(tarinfo.size, 0)
- f = open(path, "w")
- f.write("aaa")
- f.close()
- tarinfo = self.dst.gettarinfo(path)
+
+ fobj = open(path, "wb")
+ fobj.write("aaa")
+ fobj.close()
+ tarinfo = tar.gettarinfo(path)
self.assertEqual(tarinfo.size, 3)
- def test_directory(self):
- path = os.path.join(self.tmpdir, "directory")
- if os.path.exists(path):
- # This shouldn't be necessary, but is <wink> if a previous
- # run was killed in mid-stream.
- shutil.rmtree(path)
- os.mkdir(path)
- tarinfo = self.dst.gettarinfo(path)
- self.assertEqual(tarinfo.size, 0)
+ tar.close()
- def test_symlink(self):
+ def test_directory_size(self):
+ path = os.path.join(TEMPDIR, "directory")
+ os.mkdir(path)
+ try:
+ tar = tarfile.open(tmpname, self.mode)
+ tarinfo = tar.gettarinfo(path)
+ self.assertEqual(tarinfo.size, 0)
+ finally:
+ os.rmdir(path)
+
+ def test_link_size(self):
+ if hasattr(os, "link"):
+ link = os.path.join(TEMPDIR, "link")
+ target = os.path.join(TEMPDIR, "link_target")
+ open(target, "wb").close()
+ os.link(target, link)
+ try:
+ tar = tarfile.open(tmpname, self.mode)
+ tarinfo = tar.gettarinfo(link)
+ self.assertEqual(tarinfo.size, 0)
+ finally:
+ os.remove(target)
+ os.remove(link)
+
+ def test_symlink_size(self):
if hasattr(os, "symlink"):
- path = os.path.join(self.tmpdir, "symlink")
+ path = os.path.join(TEMPDIR, "symlink")
os.symlink("link_target", path)
- tarinfo = self.dst.gettarinfo(path)
- self.assertEqual(tarinfo.size, 0)
+ try:
+ tar = tarfile.open(tmpname, self.mode)
+ tarinfo = tar.gettarinfo(path)
+ self.assertEqual(tarinfo.size, 0)
+ finally:
+ os.remove(path)
+ def test_add_self(self):
+ # Test for #1257255.
+ dstname = os.path.abspath(tmpname)
-class WriteStreamTest(WriteTest):
- sep = '|'
+ tar = tarfile.open(tmpname, self.mode)
+ self.assert_(tar.name == dstname, "archive name must be absolute")
- def test_padding(self):
- self.dst.close()
+ tar.add(dstname)
+ self.assert_(tar.getnames() == [], "added the archive to itself")
- if self.comp == "gz":
- f = gzip.GzipFile(self.dstname)
- s = f.read()
- f.close()
- elif self.comp == "bz2":
- f = bz2.BZ2Decompressor()
- s = open(self.dstname).read()
- s = f.decompress(s)
- self.assertEqual(len(f.unused_data), 0, "trailing data")
- else:
- f = open(self.dstname)
- s = f.read()
- f.close()
+ cwd = os.getcwd()
+ os.chdir(TEMPDIR)
+ tar.add(dstname)
+ os.chdir(cwd)
+ self.assert_(tar.getnames() == [], "added the archive to itself")
- self.assertEqual(s.count("\0"), tarfile.RECORDSIZE,
- "incorrect zero padding")
+class StreamWriteTest(unittest.TestCase):
-class WriteGNULongTest(unittest.TestCase):
- """This testcase checks for correct creation of GNU Longname
- and Longlink extensions.
+ mode = "w|"
- It creates a tarfile and adds empty members with either
- long names, long linknames or both and compares the size
- of the tarfile with the expected size.
+ def test_stream_padding(self):
+ # Test for bug #1543303.
+ tar = tarfile.open(tmpname, self.mode)
+ tar.close()
- It checks for SF bug #812325 in TarFile._create_gnulong().
+ if self.mode.endswith("gz"):
+ fobj = gzip.GzipFile(tmpname)
+ data = fobj.read()
+ fobj.close()
+ elif self.mode.endswith("bz2"):
+ dec = bz2.BZ2Decompressor()
+ data = open(tmpname, "rb").read()
+ data = dec.decompress(data)
+ self.assert_(len(dec.unused_data) == 0,
+ "found trailing data")
+ else:
+ fobj = open(tmpname, "rb")
+ data = fobj.read()
+ fobj.close()
+
+ self.assert_(data.count("\0") == tarfile.RECORDSIZE,
+ "incorrect zero padding")
- While I was writing this testcase, I noticed a second bug
- in the same method:
- Long{names,links} weren't null-terminated which lead to
- bad tarfiles when their length was a multiple of 512. This
- is tested as well.
- """
+
+class GNUWriteTest(unittest.TestCase):
+ # This testcase checks for correct creation of GNU Longname
+ # and Longlink extended headers (cp. bug #812325).
def _length(self, s):
blocks, remainder = divmod(len(s) + 1, 512)
@@ -474,19 +573,17 @@ class WriteGNULongTest(unittest.TestCase):
return blocks * 512
def _calc_size(self, name, link=None):
- # initial tar header
+ # Initial tar header
count = 512
if len(name) > tarfile.LENGTH_NAME:
- # gnu longname extended header + longname
+ # GNU longname extended header + longname
count += 512
count += self._length(name)
-
if link is not None and len(link) > tarfile.LENGTH_LINK:
- # gnu longlink extended header + longlink
+ # GNU longlink extended header + longlink
count += 512
count += self._length(link)
-
return count
def _test(self, name, link=None):
@@ -495,17 +592,17 @@ class WriteGNULongTest(unittest.TestCase):
tarinfo.linkname = link
tarinfo.type = tarfile.LNKTYPE
- tar = tarfile.open(tmpname(), "w")
- tar.posix = False
+ tar = tarfile.open(tmpname, "w")
+ tar.format = tarfile.GNU_FORMAT
tar.addfile(tarinfo)
v1 = self._calc_size(name, link)
v2 = tar.offset
- self.assertEqual(v1, v2, "GNU longname/longlink creation failed")
+ self.assert_(v1 == v2, "GNU longname/longlink creation failed")
tar.close()
- tar = tarfile.open(tmpname())
+ tar = tarfile.open(tmpname)
member = tar.next()
self.failIf(member is None, "unable to read longname member")
self.assert_(tarinfo.name == member.name and \
@@ -542,268 +639,351 @@ class WriteGNULongTest(unittest.TestCase):
self._test(("longnam/" * 127) + "longname_",
("longlnk/" * 127) + "longlink_")
-class ReadGNULongTest(unittest.TestCase):
+
+class HardlinkTest(unittest.TestCase):
+ # Test the creation of LNKTYPE (hardlink) members in an archive.
def setUp(self):
- self.tar = tarfile.open(tarname())
+ self.foo = os.path.join(TEMPDIR, "foo")
+ self.bar = os.path.join(TEMPDIR, "bar")
+
+ fobj = open(self.foo, "wb")
+ fobj.write("foo")
+ fobj.close()
+
+ os.link(self.foo, self.bar)
+
+ self.tar = tarfile.open(tmpname, "w")
+ self.tar.add(self.foo)
def tearDown(self):
- self.tar.close()
+ os.remove(self.foo)
+ os.remove(self.bar)
- def test_1471427(self):
- """Test reading of longname (bug #1471427).
- """
- name = "test/" * 20 + "0-REGTYPE"
- try:
- tarinfo = self.tar.getmember(name)
- except KeyError:
- tarinfo = None
- self.assert_(tarinfo is not None, "longname not found")
- self.assert_(tarinfo.type != tarfile.DIRTYPE, "read longname as dirtype")
+ def test_add_twice(self):
+ # The same name will be added as a REGTYPE every
+ # time regardless of st_nlink.
+ tarinfo = self.tar.gettarinfo(self.foo)
+ self.assert_(tarinfo.type == tarfile.REGTYPE,
+ "add file as regular failed")
- def test_read_name(self):
- name = ("0-LONGNAME-" * 10)[:101]
- try:
- tarinfo = self.tar.getmember(name)
- except KeyError:
- tarinfo = None
- self.assert_(tarinfo is not None, "longname not found")
+ def test_add_hardlink(self):
+ tarinfo = self.tar.gettarinfo(self.bar)
+ self.assert_(tarinfo.type == tarfile.LNKTYPE,
+ "add file as hardlink failed")
- def test_read_link(self):
- link = ("1-LONGLINK-" * 10)[:101]
- name = ("0-LONGNAME-" * 10)[:101]
- try:
- tarinfo = self.tar.getmember(link)
- except KeyError:
- tarinfo = None
- self.assert_(tarinfo is not None, "longlink not found")
- self.assert_(tarinfo.linkname == name, "linkname wrong")
+ def test_dereference_hardlink(self):
+ self.tar.dereference = True
+ tarinfo = self.tar.gettarinfo(self.bar)
+ self.assert_(tarinfo.type == tarfile.REGTYPE,
+ "dereferencing hardlink failed")
- def test_truncated_longname(self):
- f = open(tarname())
- fobj = StringIO.StringIO(f.read(1024))
- f.close()
- tar = tarfile.open(name="foo.tar", fileobj=fobj)
- self.assert_(len(tar.getmembers()) == 0, "")
+
+class PaxWriteTest(GNUWriteTest):
+
+ def _test(self, name, link=None):
+ # See GNUWriteTest.
+ tarinfo = tarfile.TarInfo(name)
+ if link:
+ tarinfo.linkname = link
+ tarinfo.type = tarfile.LNKTYPE
+
+ tar = tarfile.open(tmpname, "w", format=tarfile.PAX_FORMAT)
+ tar.addfile(tarinfo)
tar.close()
+ tar = tarfile.open(tmpname)
+ if link:
+ l = tar.getmembers()[0].linkname
+ self.assert_(link == l, "PAX longlink creation failed")
+ else:
+ n = tar.getmembers()[0].name
+ self.assert_(name == n, "PAX longname creation failed")
-class ExtractHardlinkTest(BaseTest):
+ def test_iso8859_15_filename(self):
+ self._test_unicode_filename("iso8859-15")
- def test_hardlink(self):
- """Test hardlink extraction (bug #857297)
- """
- # Prevent errors from being caught
- self.tar.errorlevel = 1
+ def test_utf8_filename(self):
+ self._test_unicode_filename("utf8")
- self.tar.extract("0-REGTYPE", dirname())
- try:
- # Extract 1-LNKTYPE which is a hardlink to 0-REGTYPE
- self.tar.extract("1-LNKTYPE", dirname())
- except EnvironmentError as e:
- import errno
- if e.errno == errno.ENOENT:
- self.fail("hardlink not extracted properly")
+ def test_utf16_filename(self):
+ self._test_unicode_filename("utf16")
-class CreateHardlinkTest(BaseTest):
- """Test the creation of LNKTYPE (hardlink) members in an archive.
- In this respect tarfile.py mimics the behaviour of GNU tar: If
- a file has a st_nlink > 1, it will be added a REGTYPE member
- only the first time.
- """
+ def _test_unicode_filename(self, encoding):
+ tar = tarfile.open(tmpname, "w", format=tarfile.PAX_FORMAT)
+ name = u"\u20ac".encode(encoding) # Euro sign
+ tar.encoding = encoding
+ tar.addfile(tarfile.TarInfo(name))
+ tar.close()
- def setUp(self):
- self.tar = tarfile.open(tmpname(), "w")
+ tar = tarfile.open(tmpname, encoding=encoding)
+ self.assertEqual(tar.getmembers()[0].name, name)
+ tar.close()
- self.foo = os.path.join(dirname(), "foo")
- self.bar = os.path.join(dirname(), "bar")
+ def test_unicode_filename_error(self):
+ # The euro sign filename cannot be translated to iso8859-1 encoding.
+ tar = tarfile.open(tmpname, "w", format=tarfile.PAX_FORMAT, encoding="utf8")
+ name = u"\u20ac".encode("utf8") # Euro sign
+ tar.addfile(tarfile.TarInfo(name))
+ tar.close()
- if os.path.exists(self.foo):
- os.remove(self.foo)
- if os.path.exists(self.bar):
- os.remove(self.bar)
+ self.assertRaises(UnicodeError, tarfile.open, tmpname, encoding="iso8859-1")
- f = open(self.foo, "w")
- f.write("foo")
- f.close()
- self.tar.add(self.foo)
+ def test_pax_headers(self):
+ self._test_pax_headers({"foo": "bar", "uid": 0, "mtime": 1.23})
- def test_add_twice(self):
- # If st_nlink == 1 then the same file will be added as
- # REGTYPE every time.
- tarinfo = self.tar.gettarinfo(self.foo)
- self.assertEqual(tarinfo.type, tarfile.REGTYPE,
- "add file as regular failed")
+ self._test_pax_headers({"euro": u"\u20ac".encode("utf8")})
- def test_add_hardlink(self):
- # If st_nlink > 1 then the same file will be added as
- # LNKTYPE.
- os.link(self.foo, self.bar)
- tarinfo = self.tar.gettarinfo(self.foo)
- self.assertEqual(tarinfo.type, tarfile.LNKTYPE,
- "add file as hardlink failed")
+ self._test_pax_headers({"euro": u"\u20ac"},
+ {"euro": u"\u20ac".encode("utf8")})
- tarinfo = self.tar.gettarinfo(self.bar)
- self.assertEqual(tarinfo.type, tarfile.LNKTYPE,
- "add file as hardlink failed")
+ self._test_pax_headers({u"\u20ac": "euro"},
+ {u"\u20ac".encode("utf8"): "euro"})
- def test_dereference_hardlink(self):
- self.tar.dereference = True
- os.link(self.foo, self.bar)
- tarinfo = self.tar.gettarinfo(self.bar)
- self.assertEqual(tarinfo.type, tarfile.REGTYPE,
- "dereferencing hardlink failed")
+ def _test_pax_headers(self, pax_headers, cmp_headers=None):
+ if cmp_headers is None:
+ cmp_headers = pax_headers
+ tar = tarfile.open(tmpname, "w", format=tarfile.PAX_FORMAT, \
+ pax_headers=pax_headers, encoding="utf8")
+ tar.addfile(tarfile.TarInfo("test"))
+ tar.close()
-# Gzip TestCases
-class ReadTestGzip(ReadTest):
- comp = "gz"
-class ReadStreamTestGzip(ReadStreamTest):
- comp = "gz"
-class WriteTestGzip(WriteTest):
- comp = "gz"
-class WriteStreamTestGzip(WriteStreamTest):
- comp = "gz"
-class ReadDetectTestGzip(ReadDetectTest):
- comp = "gz"
-class ReadDetectFileobjTestGzip(ReadDetectFileobjTest):
- comp = "gz"
-class ReadAsteriskTestGzip(ReadAsteriskTest):
- comp = "gz"
-class ReadStreamAsteriskTestGzip(ReadStreamAsteriskTest):
- comp = "gz"
-
-# Filemode test cases
-
-class FileModeTest(unittest.TestCase):
- def test_modes(self):
- self.assertEqual(tarfile.filemode(0755), '-rwxr-xr-x')
- self.assertEqual(tarfile.filemode(07111), '---s--s--t')
-
-class HeaderErrorTest(unittest.TestCase):
+ tar = tarfile.open(tmpname, encoding="utf8")
+ self.assertEqual(tar.pax_headers, cmp_headers)
def test_truncated_header(self):
- self.assertRaises(tarfile.HeaderError, tarfile.TarInfo.frombuf, "")
- self.assertRaises(tarfile.HeaderError, tarfile.TarInfo.frombuf, "filename\0")
- self.assertRaises(tarfile.HeaderError, tarfile.TarInfo.frombuf, "\0" * 511)
- self.assertRaises(tarfile.HeaderError, tarfile.TarInfo.frombuf, "\0" * 513)
-
- def test_empty_header(self):
- self.assertRaises(tarfile.HeaderError, tarfile.TarInfo.frombuf, "\0" * 512)
-
- def test_invalid_header(self):
- buf = tarfile.TarInfo("filename").tobuf()
- buf = buf[:148] + "foo\0\0\0\0\0" + buf[156:] # invalid number field.
- self.assertRaises(tarfile.HeaderError, tarfile.TarInfo.frombuf, buf)
-
- def test_bad_checksum(self):
- buf = tarfile.TarInfo("filename").tobuf()
- b = buf[:148] + " " + buf[156:] # clear the checksum field.
- self.assertRaises(tarfile.HeaderError, tarfile.TarInfo.frombuf, b)
- b = "a" + buf[1:] # manipulate the buffer, so checksum won't match.
- self.assertRaises(tarfile.HeaderError, tarfile.TarInfo.frombuf, b)
-
-class OpenFileobjTest(BaseTest):
- # Test for SF bug #1496501.
-
- def test_opener(self):
- fobj = StringIO.StringIO("foo\n")
- try:
- tarfile.open("", "r", fileobj=fobj)
- except tarfile.ReadError:
- self.assertEqual(fobj.tell(), 0, "fileobj's position has moved")
-
-if bz2:
- # Bzip2 TestCases
- class ReadTestBzip2(ReadTestGzip):
- comp = "bz2"
- class ReadStreamTestBzip2(ReadStreamTestGzip):
- comp = "bz2"
- class WriteTestBzip2(WriteTest):
- comp = "bz2"
- class WriteStreamTestBzip2(WriteStreamTestGzip):
- comp = "bz2"
- class ReadDetectTestBzip2(ReadDetectTest):
- comp = "bz2"
- class ReadDetectFileobjTestBzip2(ReadDetectFileobjTest):
- comp = "bz2"
- class ReadAsteriskTestBzip2(ReadAsteriskTest):
- comp = "bz2"
- class ReadStreamAsteriskTestBzip2(ReadStreamAsteriskTest):
- comp = "bz2"
-
-# If importing gzip failed, discard the Gzip TestCases.
-if not gzip:
- del ReadTestGzip
- del ReadStreamTestGzip
- del WriteTestGzip
- del WriteStreamTestGzip
+ tar = tarfile.open(tmpname, "w", format=tarfile.PAX_FORMAT)
+ tarinfo = tarfile.TarInfo("123/" * 126 + "longname")
+ tar.addfile(tarinfo)
+ tar.close()
-def test_main():
- # Create archive.
- f = open(tarname(), "rb")
- fguts = f.read()
- f.close()
- if gzip:
- # create testtar.tar.gz
- tar = gzip.open(tarname("gz"), "wb")
- tar.write(fguts)
+ # Simulate a premature EOF.
+ open(tmpname, "rb+").truncate(1536)
+ tar = tarfile.open(tmpname)
+ self.assertEqual(tar.getmembers(), [])
+
+
+class AppendTest(unittest.TestCase):
+ # Test append mode (cp. patch #1652681).
+
+ def setUp(self):
+ self.tarname = tmpname
+ if os.path.exists(self.tarname):
+ os.remove(self.tarname)
+
+ def _add_testfile(self, fileobj=None):
+ tar = tarfile.open(self.tarname, "a", fileobj=fileobj)
+ tar.addfile(tarfile.TarInfo("bar"))
tar.close()
- if bz2:
- # create testtar.tar.bz2
- tar = bz2.BZ2File(tarname("bz2"), "wb")
- tar.write(fguts)
+
+ def _create_testtar(self, mode="w:"):
+ src = tarfile.open(tarname, encoding="iso8859-1")
+ t = src.getmember("ustar/regtype")
+ t.name = "foo"
+ f = src.extractfile(t)
+ tar = tarfile.open(self.tarname, mode)
+ tar.addfile(t, f)
tar.close()
+ def _test(self, names=["bar"], fileobj=None):
+ tar = tarfile.open(self.tarname, fileobj=fileobj)
+ self.assertEqual(tar.getnames(), names)
+
+ def test_non_existing(self):
+ self._add_testfile()
+ self._test()
+
+ def test_empty(self):
+ open(self.tarname, "w").close()
+ self._add_testfile()
+ self._test()
+
+ def test_empty_fileobj(self):
+ fobj = StringIO.StringIO()
+ self._add_testfile(fobj)
+ fobj.seek(0)
+ self._test(fileobj=fobj)
+
+ def test_fileobj(self):
+ self._create_testtar()
+ data = open(self.tarname).read()
+ fobj = StringIO.StringIO(data)
+ self._add_testfile(fobj)
+ fobj.seek(0)
+ self._test(names=["foo", "bar"], fileobj=fobj)
+
+ def test_existing(self):
+ self._create_testtar()
+ self._add_testfile()
+ self._test(names=["foo", "bar"])
+
+ def test_append_gz(self):
+ if gzip is None:
+ return
+ self._create_testtar("w:gz")
+ self.assertRaises(tarfile.ReadError, tarfile.open, tmpname, "a")
+
+ def test_append_bz2(self):
+ if bz2 is None:
+ return
+ self._create_testtar("w:bz2")
+ self.assertRaises(tarfile.ReadError, tarfile.open, tmpname, "a")
+
+
+class LimitsTest(unittest.TestCase):
+
+ def test_ustar_limits(self):
+ # 100 char name
+ tarinfo = tarfile.TarInfo("0123456789" * 10)
+ tarinfo.create_ustar_header()
+
+ # 101 char name that cannot be stored
+ tarinfo = tarfile.TarInfo("0123456789" * 10 + "0")
+ self.assertRaises(ValueError, tarinfo.create_ustar_header)
+
+ # 256 char name with a slash at pos 156
+ tarinfo = tarfile.TarInfo("123/" * 62 + "longname")
+ tarinfo.create_ustar_header()
+
+ # 256 char name that cannot be stored
+ tarinfo = tarfile.TarInfo("1234567/" * 31 + "longname")
+ self.assertRaises(ValueError, tarinfo.create_ustar_header)
+
+ # 512 char name
+ tarinfo = tarfile.TarInfo("123/" * 126 + "longname")
+ self.assertRaises(ValueError, tarinfo.create_ustar_header)
+
+ # 512 char linkname
+ tarinfo = tarfile.TarInfo("longlink")
+ tarinfo.linkname = "123/" * 126 + "longname"
+ self.assertRaises(ValueError, tarinfo.create_ustar_header)
+
+ # uid > 8 digits
+ tarinfo = tarfile.TarInfo("name")
+ tarinfo.uid = 010000000
+ self.assertRaises(ValueError, tarinfo.create_ustar_header)
+
+ def test_gnu_limits(self):
+ tarinfo = tarfile.TarInfo("123/" * 126 + "longname")
+ tarinfo.create_gnu_header()
+
+ tarinfo = tarfile.TarInfo("longlink")
+ tarinfo.linkname = "123/" * 126 + "longname"
+ tarinfo.create_gnu_header()
+
+ # uid >= 256 ** 7
+ tarinfo = tarfile.TarInfo("name")
+ tarinfo.uid = 04000000000000000000
+ self.assertRaises(ValueError, tarinfo.create_gnu_header)
+
+ def test_pax_limits(self):
+ # A 256 char name that can be stored without an extended header.
+ tarinfo = tarfile.TarInfo("123/" * 62 + "longname")
+ self.assert_(len(tarinfo.create_pax_header("utf8")) == 512,
+ "create_pax_header attached superfluous extended header")
+
+ tarinfo = tarfile.TarInfo("123/" * 126 + "longname")
+ tarinfo.create_pax_header("utf8")
+
+ tarinfo = tarfile.TarInfo("longlink")
+ tarinfo.linkname = "123/" * 126 + "longname"
+ tarinfo.create_pax_header("utf8")
+
+ tarinfo = tarfile.TarInfo("name")
+ tarinfo.uid = 04000000000000000000
+ tarinfo.create_pax_header("utf8")
+
+
+class GzipMiscReadTest(MiscReadTest):
+ tarname = gzipname
+ mode = "r:gz"
+class GzipUstarReadTest(UstarReadTest):
+ tarname = gzipname
+ mode = "r:gz"
+class GzipStreamReadTest(StreamReadTest):
+ tarname = gzipname
+ mode = "r|gz"
+class GzipWriteTest(WriteTest):
+ mode = "w:gz"
+class GzipStreamWriteTest(StreamWriteTest):
+ mode = "w|gz"
+
+
+class Bz2MiscReadTest(MiscReadTest):
+ tarname = bz2name
+ mode = "r:bz2"
+class Bz2UstarReadTest(UstarReadTest):
+ tarname = bz2name
+ mode = "r:bz2"
+class Bz2StreamReadTest(StreamReadTest):
+ tarname = bz2name
+ mode = "r|bz2"
+class Bz2WriteTest(WriteTest):
+ mode = "w:bz2"
+class Bz2StreamWriteTest(StreamWriteTest):
+ mode = "w|bz2"
+
+def test_main():
+ if not os.path.exists(TEMPDIR):
+ os.mkdir(TEMPDIR)
+
tests = [
- FileModeTest,
- HeaderErrorTest,
- OpenFileobjTest,
- ReadTest,
- ReadStreamTest,
- ReadDetectTest,
- ReadDetectFileobjTest,
- ReadAsteriskTest,
- ReadStreamAsteriskTest,
+ UstarReadTest,
+ MiscReadTest,
+ StreamReadTest,
+ DetectReadTest,
+ MemberReadTest,
+ GNUReadTest,
+ PaxReadTest,
WriteTest,
+ StreamWriteTest,
+ GNUWriteTest,
+ PaxWriteTest,
AppendTest,
- Write100Test,
- WriteSize0Test,
- WriteStreamTest,
- WriteGNULongTest,
- ReadGNULongTest,
+ LimitsTest,
]
if hasattr(os, "link"):
- tests.append(ExtractHardlinkTest)
- tests.append(CreateHardlinkTest)
+ tests.append(HardlinkTest)
+
+ fobj = open(tarname, "rb")
+ data = fobj.read()
+ fobj.close()
if gzip:
- tests.extend([
- ReadTestGzip, ReadStreamTestGzip,
- WriteTestGzip, WriteStreamTestGzip,
- ReadDetectTestGzip, ReadDetectFileobjTestGzip,
- ReadAsteriskTestGzip, ReadStreamAsteriskTestGzip
- ])
+ # Create testtar.tar.gz and add gzip-specific tests.
+ tar = gzip.open(gzipname, "wb")
+ tar.write(data)
+ tar.close()
+
+ tests += [
+ GzipMiscReadTest,
+ GzipUstarReadTest,
+ GzipStreamReadTest,
+ GzipWriteTest,
+ GzipStreamWriteTest,
+ ]
if bz2:
- tests.extend([
- ReadTestBzip2, ReadStreamTestBzip2,
- WriteTestBzip2, WriteStreamTestBzip2,
- ReadDetectTestBzip2, ReadDetectFileobjTestBzip2,
- ReadAsteriskTestBzip2, ReadStreamAsteriskTestBzip2
- ])
+ # Create testtar.tar.bz2 and add bz2-specific tests.
+ tar = bz2.BZ2File(bz2name, "wb")
+ tar.write(data)
+ tar.close()
+
+ tests += [
+ Bz2MiscReadTest,
+ Bz2UstarReadTest,
+ Bz2StreamReadTest,
+ Bz2WriteTest,
+ Bz2StreamWriteTest,
+ ]
+
try:
test_support.run_unittest(*tests)
finally:
- if gzip:
- os.remove(tarname("gz"))
- if bz2:
- os.remove(tarname("bz2"))
- if os.path.exists(dirname()):
- shutil.rmtree(dirname())
- if os.path.exists(tmpname()):
- os.remove(tmpname())
+ if os.path.exists(TEMPDIR):
+ shutil.rmtree(TEMPDIR)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_telnetlib.py b/Lib/test/test_telnetlib.py
new file mode 100644
index 0000000..0a3604e
--- /dev/null
+++ b/Lib/test/test_telnetlib.py
@@ -0,0 +1,74 @@
+import socket
+import threading
+import telnetlib
+import time
+
+from unittest import TestCase
+from test import test_support
+
+
+def server(evt):
+ serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ serv.settimeout(3)
+ serv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ serv.bind(("", 9091))
+ serv.listen(5)
+ try:
+ conn, addr = serv.accept()
+ except socket.timeout:
+ pass
+ finally:
+ serv.close()
+ evt.set()
+
+class GeneralTests(TestCase):
+
+ def setUp(self):
+ self.evt = threading.Event()
+ threading.Thread(target=server, args=(self.evt,)).start()
+ time.sleep(.1)
+
+ def tearDown(self):
+ self.evt.wait()
+
+ def testBasic(self):
+ # connects
+ telnet = telnetlib.Telnet("localhost", 9091)
+ telnet.sock.close()
+
+ def testTimeoutDefault(self):
+ # default
+ telnet = telnetlib.Telnet("localhost", 9091)
+ self.assertTrue(telnet.sock.gettimeout() is None)
+ telnet.sock.close()
+
+ def testTimeoutValue(self):
+ # a value
+ telnet = telnetlib.Telnet("localhost", 9091, timeout=30)
+ self.assertEqual(telnet.sock.gettimeout(), 30)
+ telnet.sock.close()
+
+ def testTimeoutDifferentOrder(self):
+ telnet = telnetlib.Telnet(timeout=30)
+ telnet.open("localhost", 9091)
+ self.assertEqual(telnet.sock.gettimeout(), 30)
+ telnet.sock.close()
+
+ def testTimeoutNone(self):
+ # None, having other default
+ previous = socket.getdefaulttimeout()
+ socket.setdefaulttimeout(30)
+ try:
+ telnet = telnetlib.Telnet("localhost", 9091, timeout=None)
+ finally:
+ socket.setdefaulttimeout(previous)
+ self.assertEqual(telnet.sock.gettimeout(), 30)
+ telnet.sock.close()
+
+
+
+def test_main(verbose=None):
+ test_support.run_unittest(GeneralTests)
+
+if __name__ == '__main__':
+ test_main()
diff --git a/Lib/test/test_tempfile.py b/Lib/test/test_tempfile.py
index a398d37..20f22ed 100644
--- a/Lib/test/test_tempfile.py
+++ b/Lib/test/test_tempfile.py
@@ -81,7 +81,8 @@ class test_exports(TC):
"gettempprefix" : 1,
"gettempdir" : 1,
"tempdir" : 1,
- "template" : 1
+ "template" : 1,
+ "SpooledTemporaryFile" : 1
}
unexp = []
@@ -561,11 +562,12 @@ test_classes.append(test_mktemp)
class test_NamedTemporaryFile(TC):
"""Test NamedTemporaryFile()."""
- def do_create(self, dir=None, pre="", suf=""):
+ def do_create(self, dir=None, pre="", suf="", delete=True):
if dir is None:
dir = tempfile.gettempdir()
try:
- file = tempfile.NamedTemporaryFile(dir=dir, prefix=pre, suffix=suf)
+ file = tempfile.NamedTemporaryFile(dir=dir, prefix=pre, suffix=suf,
+ delete=delete)
except:
self.failOnException("NamedTemporaryFile")
@@ -599,6 +601,22 @@ class test_NamedTemporaryFile(TC):
finally:
os.rmdir(dir)
+ def test_dis_del_on_close(self):
+ # Tests that delete-on-close can be disabled
+ dir = tempfile.mkdtemp()
+ tmp = None
+ try:
+ f = tempfile.NamedTemporaryFile(dir=dir, delete=False)
+ tmp = f.name
+ f.write('blat')
+ f.close()
+ self.failUnless(os.path.exists(f.name),
+ "NamedTemporaryFile %s missing after close" % f.name)
+ finally:
+ if tmp is not None:
+ os.unlink(tmp)
+ os.rmdir(dir)
+
def test_multiple_close(self):
# A NamedTemporaryFile can be closed many times without error
@@ -615,6 +633,107 @@ class test_NamedTemporaryFile(TC):
test_classes.append(test_NamedTemporaryFile)
+class test_SpooledTemporaryFile(TC):
+ """Test SpooledTemporaryFile()."""
+
+ def do_create(self, max_size=0, dir=None, pre="", suf=""):
+ if dir is None:
+ dir = tempfile.gettempdir()
+ try:
+ file = tempfile.SpooledTemporaryFile(max_size=max_size, dir=dir, prefix=pre, suffix=suf)
+ except:
+ self.failOnException("SpooledTemporaryFile")
+
+ return file
+
+
+ def test_basic(self):
+ # SpooledTemporaryFile can create files
+ f = self.do_create()
+ self.failIf(f._rolled)
+ f = self.do_create(max_size=100, pre="a", suf=".txt")
+ self.failIf(f._rolled)
+
+ def test_del_on_close(self):
+ # A SpooledTemporaryFile is deleted when closed
+ dir = tempfile.mkdtemp()
+ try:
+ f = tempfile.SpooledTemporaryFile(max_size=10, dir=dir)
+ self.failIf(f._rolled)
+ f.write('blat ' * 5)
+ self.failUnless(f._rolled)
+ filename = f.name
+ f.close()
+ self.failIf(os.path.exists(filename),
+ "SpooledTemporaryFile %s exists after close" % filename)
+ finally:
+ os.rmdir(dir)
+
+ def test_rewrite_small(self):
+ # A SpooledTemporaryFile can be written to multiple within the max_size
+ f = self.do_create(max_size=30)
+ self.failIf(f._rolled)
+ for i in range(5):
+ f.seek(0, 0)
+ f.write('x' * 20)
+ self.failIf(f._rolled)
+
+ def test_write_sequential(self):
+ # A SpooledTemporaryFile should hold exactly max_size bytes, and roll
+ # over afterward
+ f = self.do_create(max_size=30)
+ self.failIf(f._rolled)
+ f.write('x' * 20)
+ self.failIf(f._rolled)
+ f.write('x' * 10)
+ self.failIf(f._rolled)
+ f.write('x')
+ self.failUnless(f._rolled)
+
+ def test_sparse(self):
+ # A SpooledTemporaryFile that is written late in the file will extend
+ # when that occurs
+ f = self.do_create(max_size=30)
+ self.failIf(f._rolled)
+ f.seek(100, 0)
+ self.failIf(f._rolled)
+ f.write('x')
+ self.failUnless(f._rolled)
+
+ def test_fileno(self):
+ # A SpooledTemporaryFile should roll over to a real file on fileno()
+ f = self.do_create(max_size=30)
+ self.failIf(f._rolled)
+ self.failUnless(f.fileno() > 0)
+ self.failUnless(f._rolled)
+
+ def test_multiple_close(self):
+ # A SpooledTemporaryFile can be closed many times without error
+ f = tempfile.SpooledTemporaryFile()
+ f.write('abc\n')
+ f.close()
+ try:
+ f.close()
+ f.close()
+ except:
+ self.failOnException("close")
+
+ def test_bound_methods(self):
+ # It should be OK to steal a bound method from a SpooledTemporaryFile
+ # and use it independently; when the file rolls over, those bound
+ # methods should continue to function
+ f = self.do_create(max_size=30)
+ read = f.read
+ write = f.write
+ seek = f.seek
+
+ write("a" * 35)
+ write("b" * 35)
+ seek(0, 0)
+ self.failUnless(read(70) == 'a'*35 + 'b'*35)
+
+test_classes.append(test_SpooledTemporaryFile)
+
class test_TemporaryFile(TC):
"""Test TemporaryFile()."""
diff --git a/Lib/test/test_textwrap.py b/Lib/test/test_textwrap.py
index 500eceb..5f0b51b 100644
--- a/Lib/test/test_textwrap.py
+++ b/Lib/test/test_textwrap.py
@@ -328,6 +328,14 @@ What a mess!
self.check_wrap(text, 30,
[" This is a sentence with", "leading whitespace."])
+ def test_no_drop_whitespace(self):
+ # SF patch #1581073
+ text = " This is a sentence with much whitespace."
+ self.check_wrap(text, 10,
+ [" This is a", " ", "sentence ",
+ "with ", "much white", "space."],
+ drop_whitespace=False)
+
if test_support.have_unicode:
def test_unicode(self):
# *Very* simple test of wrapping Unicode strings. I'm sure
diff --git a/Lib/test/test_threadedtempfile.py b/Lib/test/test_threadedtempfile.py
index 75d719d..753f388 100644
--- a/Lib/test/test_threadedtempfile.py
+++ b/Lib/test/test_threadedtempfile.py
@@ -10,22 +10,20 @@ failures. A failure is a bug in tempfile, and may be due to:
By default, NUM_THREADS == 20 and FILES_PER_THREAD == 50. This is enough to
create about 150 failures per run under Win98SE in 2.0, and runs pretty
quickly. Guido reports needing to boost FILES_PER_THREAD to 500 before
-provoking a 2.0 failure under Linux. Run the test alone to boost either
-via cmdline switches:
-
--f FILES_PER_THREAD (int)
--t NUM_THREADS (int)
+provoking a 2.0 failure under Linux.
"""
-NUM_THREADS = 20 # change w/ -t option
-FILES_PER_THREAD = 50 # change w/ -f option
+NUM_THREADS = 20
+FILES_PER_THREAD = 50
import thread # If this fails, we can't test this module
import threading
-from test.test_support import TestFailed, threading_setup, threading_cleanup
+import tempfile
+
+from test.test_support import threading_setup, threading_cleanup, run_unittest
+import unittest
import StringIO
from traceback import print_exc
-import tempfile
startEvent = threading.Event()
@@ -46,41 +44,36 @@ class TempFileGreedy(threading.Thread):
else:
self.ok_count += 1
+
+class ThreadedTempFileTest(unittest.TestCase):
+ def test_main(self):
+ threads = []
+ thread_info = threading_setup()
+
+ for i in range(NUM_THREADS):
+ t = TempFileGreedy()
+ threads.append(t)
+ t.start()
+
+ startEvent.set()
+
+ ok = 0
+ errors = []
+ for t in threads:
+ t.join()
+ ok += t.ok_count
+ if t.error_count:
+ errors.append(str(t.getName()) + str(t.errors.getvalue()))
+
+ threading_cleanup(*thread_info)
+
+ msg = "Errors: errors %d ok %d\n%s" % (len(errors), ok,
+ '\n'.join(errors))
+ self.assertEquals(errors, [], msg)
+ self.assertEquals(ok, NUM_THREADS * FILES_PER_THREAD)
+
def test_main():
- threads = []
- thread_info = threading_setup()
-
- print("Creating")
- for i in range(NUM_THREADS):
- t = TempFileGreedy()
- threads.append(t)
- t.start()
-
- print("Starting")
- startEvent.set()
-
- print("Reaping")
- ok = errors = 0
- for t in threads:
- t.join()
- ok += t.ok_count
- errors += t.error_count
- if t.error_count:
- print('%s errors:\n%s' % (t.getName(), t.errors.getvalue()))
-
- msg = "Done: errors %d ok %d" % (errors, ok)
- print(msg)
- if errors:
- raise TestFailed(msg)
-
- threading_cleanup(*thread_info)
+ run_unittest(ThreadedTempFileTest)
if __name__ == "__main__":
- import sys, getopt
- opts, args = getopt.getopt(sys.argv[1:], "t:f:")
- for o, v in opts:
- if o == "-f":
- FILES_PER_THREAD = int(v)
- elif o == "-t":
- NUM_THREADS = int(v)
test_main()
diff --git a/Lib/test/test_threading_local.py b/Lib/test/test_threading_local.py
index 56fbedd..0aaedbc 100644
--- a/Lib/test/test_threading_local.py
+++ b/Lib/test/test_threading_local.py
@@ -20,7 +20,7 @@ def test_main():
setUp=setUp, tearDown=tearDown)
)
- test_support.run_suite(suite)
+ test_support.run_unittest(suite)
if __name__ == '__main__':
test_main()
diff --git a/Lib/test/test_unicode.py b/Lib/test/test_unicode.py
index cb7586c..a704cc9 100644
--- a/Lib/test/test_unicode.py
+++ b/Lib/test/test_unicode.py
@@ -829,7 +829,7 @@ class UnicodeTest(
def test_main():
- test_support.run_unittest(UnicodeTest)
+ test_support.run_unittest(__name__)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_unicode_file.py b/Lib/test/test_unicode_file.py
index 0058d98..328b5b6 100644
--- a/Lib/test/test_unicode_file.py
+++ b/Lib/test/test_unicode_file.py
@@ -5,7 +5,7 @@ import os, glob, time, shutil
import unicodedata
import unittest
-from test.test_support import run_suite, TestSkipped, TESTFN_UNICODE
+from test.test_support import run_unittest, TestSkipped, TESTFN_UNICODE
from test.test_support import TESTFN_ENCODING, TESTFN_UNICODE_UNENCODEABLE
try:
TESTFN_ENCODED = TESTFN_UNICODE.encode(TESTFN_ENCODING)
@@ -205,9 +205,7 @@ class TestUnicodeFiles(unittest.TestCase):
False)
def test_main():
- suite = unittest.TestSuite()
- suite.addTest(unittest.makeSuite(TestUnicodeFiles))
- run_suite(suite)
+ run_unittest(__name__)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_unittest.py b/Lib/test/test_unittest.py
index 9151166..70f12d2 100644
--- a/Lib/test/test_unittest.py
+++ b/Lib/test/test_unittest.py
@@ -1,31 +1,2302 @@
"""Test script for unittest.
-This just includes tests for new features. We really need a
-full set of tests.
+By Collin Winter <collinw at gmail.com>
+
+Still need testing:
+ TestCase.{assert,fail}* methods (some are tested implicitly)
"""
+from test import test_support
import unittest
+from unittest import TestCase
+
+### Support code
+################################################################
+
+class LoggingResult(unittest.TestResult):
+ def __init__(self, log):
+ self._events = log
+ super(LoggingResult, self).__init__()
+
+ def startTest(self, test):
+ self._events.append('startTest')
+ super(LoggingResult, self).startTest(test)
+
+ def stopTest(self, test):
+ self._events.append('stopTest')
+ super(LoggingResult, self).stopTest(test)
+
+ def addFailure(self, *args):
+ self._events.append('addFailure')
+ super(LoggingResult, self).addFailure(*args)
+
+ def addError(self, *args):
+ self._events.append('addError')
+ super(LoggingResult, self).addError(*args)
+
+class TestEquality(object):
+ # Check for a valid __eq__ implementation
+ def test_eq(self):
+ for obj_1, obj_2 in self.eq_pairs:
+ self.assertEqual(obj_1, obj_2)
+ self.assertEqual(obj_2, obj_1)
+
+ # Check for a valid __ne__ implementation
+ def test_ne(self):
+ for obj_1, obj_2 in self.ne_pairs:
+ self.failIfEqual(obj_1, obj_2)
+ self.failIfEqual(obj_2, obj_1)
+
+class TestHashing(object):
+ # Check for a valid __hash__ implementation
+ def test_hash(self):
+ for obj_1, obj_2 in self.eq_pairs:
+ try:
+ assert hash(obj_1) == hash(obj_2)
+ except KeyboardInterrupt:
+ raise
+ except AssertionError:
+ self.fail("%s and %s do not hash equal" % (obj_1, obj_2))
+ except Exception as e:
+ self.fail("Problem hashing %s and %s: %s" % (obj_1, obj_2, e))
+
+ for obj_1, obj_2 in self.ne_pairs:
+ try:
+ assert hash(obj_1) != hash(obj_2)
+ except KeyboardInterrupt:
+ raise
+ except AssertionError:
+ self.fail("%s and %s hash equal, but shouldn't" % (obj_1, obj_2))
+ except Exception as e:
+ self.fail("Problem hashing %s and %s: %s" % (obj_1, obj_2, e))
+
+
+################################################################
+### /Support code
+
+class Test_TestLoader(TestCase):
+
+ ### Tests for TestLoader.loadTestsFromTestCase
+ ################################################################
+
+ # "Return a suite of all tests cases contained in the TestCase-derived
+ # class testCaseClass"
+ def test_loadTestsFromTestCase(self):
+ class Foo(unittest.TestCase):
+ def test_1(self): pass
+ def test_2(self): pass
+ def foo_bar(self): pass
+
+ tests = unittest.TestSuite([Foo('test_1'), Foo('test_2')])
+
+ loader = unittest.TestLoader()
+ self.assertEqual(loader.loadTestsFromTestCase(Foo), tests)
+
+ # "Return a suite of all tests cases contained in the TestCase-derived
+ # class testCaseClass"
+ #
+ # Make sure it does the right thing even if no tests were found
+ def test_loadTestsFromTestCase__no_matches(self):
+ class Foo(unittest.TestCase):
+ def foo_bar(self): pass
+
+ empty_suite = unittest.TestSuite()
+
+ loader = unittest.TestLoader()
+ self.assertEqual(loader.loadTestsFromTestCase(Foo), empty_suite)
+
+ # "Return a suite of all tests cases contained in the TestCase-derived
+ # class testCaseClass"
+ #
+ # What happens if loadTestsFromTestCase() is given an object
+ # that isn't a subclass of TestCase? Specifically, what happens
+ # if testCaseClass is a subclass of TestSuite?
+ #
+ # This is checked for specifically in the code, so we better add a
+ # test for it.
+ def test_loadTestsFromTestCase__TestSuite_subclass(self):
+ class NotATestCase(unittest.TestSuite):
+ pass
+
+ loader = unittest.TestLoader()
+ try:
+ loader.loadTestsFromTestCase(NotATestCase)
+ except TypeError:
+ pass
+ else:
+ self.fail('Should raise TypeError')
+
+ # "Return a suite of all tests cases contained in the TestCase-derived
+ # class testCaseClass"
+ #
+ # Make sure loadTestsFromTestCase() picks up the default test method
+ # name (as specified by TestCase), even though the method name does
+ # not match the default TestLoader.testMethodPrefix string
+ def test_loadTestsFromTestCase__default_method_name(self):
+ class Foo(unittest.TestCase):
+ def runTest(self):
+ pass
+
+ loader = unittest.TestLoader()
+ # This has to be false for the test to succeed
+ self.failIf('runTest'.startswith(loader.testMethodPrefix))
+
+ suite = loader.loadTestsFromTestCase(Foo)
+ self.failUnless(isinstance(suite, loader.suiteClass))
+ self.assertEqual(list(suite), [Foo('runTest')])
+
+ ################################################################
+ ### /Tests for TestLoader.loadTestsFromTestCase
+
+ ### Tests for TestLoader.loadTestsFromModule
+ ################################################################
+
+ # "This method searches `module` for classes derived from TestCase"
+ def test_loadTestsFromModule__TestCase_subclass(self):
+ import new
+ m = new.module('m')
+ class MyTestCase(unittest.TestCase):
+ def test(self):
+ pass
+ m.testcase_1 = MyTestCase
+
+ loader = unittest.TestLoader()
+ suite = loader.loadTestsFromModule(m)
+ self.failUnless(isinstance(suite, loader.suiteClass))
+
+ expected = [loader.suiteClass([MyTestCase('test')])]
+ self.assertEqual(list(suite), expected)
+
+ # "This method searches `module` for classes derived from TestCase"
+ #
+ # What happens if no tests are found (no TestCase instances)?
+ def test_loadTestsFromModule__no_TestCase_instances(self):
+ import new
+ m = new.module('m')
+
+ loader = unittest.TestLoader()
+ suite = loader.loadTestsFromModule(m)
+ self.failUnless(isinstance(suite, loader.suiteClass))
+ self.assertEqual(list(suite), [])
+
+ # "This method searches `module` for classes derived from TestCase"
+ #
+ # What happens if no tests are found (TestCases instances, but no tests)?
+ def test_loadTestsFromModule__no_TestCase_tests(self):
+ import new
+ m = new.module('m')
+ class MyTestCase(unittest.TestCase):
+ pass
+ m.testcase_1 = MyTestCase
+
+ loader = unittest.TestLoader()
+ suite = loader.loadTestsFromModule(m)
+ self.failUnless(isinstance(suite, loader.suiteClass))
+
+ self.assertEqual(list(suite), [loader.suiteClass()])
+
+ # "This method searches `module` for classes derived from TestCase"s
+ #
+ # What happens if loadTestsFromModule() is given something other
+ # than a module?
+ #
+ # XXX Currently, it succeeds anyway. This flexibility
+ # should either be documented or loadTestsFromModule() should
+ # raise a TypeError
+ #
+ # XXX Certain people are using this behaviour. We'll add a test for it
+ def test_loadTestsFromModule__not_a_module(self):
+ class MyTestCase(unittest.TestCase):
+ def test(self):
+ pass
+
+ class NotAModule(object):
+ test_2 = MyTestCase
+
+ loader = unittest.TestLoader()
+ suite = loader.loadTestsFromModule(NotAModule)
+
+ reference = [unittest.TestSuite([MyTestCase('test')])]
+ self.assertEqual(list(suite), reference)
+
+ ################################################################
+ ### /Tests for TestLoader.loadTestsFromModule()
+
+ ### Tests for TestLoader.loadTestsFromName()
+ ################################################################
+
+ # "The specifier name is a ``dotted name'' that may resolve either to
+ # a module, a test case class, a TestSuite instance, a test method
+ # within a test case class, or a callable object which returns a
+ # TestCase or TestSuite instance."
+ #
+ # Is ValueError raised in response to an empty name?
+ def test_loadTestsFromName__empty_name(self):
+ loader = unittest.TestLoader()
+
+ try:
+ loader.loadTestsFromName('')
+ except ValueError as e:
+ self.assertEqual(str(e), "Empty module name")
+ else:
+ self.fail("TestLoader.loadTestsFromName failed to raise ValueError")
+
+ # "The specifier name is a ``dotted name'' that may resolve either to
+ # a module, a test case class, a TestSuite instance, a test method
+ # within a test case class, or a callable object which returns a
+ # TestCase or TestSuite instance."
+ #
+ # What happens when the name contains invalid characters?
+ def test_loadTestsFromName__malformed_name(self):
+ loader = unittest.TestLoader()
+
+ # XXX Should this raise ValueError or ImportError?
+ try:
+ loader.loadTestsFromName('abc () //')
+ except ValueError:
+ pass
+ except ImportError:
+ pass
+ else:
+ self.fail("TestLoader.loadTestsFromName failed to raise ValueError")
+
+ # "The specifier name is a ``dotted name'' that may resolve ... to a
+ # module"
+ #
+ # What happens when a module by that name can't be found?
+ def test_loadTestsFromName__unknown_module_name(self):
+ loader = unittest.TestLoader()
+
+ try:
+ loader.loadTestsFromName('sdasfasfasdf')
+ except ImportError as e:
+ self.assertEqual(str(e), "No module named sdasfasfasdf")
+ else:
+ self.fail("TestLoader.loadTestsFromName failed to raise ImportError")
+
+ # "The specifier name is a ``dotted name'' that may resolve either to
+ # a module, a test case class, a TestSuite instance, a test method
+ # within a test case class, or a callable object which returns a
+ # TestCase or TestSuite instance."
+ #
+ # What happens when the module is found, but the attribute can't?
+ def test_loadTestsFromName__unknown_attr_name(self):
+ loader = unittest.TestLoader()
+
+ try:
+ loader.loadTestsFromName('unittest.sdasfasfasdf')
+ except AttributeError as e:
+ self.assertEqual(str(e), "'module' object has no attribute 'sdasfasfasdf'")
+ else:
+ self.fail("TestLoader.loadTestsFromName failed to raise AttributeError")
+
+ # "The specifier name is a ``dotted name'' that may resolve either to
+ # a module, a test case class, a TestSuite instance, a test method
+ # within a test case class, or a callable object which returns a
+ # TestCase or TestSuite instance."
+ #
+ # What happens when we provide the module, but the attribute can't be
+ # found?
+ def test_loadTestsFromName__relative_unknown_name(self):
+ loader = unittest.TestLoader()
+
+ try:
+ loader.loadTestsFromName('sdasfasfasdf', unittest)
+ except AttributeError as e:
+ self.assertEqual(str(e), "'module' object has no attribute 'sdasfasfasdf'")
+ else:
+ self.fail("TestLoader.loadTestsFromName failed to raise AttributeError")
+
+ # "The specifier name is a ``dotted name'' that may resolve either to
+ # a module, a test case class, a TestSuite instance, a test method
+ # within a test case class, or a callable object which returns a
+ # TestCase or TestSuite instance."
+ # ...
+ # "The method optionally resolves name relative to the given module"
+ #
+ # Does loadTestsFromName raise ValueError when passed an empty
+ # name relative to a provided module?
+ #
+ # XXX Should probably raise a ValueError instead of an AttributeError
+ def test_loadTestsFromName__relative_empty_name(self):
+ loader = unittest.TestLoader()
+
+ try:
+ loader.loadTestsFromName('', unittest)
+ except AttributeError as e:
+ pass
+ else:
+ self.fail("Failed to raise AttributeError")
+
+ # "The specifier name is a ``dotted name'' that may resolve either to
+ # a module, a test case class, a TestSuite instance, a test method
+ # within a test case class, or a callable object which returns a
+ # TestCase or TestSuite instance."
+ # ...
+ # "The method optionally resolves name relative to the given module"
+ #
+ # What happens when an impossible name is given, relative to the provided
+ # `module`?
+ def test_loadTestsFromName__relative_malformed_name(self):
+ loader = unittest.TestLoader()
+
+ # XXX Should this raise AttributeError or ValueError?
+ try:
+ loader.loadTestsFromName('abc () //', unittest)
+ except ValueError:
+ pass
+ except AttributeError:
+ pass
+ else:
+ self.fail("TestLoader.loadTestsFromName failed to raise ValueError")
+
+ # "The method optionally resolves name relative to the given module"
+ #
+ # Does loadTestsFromName raise TypeError when the `module` argument
+ # isn't a module object?
+ #
+ # XXX Accepts the not-a-module object, ignorning the object's type
+ # This should raise an exception or the method name should be changed
+ #
+ # XXX Some people are relying on this, so keep it for now
+ def test_loadTestsFromName__relative_not_a_module(self):
+ class MyTestCase(unittest.TestCase):
+ def test(self):
+ pass
+
+ class NotAModule(object):
+ test_2 = MyTestCase
+
+ loader = unittest.TestLoader()
+ suite = loader.loadTestsFromName('test_2', NotAModule)
+
+ reference = [MyTestCase('test')]
+ self.assertEqual(list(suite), reference)
+
+ # "The specifier name is a ``dotted name'' that may resolve either to
+ # a module, a test case class, a TestSuite instance, a test method
+ # within a test case class, or a callable object which returns a
+ # TestCase or TestSuite instance."
+ #
+ # Does it raise an exception if the name resolves to an invalid
+ # object?
+ def test_loadTestsFromName__relative_bad_object(self):
+ import new
+ m = new.module('m')
+ m.testcase_1 = object()
+
+ loader = unittest.TestLoader()
+ try:
+ loader.loadTestsFromName('testcase_1', m)
+ except TypeError:
+ pass
+ else:
+ self.fail("Should have raised TypeError")
+
+ # "The specifier name is a ``dotted name'' that may
+ # resolve either to ... a test case class"
+ def test_loadTestsFromName__relative_TestCase_subclass(self):
+ import new
+ m = new.module('m')
+ class MyTestCase(unittest.TestCase):
+ def test(self):
+ pass
+ m.testcase_1 = MyTestCase
+
+ loader = unittest.TestLoader()
+ suite = loader.loadTestsFromName('testcase_1', m)
+ self.failUnless(isinstance(suite, loader.suiteClass))
+ self.assertEqual(list(suite), [MyTestCase('test')])
+
+ # "The specifier name is a ``dotted name'' that may resolve either to
+ # a module, a test case class, a TestSuite instance, a test method
+ # within a test case class, or a callable object which returns a
+ # TestCase or TestSuite instance."
+ def test_loadTestsFromName__relative_TestSuite(self):
+ import new
+ m = new.module('m')
+ class MyTestCase(unittest.TestCase):
+ def test(self):
+ pass
+ m.testsuite = unittest.TestSuite([MyTestCase('test')])
+
+ loader = unittest.TestLoader()
+ suite = loader.loadTestsFromName('testsuite', m)
+ self.failUnless(isinstance(suite, loader.suiteClass))
+
+ self.assertEqual(list(suite), [MyTestCase('test')])
+
+ # "The specifier name is a ``dotted name'' that may resolve ... to
+ # ... a test method within a test case class"
+ def test_loadTestsFromName__relative_testmethod(self):
+ import new
+ m = new.module('m')
+ class MyTestCase(unittest.TestCase):
+ def test(self):
+ pass
+ m.testcase_1 = MyTestCase
+
+ loader = unittest.TestLoader()
+ suite = loader.loadTestsFromName('testcase_1.test', m)
+ self.failUnless(isinstance(suite, loader.suiteClass))
+
+ self.assertEqual(list(suite), [MyTestCase('test')])
+
+ # "The specifier name is a ``dotted name'' that may resolve either to
+ # a module, a test case class, a TestSuite instance, a test method
+ # within a test case class, or a callable object which returns a
+ # TestCase or TestSuite instance."
+ #
+ # Does loadTestsFromName() raise the proper exception when trying to
+ # resolve "a test method within a test case class" that doesn't exist
+ # for the given name (relative to a provided module)?
+ def test_loadTestsFromName__relative_invalid_testmethod(self):
+ import new
+ m = new.module('m')
+ class MyTestCase(unittest.TestCase):
+ def test(self):
+ pass
+ m.testcase_1 = MyTestCase
+
+ loader = unittest.TestLoader()
+ try:
+ loader.loadTestsFromName('testcase_1.testfoo', m)
+ except AttributeError as e:
+ self.assertEqual(str(e), "type object 'MyTestCase' has no attribute 'testfoo'")
+ else:
+ self.fail("Failed to raise AttributeError")
+
+ # "The specifier name is a ``dotted name'' that may resolve ... to
+ # ... a callable object which returns a ... TestSuite instance"
+ def test_loadTestsFromName__callable__TestSuite(self):
+ import new
+ m = new.module('m')
+ testcase_1 = unittest.FunctionTestCase(lambda: None)
+ testcase_2 = unittest.FunctionTestCase(lambda: None)
+ def return_TestSuite():
+ return unittest.TestSuite([testcase_1, testcase_2])
+ m.return_TestSuite = return_TestSuite
+
+ loader = unittest.TestLoader()
+ suite = loader.loadTestsFromName('return_TestSuite', m)
+ self.failUnless(isinstance(suite, loader.suiteClass))
+ self.assertEqual(list(suite), [testcase_1, testcase_2])
+
+ # "The specifier name is a ``dotted name'' that may resolve ... to
+ # ... a callable object which returns a TestCase ... instance"
+ def test_loadTestsFromName__callable__TestCase_instance(self):
+ import new
+ m = new.module('m')
+ testcase_1 = unittest.FunctionTestCase(lambda: None)
+ def return_TestCase():
+ return testcase_1
+ m.return_TestCase = return_TestCase
+
+ loader = unittest.TestLoader()
+ suite = loader.loadTestsFromName('return_TestCase', m)
+ self.failUnless(isinstance(suite, loader.suiteClass))
+ self.assertEqual(list(suite), [testcase_1])
+
+ # "The specifier name is a ``dotted name'' that may resolve ... to
+ # ... a callable object which returns a TestCase or TestSuite instance"
+ #
+ # What happens if the callable returns something else?
+ def test_loadTestsFromName__callable__wrong_type(self):
+ import new
+ m = new.module('m')
+ def return_wrong():
+ return 6
+ m.return_wrong = return_wrong
+
+ loader = unittest.TestLoader()
+ try:
+ suite = loader.loadTestsFromName('return_wrong', m)
+ except TypeError:
+ pass
+ else:
+ self.fail("TestLoader.loadTestsFromName failed to raise TypeError")
+
+ # "The specifier can refer to modules and packages which have not been
+ # imported; they will be imported as a side-effect"
+ def test_loadTestsFromName__module_not_loaded(self):
+ # We're going to try to load this module as a side-effect, so it
+ # better not be loaded before we try.
+ #
+ # Why pick audioop? Google shows it isn't used very often, so there's
+ # a good chance that it won't be imported when this test is run
+ module_name = 'audioop'
+
+ import sys
+ if module_name in sys.modules:
+ del sys.modules[module_name]
+
+ loader = unittest.TestLoader()
+ try:
+ suite = loader.loadTestsFromName(module_name)
+
+ self.failUnless(isinstance(suite, loader.suiteClass))
+ self.assertEqual(list(suite), [])
+
+ # audioop should now be loaded, thanks to loadTestsFromName()
+ self.failUnless(module_name in sys.modules)
+ finally:
+ del sys.modules[module_name]
+
+ ################################################################
+ ### Tests for TestLoader.loadTestsFromName()
+
+ ### Tests for TestLoader.loadTestsFromNames()
+ ################################################################
+
+ # "Similar to loadTestsFromName(), but takes a sequence of names rather
+ # than a single name."
+ #
+ # What happens if that sequence of names is empty?
+ def test_loadTestsFromNames__empty_name_list(self):
+ loader = unittest.TestLoader()
+
+ suite = loader.loadTestsFromNames([])
+ self.failUnless(isinstance(suite, loader.suiteClass))
+ self.assertEqual(list(suite), [])
+
+ # "Similar to loadTestsFromName(), but takes a sequence of names rather
+ # than a single name."
+ # ...
+ # "The method optionally resolves name relative to the given module"
+ #
+ # What happens if that sequence of names is empty?
+ #
+ # XXX Should this raise a ValueError or just return an empty TestSuite?
+ def test_loadTestsFromNames__relative_empty_name_list(self):
+ loader = unittest.TestLoader()
+
+ suite = loader.loadTestsFromNames([], unittest)
+ self.failUnless(isinstance(suite, loader.suiteClass))
+ self.assertEqual(list(suite), [])
+
+ # "The specifier name is a ``dotted name'' that may resolve either to
+ # a module, a test case class, a TestSuite instance, a test method
+ # within a test case class, or a callable object which returns a
+ # TestCase or TestSuite instance."
+ #
+ # Is ValueError raised in response to an empty name?
+ def test_loadTestsFromNames__empty_name(self):
+ loader = unittest.TestLoader()
+
+ try:
+ loader.loadTestsFromNames([''])
+ except ValueError as e:
+ self.assertEqual(str(e), "Empty module name")
+ else:
+ self.fail("TestLoader.loadTestsFromNames failed to raise ValueError")
+
+ # "The specifier name is a ``dotted name'' that may resolve either to
+ # a module, a test case class, a TestSuite instance, a test method
+ # within a test case class, or a callable object which returns a
+ # TestCase or TestSuite instance."
+ #
+ # What happens when presented with an impossible module name?
+ def test_loadTestsFromNames__malformed_name(self):
+ loader = unittest.TestLoader()
+
+ # XXX Should this raise ValueError or ImportError?
+ try:
+ loader.loadTestsFromNames(['abc () //'])
+ except ValueError:
+ pass
+ except ImportError:
+ pass
+ else:
+ self.fail("TestLoader.loadTestsFromNames failed to raise ValueError")
+
+ # "The specifier name is a ``dotted name'' that may resolve either to
+ # a module, a test case class, a TestSuite instance, a test method
+ # within a test case class, or a callable object which returns a
+ # TestCase or TestSuite instance."
+ #
+ # What happens when no module can be found for the given name?
+ def test_loadTestsFromNames__unknown_module_name(self):
+ loader = unittest.TestLoader()
+
+ try:
+ loader.loadTestsFromNames(['sdasfasfasdf'])
+ except ImportError as e:
+ self.assertEqual(str(e), "No module named sdasfasfasdf")
+ else:
+ self.fail("TestLoader.loadTestsFromNames failed to raise ImportError")
+
+ # "The specifier name is a ``dotted name'' that may resolve either to
+ # a module, a test case class, a TestSuite instance, a test method
+ # within a test case class, or a callable object which returns a
+ # TestCase or TestSuite instance."
+ #
+ # What happens when the module can be found, but not the attribute?
+ def test_loadTestsFromNames__unknown_attr_name(self):
+ loader = unittest.TestLoader()
+
+ try:
+ loader.loadTestsFromNames(['unittest.sdasfasfasdf', 'unittest'])
+ except AttributeError as e:
+ self.assertEqual(str(e), "'module' object has no attribute 'sdasfasfasdf'")
+ else:
+ self.fail("TestLoader.loadTestsFromNames failed to raise AttributeError")
+
+ # "The specifier name is a ``dotted name'' that may resolve either to
+ # a module, a test case class, a TestSuite instance, a test method
+ # within a test case class, or a callable object which returns a
+ # TestCase or TestSuite instance."
+ # ...
+ # "The method optionally resolves name relative to the given module"
+ #
+ # What happens when given an unknown attribute on a specified `module`
+ # argument?
+ def test_loadTestsFromNames__unknown_name_relative_1(self):
+ loader = unittest.TestLoader()
+
+ try:
+ loader.loadTestsFromNames(['sdasfasfasdf'], unittest)
+ except AttributeError as e:
+ self.assertEqual(str(e), "'module' object has no attribute 'sdasfasfasdf'")
+ else:
+ self.fail("TestLoader.loadTestsFromName failed to raise AttributeError")
+
+ # "The specifier name is a ``dotted name'' that may resolve either to
+ # a module, a test case class, a TestSuite instance, a test method
+ # within a test case class, or a callable object which returns a
+ # TestCase or TestSuite instance."
+ # ...
+ # "The method optionally resolves name relative to the given module"
+ #
+ # Do unknown attributes (relative to a provided module) still raise an
+ # exception even in the presence of valid attribute names?
+ def test_loadTestsFromNames__unknown_name_relative_2(self):
+ loader = unittest.TestLoader()
+
+ try:
+ loader.loadTestsFromNames(['TestCase', 'sdasfasfasdf'], unittest)
+ except AttributeError as e:
+ self.assertEqual(str(e), "'module' object has no attribute 'sdasfasfasdf'")
+ else:
+ self.fail("TestLoader.loadTestsFromName failed to raise AttributeError")
+
+ # "The specifier name is a ``dotted name'' that may resolve either to
+ # a module, a test case class, a TestSuite instance, a test method
+ # within a test case class, or a callable object which returns a
+ # TestCase or TestSuite instance."
+ # ...
+ # "The method optionally resolves name relative to the given module"
+ #
+ # What happens when faced with the empty string?
+ #
+ # XXX This currently raises AttributeError, though ValueError is probably
+ # more appropriate
+ def test_loadTestsFromNames__relative_empty_name(self):
+ loader = unittest.TestLoader()
+
+ try:
+ loader.loadTestsFromNames([''], unittest)
+ except AttributeError:
+ pass
+ else:
+ self.fail("Failed to raise ValueError")
+
+ # "The specifier name is a ``dotted name'' that may resolve either to
+ # a module, a test case class, a TestSuite instance, a test method
+ # within a test case class, or a callable object which returns a
+ # TestCase or TestSuite instance."
+ # ...
+ # "The method optionally resolves name relative to the given module"
+ #
+ # What happens when presented with an impossible attribute name?
+ def test_loadTestsFromNames__relative_malformed_name(self):
+ loader = unittest.TestLoader()
+
+ # XXX Should this raise AttributeError or ValueError?
+ try:
+ loader.loadTestsFromNames(['abc () //'], unittest)
+ except AttributeError:
+ pass
+ except ValueError:
+ pass
+ else:
+ self.fail("TestLoader.loadTestsFromNames failed to raise ValueError")
+
+ # "The method optionally resolves name relative to the given module"
+ #
+ # Does loadTestsFromNames() make sure the provided `module` is in fact
+ # a module?
+ #
+ # XXX This validation is currently not done. This flexibility should
+ # either be documented or a TypeError should be raised.
+ def test_loadTestsFromNames__relative_not_a_module(self):
+ class MyTestCase(unittest.TestCase):
+ def test(self):
+ pass
+
+ class NotAModule(object):
+ test_2 = MyTestCase
+
+ loader = unittest.TestLoader()
+ suite = loader.loadTestsFromNames(['test_2'], NotAModule)
+
+ reference = [unittest.TestSuite([MyTestCase('test')])]
+ self.assertEqual(list(suite), reference)
+
+ # "The specifier name is a ``dotted name'' that may resolve either to
+ # a module, a test case class, a TestSuite instance, a test method
+ # within a test case class, or a callable object which returns a
+ # TestCase or TestSuite instance."
+ #
+ # Does it raise an exception if the name resolves to an invalid
+ # object?
+ def test_loadTestsFromNames__relative_bad_object(self):
+ import new
+ m = new.module('m')
+ m.testcase_1 = object()
+
+ loader = unittest.TestLoader()
+ try:
+ loader.loadTestsFromNames(['testcase_1'], m)
+ except TypeError:
+ pass
+ else:
+ self.fail("Should have raised TypeError")
+
+ # "The specifier name is a ``dotted name'' that may resolve ... to
+ # ... a test case class"
+ def test_loadTestsFromNames__relative_TestCase_subclass(self):
+ import new
+ m = new.module('m')
+ class MyTestCase(unittest.TestCase):
+ def test(self):
+ pass
+ m.testcase_1 = MyTestCase
+
+ loader = unittest.TestLoader()
+ suite = loader.loadTestsFromNames(['testcase_1'], m)
+ self.failUnless(isinstance(suite, loader.suiteClass))
+
+ expected = loader.suiteClass([MyTestCase('test')])
+ self.assertEqual(list(suite), [expected])
+
+ # "The specifier name is a ``dotted name'' that may resolve ... to
+ # ... a TestSuite instance"
+ def test_loadTestsFromNames__relative_TestSuite(self):
+ import new
+ m = new.module('m')
+ class MyTestCase(unittest.TestCase):
+ def test(self):
+ pass
+ m.testsuite = unittest.TestSuite([MyTestCase('test')])
+
+ loader = unittest.TestLoader()
+ suite = loader.loadTestsFromNames(['testsuite'], m)
+ self.failUnless(isinstance(suite, loader.suiteClass))
+
+ self.assertEqual(list(suite), [m.testsuite])
+
+ # "The specifier name is a ``dotted name'' that may resolve ... to ... a
+ # test method within a test case class"
+ def test_loadTestsFromNames__relative_testmethod(self):
+ import new
+ m = new.module('m')
+ class MyTestCase(unittest.TestCase):
+ def test(self):
+ pass
+ m.testcase_1 = MyTestCase
+
+ loader = unittest.TestLoader()
+ suite = loader.loadTestsFromNames(['testcase_1.test'], m)
+ self.failUnless(isinstance(suite, loader.suiteClass))
+
+ ref_suite = unittest.TestSuite([MyTestCase('test')])
+ self.assertEqual(list(suite), [ref_suite])
+
+ # "The specifier name is a ``dotted name'' that may resolve ... to ... a
+ # test method within a test case class"
+ #
+ # Does the method gracefully handle names that initially look like they
+ # resolve to "a test method within a test case class" but don't?
+ def test_loadTestsFromNames__relative_invalid_testmethod(self):
+ import new
+ m = new.module('m')
+ class MyTestCase(unittest.TestCase):
+ def test(self):
+ pass
+ m.testcase_1 = MyTestCase
+
+ loader = unittest.TestLoader()
+ try:
+ loader.loadTestsFromNames(['testcase_1.testfoo'], m)
+ except AttributeError as e:
+ self.assertEqual(str(e), "type object 'MyTestCase' has no attribute 'testfoo'")
+ else:
+ self.fail("Failed to raise AttributeError")
+
+ # "The specifier name is a ``dotted name'' that may resolve ... to
+ # ... a callable object which returns a ... TestSuite instance"
+ def test_loadTestsFromNames__callable__TestSuite(self):
+ import new
+ m = new.module('m')
+ testcase_1 = unittest.FunctionTestCase(lambda: None)
+ testcase_2 = unittest.FunctionTestCase(lambda: None)
+ def return_TestSuite():
+ return unittest.TestSuite([testcase_1, testcase_2])
+ m.return_TestSuite = return_TestSuite
+
+ loader = unittest.TestLoader()
+ suite = loader.loadTestsFromNames(['return_TestSuite'], m)
+ self.failUnless(isinstance(suite, loader.suiteClass))
+
+ expected = unittest.TestSuite([testcase_1, testcase_2])
+ self.assertEqual(list(suite), [expected])
+
+ # "The specifier name is a ``dotted name'' that may resolve ... to
+ # ... a callable object which returns a TestCase ... instance"
+ def test_loadTestsFromNames__callable__TestCase_instance(self):
+ import new
+ m = new.module('m')
+ testcase_1 = unittest.FunctionTestCase(lambda: None)
+ def return_TestCase():
+ return testcase_1
+ m.return_TestCase = return_TestCase
+
+ loader = unittest.TestLoader()
+ suite = loader.loadTestsFromNames(['return_TestCase'], m)
+ self.failUnless(isinstance(suite, loader.suiteClass))
+
+ ref_suite = unittest.TestSuite([testcase_1])
+ self.assertEqual(list(suite), [ref_suite])
+
+ # "The specifier name is a ``dotted name'' that may resolve ... to
+ # ... a callable object which returns a TestCase or TestSuite instance"
+ #
+ # Are staticmethods handled correctly?
+ def test_loadTestsFromNames__callable__call_staticmethod(self):
+ import new
+ m = new.module('m')
+ class Test1(unittest.TestCase):
+ def test(self):
+ pass
+
+ testcase_1 = Test1('test')
+ class Foo(unittest.TestCase):
+ @staticmethod
+ def foo():
+ return testcase_1
+ m.Foo = Foo
+
+ loader = unittest.TestLoader()
+ suite = loader.loadTestsFromNames(['Foo.foo'], m)
+ self.failUnless(isinstance(suite, loader.suiteClass))
+
+ ref_suite = unittest.TestSuite([testcase_1])
+ self.assertEqual(list(suite), [ref_suite])
+
+ # "The specifier name is a ``dotted name'' that may resolve ... to
+ # ... a callable object which returns a TestCase or TestSuite instance"
+ #
+ # What happens when the callable returns something else?
+ def test_loadTestsFromNames__callable__wrong_type(self):
+ import new
+ m = new.module('m')
+ def return_wrong():
+ return 6
+ m.return_wrong = return_wrong
+
+ loader = unittest.TestLoader()
+ try:
+ suite = loader.loadTestsFromNames(['return_wrong'], m)
+ except TypeError:
+ pass
+ else:
+ self.fail("TestLoader.loadTestsFromNames failed to raise TypeError")
+
+ # "The specifier can refer to modules and packages which have not been
+ # imported; they will be imported as a side-effect"
+ def test_loadTestsFromNames__module_not_loaded(self):
+ # We're going to try to load this module as a side-effect, so it
+ # better not be loaded before we try.
+ #
+ # Why pick audioop? Google shows it isn't used very often, so there's
+ # a good chance that it won't be imported when this test is run
+ module_name = 'audioop'
+
+ import sys
+ if module_name in sys.modules:
+ del sys.modules[module_name]
+
+ loader = unittest.TestLoader()
+ try:
+ suite = loader.loadTestsFromNames([module_name])
+
+ self.failUnless(isinstance(suite, loader.suiteClass))
+ self.assertEqual(list(suite), [unittest.TestSuite()])
+
+ # audioop should now be loaded, thanks to loadTestsFromName()
+ self.failUnless(module_name in sys.modules)
+ finally:
+ del sys.modules[module_name]
+
+ ################################################################
+ ### /Tests for TestLoader.loadTestsFromNames()
+
+ ### Tests for TestLoader.getTestCaseNames()
+ ################################################################
+
+ # "Return a sorted sequence of method names found within testCaseClass"
+ #
+ # Test.foobar is defined to make sure getTestCaseNames() respects
+ # loader.testMethodPrefix
+ def test_getTestCaseNames(self):
+ class Test(unittest.TestCase):
+ def test_1(self): pass
+ def test_2(self): pass
+ def foobar(self): pass
+
+ loader = unittest.TestLoader()
+
+ self.assertEqual(loader.getTestCaseNames(Test), ['test_1', 'test_2'])
+
+ # "Return a sorted sequence of method names found within testCaseClass"
+ #
+ # Does getTestCaseNames() behave appropriately if no tests are found?
+ def test_getTestCaseNames__no_tests(self):
+ class Test(unittest.TestCase):
+ def foobar(self): pass
+
+ loader = unittest.TestLoader()
+
+ self.assertEqual(loader.getTestCaseNames(Test), [])
+
+ # "Return a sorted sequence of method names found within testCaseClass"
+ #
+ # Are not-TestCases handled gracefully?
+ #
+ # XXX This should raise a TypeError, not return a list
+ #
+ # XXX It's too late in the 2.5 release cycle to fix this, but it should
+ # probably be revisited for 2.6
+ def test_getTestCaseNames__not_a_TestCase(self):
+ class BadCase(int):
+ def test_foo(self):
+ pass
+
+ loader = unittest.TestLoader()
+ names = loader.getTestCaseNames(BadCase)
+
+ self.assertEqual(names, ['test_foo'])
+
+ # "Return a sorted sequence of method names found within testCaseClass"
+ #
+ # Make sure inherited names are handled.
+ #
+ # TestP.foobar is defined to make sure getTestCaseNames() respects
+ # loader.testMethodPrefix
+ def test_getTestCaseNames__inheritance(self):
+ class TestP(unittest.TestCase):
+ def test_1(self): pass
+ def test_2(self): pass
+ def foobar(self): pass
+
+ class TestC(TestP):
+ def test_1(self): pass
+ def test_3(self): pass
+
+ loader = unittest.TestLoader()
+
+ names = ['test_1', 'test_2', 'test_3']
+ self.assertEqual(loader.getTestCaseNames(TestC), names)
+
+ ################################################################
+ ### /Tests for TestLoader.getTestCaseNames()
+
+ ### Tests for TestLoader.testMethodPrefix
+ ################################################################
+
+ # "String giving the prefix of method names which will be interpreted as
+ # test methods"
+ #
+ # Implicit in the documentation is that testMethodPrefix is respected by
+ # all loadTestsFrom* methods.
+ def test_testMethodPrefix__loadTestsFromTestCase(self):
+ class Foo(unittest.TestCase):
+ def test_1(self): pass
+ def test_2(self): pass
+ def foo_bar(self): pass
+
+ tests_1 = unittest.TestSuite([Foo('foo_bar')])
+ tests_2 = unittest.TestSuite([Foo('test_1'), Foo('test_2')])
+
+ loader = unittest.TestLoader()
+ loader.testMethodPrefix = 'foo'
+ self.assertEqual(loader.loadTestsFromTestCase(Foo), tests_1)
+
+ loader.testMethodPrefix = 'test'
+ self.assertEqual(loader.loadTestsFromTestCase(Foo), tests_2)
+
+ # "String giving the prefix of method names which will be interpreted as
+ # test methods"
+ #
+ # Implicit in the documentation is that testMethodPrefix is respected by
+ # all loadTestsFrom* methods.
+ def test_testMethodPrefix__loadTestsFromModule(self):
+ import new
+ m = new.module('m')
+ class Foo(unittest.TestCase):
+ def test_1(self): pass
+ def test_2(self): pass
+ def foo_bar(self): pass
+ m.Foo = Foo
+
+ tests_1 = [unittest.TestSuite([Foo('foo_bar')])]
+ tests_2 = [unittest.TestSuite([Foo('test_1'), Foo('test_2')])]
+
+ loader = unittest.TestLoader()
+ loader.testMethodPrefix = 'foo'
+ self.assertEqual(list(loader.loadTestsFromModule(m)), tests_1)
+
+ loader.testMethodPrefix = 'test'
+ self.assertEqual(list(loader.loadTestsFromModule(m)), tests_2)
+
+ # "String giving the prefix of method names which will be interpreted as
+ # test methods"
+ #
+ # Implicit in the documentation is that testMethodPrefix is respected by
+ # all loadTestsFrom* methods.
+ def test_testMethodPrefix__loadTestsFromName(self):
+ import new
+ m = new.module('m')
+ class Foo(unittest.TestCase):
+ def test_1(self): pass
+ def test_2(self): pass
+ def foo_bar(self): pass
+ m.Foo = Foo
+
+ tests_1 = unittest.TestSuite([Foo('foo_bar')])
+ tests_2 = unittest.TestSuite([Foo('test_1'), Foo('test_2')])
+
+ loader = unittest.TestLoader()
+ loader.testMethodPrefix = 'foo'
+ self.assertEqual(loader.loadTestsFromName('Foo', m), tests_1)
+
+ loader.testMethodPrefix = 'test'
+ self.assertEqual(loader.loadTestsFromName('Foo', m), tests_2)
+
+ # "String giving the prefix of method names which will be interpreted as
+ # test methods"
+ #
+ # Implicit in the documentation is that testMethodPrefix is respected by
+ # all loadTestsFrom* methods.
+ def test_testMethodPrefix__loadTestsFromNames(self):
+ import new
+ m = new.module('m')
+ class Foo(unittest.TestCase):
+ def test_1(self): pass
+ def test_2(self): pass
+ def foo_bar(self): pass
+ m.Foo = Foo
+
+ tests_1 = unittest.TestSuite([unittest.TestSuite([Foo('foo_bar')])])
+ tests_2 = unittest.TestSuite([Foo('test_1'), Foo('test_2')])
+ tests_2 = unittest.TestSuite([tests_2])
+
+ loader = unittest.TestLoader()
+ loader.testMethodPrefix = 'foo'
+ self.assertEqual(loader.loadTestsFromNames(['Foo'], m), tests_1)
+
+ loader.testMethodPrefix = 'test'
+ self.assertEqual(loader.loadTestsFromNames(['Foo'], m), tests_2)
+
+ # "The default value is 'test'"
+ def test_testMethodPrefix__default_value(self):
+ loader = unittest.TestLoader()
+ self.failUnless(loader.testMethodPrefix == 'test')
+
+ ################################################################
+ ### /Tests for TestLoader.testMethodPrefix
+
+ ### Tests for TestLoader.sortTestMethodsUsing
+ ################################################################
+
+ # "Function to be used to compare method names when sorting them in
+ # getTestCaseNames() and all the loadTestsFromX() methods"
+ def test_sortTestMethodsUsing__loadTestsFromTestCase(self):
+ def reversed_cmp(x, y):
+ return -cmp(x, y)
+
+ class Foo(unittest.TestCase):
+ def test_1(self): pass
+ def test_2(self): pass
+
+ loader = unittest.TestLoader()
+ loader.sortTestMethodsUsing = reversed_cmp
+
+ tests = loader.suiteClass([Foo('test_2'), Foo('test_1')])
+ self.assertEqual(loader.loadTestsFromTestCase(Foo), tests)
+
+ # "Function to be used to compare method names when sorting them in
+ # getTestCaseNames() and all the loadTestsFromX() methods"
+ def test_sortTestMethodsUsing__loadTestsFromModule(self):
+ def reversed_cmp(x, y):
+ return -cmp(x, y)
+
+ import new
+ m = new.module('m')
+ class Foo(unittest.TestCase):
+ def test_1(self): pass
+ def test_2(self): pass
+ m.Foo = Foo
+
+ loader = unittest.TestLoader()
+ loader.sortTestMethodsUsing = reversed_cmp
+
+ tests = [loader.suiteClass([Foo('test_2'), Foo('test_1')])]
+ self.assertEqual(list(loader.loadTestsFromModule(m)), tests)
+
+ # "Function to be used to compare method names when sorting them in
+ # getTestCaseNames() and all the loadTestsFromX() methods"
+ def test_sortTestMethodsUsing__loadTestsFromName(self):
+ def reversed_cmp(x, y):
+ return -cmp(x, y)
+
+ import new
+ m = new.module('m')
+ class Foo(unittest.TestCase):
+ def test_1(self): pass
+ def test_2(self): pass
+ m.Foo = Foo
+
+ loader = unittest.TestLoader()
+ loader.sortTestMethodsUsing = reversed_cmp
+
+ tests = loader.suiteClass([Foo('test_2'), Foo('test_1')])
+ self.assertEqual(loader.loadTestsFromName('Foo', m), tests)
+
+ # "Function to be used to compare method names when sorting them in
+ # getTestCaseNames() and all the loadTestsFromX() methods"
+ def test_sortTestMethodsUsing__loadTestsFromNames(self):
+ def reversed_cmp(x, y):
+ return -cmp(x, y)
+
+ import new
+ m = new.module('m')
+ class Foo(unittest.TestCase):
+ def test_1(self): pass
+ def test_2(self): pass
+ m.Foo = Foo
+
+ loader = unittest.TestLoader()
+ loader.sortTestMethodsUsing = reversed_cmp
+
+ tests = [loader.suiteClass([Foo('test_2'), Foo('test_1')])]
+ self.assertEqual(list(loader.loadTestsFromNames(['Foo'], m)), tests)
+
+ # "Function to be used to compare method names when sorting them in
+ # getTestCaseNames()"
+ #
+ # Does it actually affect getTestCaseNames()?
+ def test_sortTestMethodsUsing__getTestCaseNames(self):
+ def reversed_cmp(x, y):
+ return -cmp(x, y)
+
+ class Foo(unittest.TestCase):
+ def test_1(self): pass
+ def test_2(self): pass
+
+ loader = unittest.TestLoader()
+ loader.sortTestMethodsUsing = reversed_cmp
+
+ test_names = ['test_2', 'test_1']
+ self.assertEqual(loader.getTestCaseNames(Foo), test_names)
+
+ # "The default value is the built-in cmp() function"
+ def test_sortTestMethodsUsing__default_value(self):
+ loader = unittest.TestLoader()
+ self.failUnless(loader.sortTestMethodsUsing is cmp)
+
+ # "it can be set to None to disable the sort."
+ #
+ # XXX How is this different from reassigning cmp? Are the tests returned
+ # in a random order or something? This behaviour should die
+ def test_sortTestMethodsUsing__None(self):
+ class Foo(unittest.TestCase):
+ def test_1(self): pass
+ def test_2(self): pass
+
+ loader = unittest.TestLoader()
+ loader.sortTestMethodsUsing = None
+
+ test_names = ['test_2', 'test_1']
+ self.assertEqual(set(loader.getTestCaseNames(Foo)), set(test_names))
+
+ ################################################################
+ ### /Tests for TestLoader.sortTestMethodsUsing
+
+ ### Tests for TestLoader.suiteClass
+ ################################################################
+
+ # "Callable object that constructs a test suite from a list of tests."
+ def test_suiteClass__loadTestsFromTestCase(self):
+ class Foo(unittest.TestCase):
+ def test_1(self): pass
+ def test_2(self): pass
+ def foo_bar(self): pass
+
+ tests = [Foo('test_1'), Foo('test_2')]
+
+ loader = unittest.TestLoader()
+ loader.suiteClass = list
+ self.assertEqual(loader.loadTestsFromTestCase(Foo), tests)
+
+ # It is implicit in the documentation for TestLoader.suiteClass that
+ # all TestLoader.loadTestsFrom* methods respect it. Let's make sure
+ def test_suiteClass__loadTestsFromModule(self):
+ import new
+ m = new.module('m')
+ class Foo(unittest.TestCase):
+ def test_1(self): pass
+ def test_2(self): pass
+ def foo_bar(self): pass
+ m.Foo = Foo
+
+ tests = [[Foo('test_1'), Foo('test_2')]]
+
+ loader = unittest.TestLoader()
+ loader.suiteClass = list
+ self.assertEqual(loader.loadTestsFromModule(m), tests)
+
+ # It is implicit in the documentation for TestLoader.suiteClass that
+ # all TestLoader.loadTestsFrom* methods respect it. Let's make sure
+ def test_suiteClass__loadTestsFromName(self):
+ import new
+ m = new.module('m')
+ class Foo(unittest.TestCase):
+ def test_1(self): pass
+ def test_2(self): pass
+ def foo_bar(self): pass
+ m.Foo = Foo
+
+ tests = [Foo('test_1'), Foo('test_2')]
+
+ loader = unittest.TestLoader()
+ loader.suiteClass = list
+ self.assertEqual(loader.loadTestsFromName('Foo', m), tests)
+
+ # It is implicit in the documentation for TestLoader.suiteClass that
+ # all TestLoader.loadTestsFrom* methods respect it. Let's make sure
+ def test_suiteClass__loadTestsFromNames(self):
+ import new
+ m = new.module('m')
+ class Foo(unittest.TestCase):
+ def test_1(self): pass
+ def test_2(self): pass
+ def foo_bar(self): pass
+ m.Foo = Foo
+
+ tests = [[Foo('test_1'), Foo('test_2')]]
+
+ loader = unittest.TestLoader()
+ loader.suiteClass = list
+ self.assertEqual(loader.loadTestsFromNames(['Foo'], m), tests)
+
+ # "The default value is the TestSuite class"
+ def test_suiteClass__default_value(self):
+ loader = unittest.TestLoader()
+ self.failUnless(loader.suiteClass is unittest.TestSuite)
+
+ ################################################################
+ ### /Tests for TestLoader.suiteClass
+
+### Support code for Test_TestSuite
+################################################################
+
+class Foo(unittest.TestCase):
+ def test_1(self): pass
+ def test_2(self): pass
+ def test_3(self): pass
+ def runTest(self): pass
+
+def _mk_TestSuite(*names):
+ return unittest.TestSuite(Foo(n) for n in names)
+
+################################################################
+### /Support code for Test_TestSuite
+
+class Test_TestSuite(TestCase, TestEquality):
+
+ ### Set up attributes needed by inherited tests
+ ################################################################
+
+ # Used by TestEquality.test_eq
+ eq_pairs = [(unittest.TestSuite(), unittest.TestSuite())
+ ,(unittest.TestSuite(), unittest.TestSuite([]))
+ ,(_mk_TestSuite('test_1'), _mk_TestSuite('test_1'))]
+
+ # Used by TestEquality.test_ne
+ ne_pairs = [(unittest.TestSuite(), _mk_TestSuite('test_1'))
+ ,(unittest.TestSuite([]), _mk_TestSuite('test_1'))
+ ,(_mk_TestSuite('test_1', 'test_2'), _mk_TestSuite('test_1', 'test_3'))
+ ,(_mk_TestSuite('test_1'), _mk_TestSuite('test_2'))]
+
+ ################################################################
+ ### /Set up attributes needed by inherited tests
+
+ ### Tests for TestSuite.__init__
+ ################################################################
+
+ # "class TestSuite([tests])"
+ #
+ # The tests iterable should be optional
+ def test_init__tests_optional(self):
+ suite = unittest.TestSuite()
+
+ self.assertEqual(suite.countTestCases(), 0)
+
+ # "class TestSuite([tests])"
+ # ...
+ # "If tests is given, it must be an iterable of individual test cases
+ # or other test suites that will be used to build the suite initially"
+ #
+ # TestSuite should deal with empty tests iterables by allowing the
+ # creation of an empty suite
+ def test_init__empty_tests(self):
+ suite = unittest.TestSuite([])
+
+ self.assertEqual(suite.countTestCases(), 0)
+
+ # "class TestSuite([tests])"
+ # ...
+ # "If tests is given, it must be an iterable of individual test cases
+ # or other test suites that will be used to build the suite initially"
+ #
+ # TestSuite should allow any iterable to provide tests
+ def test_init__tests_from_any_iterable(self):
+ def tests():
+ yield unittest.FunctionTestCase(lambda: None)
+ yield unittest.FunctionTestCase(lambda: None)
+
+ suite_1 = unittest.TestSuite(tests())
+ self.assertEqual(suite_1.countTestCases(), 2)
+
+ suite_2 = unittest.TestSuite(suite_1)
+ self.assertEqual(suite_2.countTestCases(), 2)
+
+ suite_3 = unittest.TestSuite(set(suite_1))
+ self.assertEqual(suite_3.countTestCases(), 2)
+
+ # "class TestSuite([tests])"
+ # ...
+ # "If tests is given, it must be an iterable of individual test cases
+ # or other test suites that will be used to build the suite initially"
+ #
+ # Does TestSuite() also allow other TestSuite() instances to be present
+ # in the tests iterable?
+ def test_init__TestSuite_instances_in_tests(self):
+ def tests():
+ ftc = unittest.FunctionTestCase(lambda: None)
+ yield unittest.TestSuite([ftc])
+ yield unittest.FunctionTestCase(lambda: None)
+
+ suite = unittest.TestSuite(tests())
+ self.assertEqual(suite.countTestCases(), 2)
+
+ ################################################################
+ ### /Tests for TestSuite.__init__
+
+ # Container types should support the iter protocol
+ def test_iter(self):
+ test1 = unittest.FunctionTestCase(lambda: None)
+ test2 = unittest.FunctionTestCase(lambda: None)
+ suite = unittest.TestSuite((test1, test2))
+
+ self.assertEqual(list(suite), [test1, test2])
+
+ # "Return the number of tests represented by the this test object.
+ # ...this method is also implemented by the TestSuite class, which can
+ # return larger [greater than 1] values"
+ #
+ # Presumably an empty TestSuite returns 0?
+ def test_countTestCases_zero_simple(self):
+ suite = unittest.TestSuite()
+
+ self.assertEqual(suite.countTestCases(), 0)
+
+ # "Return the number of tests represented by the this test object.
+ # ...this method is also implemented by the TestSuite class, which can
+ # return larger [greater than 1] values"
+ #
+ # Presumably an empty TestSuite (even if it contains other empty
+ # TestSuite instances) returns 0?
+ def test_countTestCases_zero_nested(self):
+ class Test1(unittest.TestCase):
+ def test(self):
+ pass
+
+ suite = unittest.TestSuite([unittest.TestSuite()])
+
+ self.assertEqual(suite.countTestCases(), 0)
+
+ # "Return the number of tests represented by the this test object.
+ # ...this method is also implemented by the TestSuite class, which can
+ # return larger [greater than 1] values"
+ def test_countTestCases_simple(self):
+ test1 = unittest.FunctionTestCase(lambda: None)
+ test2 = unittest.FunctionTestCase(lambda: None)
+ suite = unittest.TestSuite((test1, test2))
+
+ self.assertEqual(suite.countTestCases(), 2)
+
+ # "Return the number of tests represented by the this test object.
+ # ...this method is also implemented by the TestSuite class, which can
+ # return larger [greater than 1] values"
+ #
+ # Make sure this holds for nested TestSuite instances, too
+ def test_countTestCases_nested(self):
+ class Test1(unittest.TestCase):
+ def test1(self): pass
+ def test2(self): pass
+
+ test2 = unittest.FunctionTestCase(lambda: None)
+ test3 = unittest.FunctionTestCase(lambda: None)
+ child = unittest.TestSuite((Test1('test2'), test2))
+ parent = unittest.TestSuite((test3, child, Test1('test1')))
+
+ self.assertEqual(parent.countTestCases(), 4)
+
+ # "Run the tests associated with this suite, collecting the result into
+ # the test result object passed as result."
+ #
+ # And if there are no tests? What then?
+ def test_run__empty_suite(self):
+ events = []
+ result = LoggingResult(events)
+
+ suite = unittest.TestSuite()
+
+ suite.run(result)
+
+ self.assertEqual(events, [])
+
+ # "Note that unlike TestCase.run(), TestSuite.run() requires the
+ # "result object to be passed in."
+ def test_run__requires_result(self):
+ suite = unittest.TestSuite()
+
+ try:
+ suite.run()
+ except TypeError:
+ pass
+ else:
+ self.fail("Failed to raise TypeError")
+
+ # "Run the tests associated with this suite, collecting the result into
+ # the test result object passed as result."
+ def test_run(self):
+ events = []
+ result = LoggingResult(events)
+
+ class LoggingCase(unittest.TestCase):
+ def run(self, result):
+ events.append('run %s' % self._testMethodName)
+
+ def test1(self): pass
+ def test2(self): pass
+
+ tests = [LoggingCase('test1'), LoggingCase('test2')]
+
+ unittest.TestSuite(tests).run(result)
+
+ self.assertEqual(events, ['run test1', 'run test2'])
+
+ # "Add a TestCase ... to the suite"
+ def test_addTest__TestCase(self):
+ class Foo(unittest.TestCase):
+ def test(self): pass
+
+ test = Foo('test')
+ suite = unittest.TestSuite()
+
+ suite.addTest(test)
+
+ self.assertEqual(suite.countTestCases(), 1)
+ self.assertEqual(list(suite), [test])
+
+ # "Add a ... TestSuite to the suite"
+ def test_addTest__TestSuite(self):
+ class Foo(unittest.TestCase):
+ def test(self): pass
+
+ suite_2 = unittest.TestSuite([Foo('test')])
+
+ suite = unittest.TestSuite()
+ suite.addTest(suite_2)
+
+ self.assertEqual(suite.countTestCases(), 1)
+ self.assertEqual(list(suite), [suite_2])
+
+ # "Add all the tests from an iterable of TestCase and TestSuite
+ # instances to this test suite."
+ #
+ # "This is equivalent to iterating over tests, calling addTest() for
+ # each element"
+ def test_addTests(self):
+ class Foo(unittest.TestCase):
+ def test_1(self): pass
+ def test_2(self): pass
+
+ test_1 = Foo('test_1')
+ test_2 = Foo('test_2')
+ inner_suite = unittest.TestSuite([test_2])
+
+ def gen():
+ yield test_1
+ yield test_2
+ yield inner_suite
+
+ suite_1 = unittest.TestSuite()
+ suite_1.addTests(gen())
+
+ self.assertEqual(list(suite_1), list(gen()))
+
+ # "This is equivalent to iterating over tests, calling addTest() for
+ # each element"
+ suite_2 = unittest.TestSuite()
+ for t in gen():
+ suite_2.addTest(t)
+
+ self.assertEqual(suite_1, suite_2)
+
+ # "Add all the tests from an iterable of TestCase and TestSuite
+ # instances to this test suite."
+ #
+ # What happens if it doesn't get an iterable?
+ def test_addTest__noniterable(self):
+ suite = unittest.TestSuite()
+
+ try:
+ suite.addTests(5)
+ except TypeError:
+ pass
+ else:
+ self.fail("Failed to raise TypeError")
+
+ def test_addTest__noncallable(self):
+ suite = unittest.TestSuite()
+ self.assertRaises(TypeError, suite.addTest, 5)
+
+ def test_addTest__casesuiteclass(self):
+ suite = unittest.TestSuite()
+ self.assertRaises(TypeError, suite.addTest, Test_TestSuite)
+ self.assertRaises(TypeError, suite.addTest, unittest.TestSuite)
+
+ def test_addTests__string(self):
+ suite = unittest.TestSuite()
+ self.assertRaises(TypeError, suite.addTests, "foo")
+
+
+class Test_FunctionTestCase(TestCase):
+
+ # "Return the number of tests represented by the this test object. For
+ # TestCase instances, this will always be 1"
+ def test_countTestCases(self):
+ test = unittest.FunctionTestCase(lambda: None)
+
+ self.assertEqual(test.countTestCases(), 1)
+
+ # "When a setUp() method is defined, the test runner will run that method
+ # prior to each test. Likewise, if a tearDown() method is defined, the
+ # test runner will invoke that method after each test. In the example,
+ # setUp() was used to create a fresh sequence for each test."
+ #
+ # Make sure the proper call order is maintained, even if setUp() raises
+ # an exception.
+ def test_run_call_order__error_in_setUp(self):
+ events = []
+ result = LoggingResult(events)
+
+ def setUp():
+ events.append('setUp')
+ raise RuntimeError('raised by setUp')
+
+ def test():
+ events.append('test')
+
+ def tearDown():
+ events.append('tearDown')
+
+ expected = ['startTest', 'setUp', 'addError', 'stopTest']
+ unittest.FunctionTestCase(test, setUp, tearDown).run(result)
+ self.assertEqual(events, expected)
+
+ # "When a setUp() method is defined, the test runner will run that method
+ # prior to each test. Likewise, if a tearDown() method is defined, the
+ # test runner will invoke that method after each test. In the example,
+ # setUp() was used to create a fresh sequence for each test."
+ #
+ # Make sure the proper call order is maintained, even if the test raises
+ # an error (as opposed to a failure).
+ def test_run_call_order__error_in_test(self):
+ events = []
+ result = LoggingResult(events)
+
+ def setUp():
+ events.append('setUp')
+
+ def test():
+ events.append('test')
+ raise RuntimeError('raised by test')
+
+ def tearDown():
+ events.append('tearDown')
+
+ expected = ['startTest', 'setUp', 'test', 'addError', 'tearDown',
+ 'stopTest']
+ unittest.FunctionTestCase(test, setUp, tearDown).run(result)
+ self.assertEqual(events, expected)
+
+ # "When a setUp() method is defined, the test runner will run that method
+ # prior to each test. Likewise, if a tearDown() method is defined, the
+ # test runner will invoke that method after each test. In the example,
+ # setUp() was used to create a fresh sequence for each test."
+ #
+ # Make sure the proper call order is maintained, even if the test signals
+ # a failure (as opposed to an error).
+ def test_run_call_order__failure_in_test(self):
+ events = []
+ result = LoggingResult(events)
+
+ def setUp():
+ events.append('setUp')
+
+ def test():
+ events.append('test')
+ self.fail('raised by test')
+
+ def tearDown():
+ events.append('tearDown')
+
+ expected = ['startTest', 'setUp', 'test', 'addFailure', 'tearDown',
+ 'stopTest']
+ unittest.FunctionTestCase(test, setUp, tearDown).run(result)
+ self.assertEqual(events, expected)
+
+ # "When a setUp() method is defined, the test runner will run that method
+ # prior to each test. Likewise, if a tearDown() method is defined, the
+ # test runner will invoke that method after each test. In the example,
+ # setUp() was used to create a fresh sequence for each test."
+ #
+ # Make sure the proper call order is maintained, even if tearDown() raises
+ # an exception.
+ def test_run_call_order__error_in_tearDown(self):
+ events = []
+ result = LoggingResult(events)
+
+ def setUp():
+ events.append('setUp')
+
+ def test():
+ events.append('test')
+
+ def tearDown():
+ events.append('tearDown')
+ raise RuntimeError('raised by tearDown')
+
+ expected = ['startTest', 'setUp', 'test', 'tearDown', 'addError',
+ 'stopTest']
+ unittest.FunctionTestCase(test, setUp, tearDown).run(result)
+ self.assertEqual(events, expected)
+
+ # "Return a string identifying the specific test case."
+ #
+ # Because of the vague nature of the docs, I'm not going to lock this
+ # test down too much. Really all that can be asserted is that the id()
+ # will be a string (either 8-byte or unicode -- again, because the docs
+ # just say "string")
+ def test_id(self):
+ test = unittest.FunctionTestCase(lambda: None)
+
+ self.failUnless(isinstance(test.id(), basestring))
+
+ # "Returns a one-line description of the test, or None if no description
+ # has been provided. The default implementation of this method returns
+ # the first line of the test method's docstring, if available, or None."
+ def test_shortDescription__no_docstring(self):
+ test = unittest.FunctionTestCase(lambda: None)
+
+ self.assertEqual(test.shortDescription(), None)
+
+ # "Returns a one-line description of the test, or None if no description
+ # has been provided. The default implementation of this method returns
+ # the first line of the test method's docstring, if available, or None."
+ def test_shortDescription__singleline_docstring(self):
+ desc = "this tests foo"
+ test = unittest.FunctionTestCase(lambda: None, description=desc)
+
+ self.assertEqual(test.shortDescription(), "this tests foo")
+
+class Test_TestResult(TestCase):
+ # Note: there are not separate tests for TestResult.wasSuccessful(),
+ # TestResult.errors, TestResult.failures, TestResult.testsRun or
+ # TestResult.shouldStop because these only have meaning in terms of
+ # other TestResult methods.
+ #
+ # Accordingly, tests for the aforenamed attributes are incorporated
+ # in with the tests for the defining methods.
+ ################################################################
+
+ def test_init(self):
+ result = unittest.TestResult()
+
+ self.failUnless(result.wasSuccessful())
+ self.assertEqual(len(result.errors), 0)
+ self.assertEqual(len(result.failures), 0)
+ self.assertEqual(result.testsRun, 0)
+ self.assertEqual(result.shouldStop, False)
+
+ # "This method can be called to signal that the set of tests being
+ # run should be aborted by setting the TestResult's shouldStop
+ # attribute to True."
+ def test_stop(self):
+ result = unittest.TestResult()
+
+ result.stop()
+
+ self.assertEqual(result.shouldStop, True)
+
+ # "Called when the test case test is about to be run. The default
+ # implementation simply increments the instance's testsRun counter."
+ def test_startTest(self):
+ class Foo(unittest.TestCase):
+ def test_1(self):
+ pass
+
+ test = Foo('test_1')
+
+ result = unittest.TestResult()
+
+ result.startTest(test)
+
+ self.failUnless(result.wasSuccessful())
+ self.assertEqual(len(result.errors), 0)
+ self.assertEqual(len(result.failures), 0)
+ self.assertEqual(result.testsRun, 1)
+ self.assertEqual(result.shouldStop, False)
+
+ result.stopTest(test)
+
+ # "Called after the test case test has been executed, regardless of
+ # the outcome. The default implementation does nothing."
+ def test_stopTest(self):
+ class Foo(unittest.TestCase):
+ def test_1(self):
+ pass
+
+ test = Foo('test_1')
+
+ result = unittest.TestResult()
+
+ result.startTest(test)
+
+ self.failUnless(result.wasSuccessful())
+ self.assertEqual(len(result.errors), 0)
+ self.assertEqual(len(result.failures), 0)
+ self.assertEqual(result.testsRun, 1)
+ self.assertEqual(result.shouldStop, False)
+
+ result.stopTest(test)
+
+ # Same tests as above; make sure nothing has changed
+ self.failUnless(result.wasSuccessful())
+ self.assertEqual(len(result.errors), 0)
+ self.assertEqual(len(result.failures), 0)
+ self.assertEqual(result.testsRun, 1)
+ self.assertEqual(result.shouldStop, False)
+
+ # "addSuccess(test)"
+ # ...
+ # "Called when the test case test succeeds"
+ # ...
+ # "wasSuccessful() - Returns True if all tests run so far have passed,
+ # otherwise returns False"
+ # ...
+ # "testsRun - The total number of tests run so far."
+ # ...
+ # "errors - A list containing 2-tuples of TestCase instances and
+ # formatted tracebacks. Each tuple represents a test which raised an
+ # unexpected exception. Contains formatted
+ # tracebacks instead of sys.exc_info() results."
+ # ...
+ # "failures - A list containing 2-tuples of TestCase instances and
+ # formatted tracebacks. Each tuple represents a test where a failure was
+ # explicitly signalled using the TestCase.fail*() or TestCase.assert*()
+ # methods. Contains formatted tracebacks instead
+ # of sys.exc_info() results."
+ def test_addSuccess(self):
+ class Foo(unittest.TestCase):
+ def test_1(self):
+ pass
+
+ test = Foo('test_1')
+
+ result = unittest.TestResult()
+
+ result.startTest(test)
+ result.addSuccess(test)
+ result.stopTest(test)
+
+ self.failUnless(result.wasSuccessful())
+ self.assertEqual(len(result.errors), 0)
+ self.assertEqual(len(result.failures), 0)
+ self.assertEqual(result.testsRun, 1)
+ self.assertEqual(result.shouldStop, False)
+
+ # "addFailure(test, err)"
+ # ...
+ # "Called when the test case test signals a failure. err is a tuple of
+ # the form returned by sys.exc_info(): (type, value, traceback)"
+ # ...
+ # "wasSuccessful() - Returns True if all tests run so far have passed,
+ # otherwise returns False"
+ # ...
+ # "testsRun - The total number of tests run so far."
+ # ...
+ # "errors - A list containing 2-tuples of TestCase instances and
+ # formatted tracebacks. Each tuple represents a test which raised an
+ # unexpected exception. Contains formatted
+ # tracebacks instead of sys.exc_info() results."
+ # ...
+ # "failures - A list containing 2-tuples of TestCase instances and
+ # formatted tracebacks. Each tuple represents a test where a failure was
+ # explicitly signalled using the TestCase.fail*() or TestCase.assert*()
+ # methods. Contains formatted tracebacks instead
+ # of sys.exc_info() results."
+ def test_addFailure(self):
+ import sys
+
+ class Foo(unittest.TestCase):
+ def test_1(self):
+ pass
+
+ test = Foo('test_1')
+ try:
+ test.fail("foo")
+ except:
+ exc_info_tuple = sys.exc_info()
+
+ result = unittest.TestResult()
+
+ result.startTest(test)
+ result.addFailure(test, exc_info_tuple)
+ result.stopTest(test)
+
+ self.failIf(result.wasSuccessful())
+ self.assertEqual(len(result.errors), 0)
+ self.assertEqual(len(result.failures), 1)
+ self.assertEqual(result.testsRun, 1)
+ self.assertEqual(result.shouldStop, False)
+
+ test_case, formatted_exc = result.failures[0]
+ self.failUnless(test_case is test)
+ self.failUnless(isinstance(formatted_exc, str))
+
+ # "addError(test, err)"
+ # ...
+ # "Called when the test case test raises an unexpected exception err
+ # is a tuple of the form returned by sys.exc_info():
+ # (type, value, traceback)"
+ # ...
+ # "wasSuccessful() - Returns True if all tests run so far have passed,
+ # otherwise returns False"
+ # ...
+ # "testsRun - The total number of tests run so far."
+ # ...
+ # "errors - A list containing 2-tuples of TestCase instances and
+ # formatted tracebacks. Each tuple represents a test which raised an
+ # unexpected exception. Contains formatted
+ # tracebacks instead of sys.exc_info() results."
+ # ...
+ # "failures - A list containing 2-tuples of TestCase instances and
+ # formatted tracebacks. Each tuple represents a test where a failure was
+ # explicitly signalled using the TestCase.fail*() or TestCase.assert*()
+ # methods. Contains formatted tracebacks instead
+ # of sys.exc_info() results."
+ def test_addError(self):
+ import sys
+
+ class Foo(unittest.TestCase):
+ def test_1(self):
+ pass
+
+ test = Foo('test_1')
+ try:
+ raise TypeError()
+ except:
+ exc_info_tuple = sys.exc_info()
+
+ result = unittest.TestResult()
+
+ result.startTest(test)
+ result.addError(test, exc_info_tuple)
+ result.stopTest(test)
+
+ self.failIf(result.wasSuccessful())
+ self.assertEqual(len(result.errors), 1)
+ self.assertEqual(len(result.failures), 0)
+ self.assertEqual(result.testsRun, 1)
+ self.assertEqual(result.shouldStop, False)
+
+ test_case, formatted_exc = result.errors[0]
+ self.failUnless(test_case is test)
+ self.failUnless(isinstance(formatted_exc, str))
+
+### Support code for Test_TestCase
+################################################################
+
+class Foo(unittest.TestCase):
+ def runTest(self): pass
+ def test1(self): pass
+
+class Bar(Foo):
+ def test2(self): pass
+
+################################################################
+### /Support code for Test_TestCase
+
+class Test_TestCase(TestCase, TestEquality, TestHashing):
+
+ ### Set up attributes used by inherited tests
+ ################################################################
+
+ # Used by TestHashing.test_hash and TestEquality.test_eq
+ eq_pairs = [(Foo('test1'), Foo('test1'))]
+
+ # Used by TestEquality.test_ne
+ ne_pairs = [(Foo('test1'), Foo('runTest'))
+ ,(Foo('test1'), Bar('test1'))
+ ,(Foo('test1'), Bar('test2'))]
+
+ ################################################################
+ ### /Set up attributes used by inherited tests
+
+
+ # "class TestCase([methodName])"
+ # ...
+ # "Each instance of TestCase will run a single test method: the
+ # method named methodName."
+ # ...
+ # "methodName defaults to "runTest"."
+ #
+ # Make sure it really is optional, and that it defaults to the proper
+ # thing.
+ def test_init__no_test_name(self):
+ class Test(unittest.TestCase):
+ def runTest(self): raise MyException()
+ def test(self): pass
+
+ self.assertEqual(Test().id()[-13:], '.Test.runTest')
+
+ # "class TestCase([methodName])"
+ # ...
+ # "Each instance of TestCase will run a single test method: the
+ # method named methodName."
+ def test_init__test_name__valid(self):
+ class Test(unittest.TestCase):
+ def runTest(self): raise MyException()
+ def test(self): pass
+
+ self.assertEqual(Test('test').id()[-10:], '.Test.test')
+
+ # "class TestCase([methodName])"
+ # ...
+ # "Each instance of TestCase will run a single test method: the
+ # method named methodName."
+ def test_init__test_name__invalid(self):
+ class Test(unittest.TestCase):
+ def runTest(self): raise MyException()
+ def test(self): pass
+
+ try:
+ Test('testfoo')
+ except ValueError:
+ pass
+ else:
+ self.fail("Failed to raise ValueError")
+
+ # "Return the number of tests represented by the this test object. For
+ # TestCase instances, this will always be 1"
+ def test_countTestCases(self):
+ class Foo(unittest.TestCase):
+ def test(self): pass
+
+ self.assertEqual(Foo('test').countTestCases(), 1)
+
+ # "Return the default type of test result object to be used to run this
+ # test. For TestCase instances, this will always be
+ # unittest.TestResult; subclasses of TestCase should
+ # override this as necessary."
+ def test_defaultTestResult(self):
+ class Foo(unittest.TestCase):
+ def runTest(self):
+ pass
+
+ result = Foo().defaultTestResult()
+ self.assertEqual(type(result), unittest.TestResult)
+
+ # "When a setUp() method is defined, the test runner will run that method
+ # prior to each test. Likewise, if a tearDown() method is defined, the
+ # test runner will invoke that method after each test. In the example,
+ # setUp() was used to create a fresh sequence for each test."
+ #
+ # Make sure the proper call order is maintained, even if setUp() raises
+ # an exception.
+ def test_run_call_order__error_in_setUp(self):
+ events = []
+ result = LoggingResult(events)
+
+ class Foo(unittest.TestCase):
+ def setUp(self):
+ events.append('setUp')
+ raise RuntimeError('raised by Foo.setUp')
+
+ def test(self):
+ events.append('test')
+
+ def tearDown(self):
+ events.append('tearDown')
+
+ Foo('test').run(result)
+ expected = ['startTest', 'setUp', 'addError', 'stopTest']
+ self.assertEqual(events, expected)
+
+ # "When a setUp() method is defined, the test runner will run that method
+ # prior to each test. Likewise, if a tearDown() method is defined, the
+ # test runner will invoke that method after each test. In the example,
+ # setUp() was used to create a fresh sequence for each test."
+ #
+ # Make sure the proper call order is maintained, even if the test raises
+ # an error (as opposed to a failure).
+ def test_run_call_order__error_in_test(self):
+ events = []
+ result = LoggingResult(events)
+
+ class Foo(unittest.TestCase):
+ def setUp(self):
+ events.append('setUp')
+
+ def test(self):
+ events.append('test')
+ raise RuntimeError('raised by Foo.test')
+
+ def tearDown(self):
+ events.append('tearDown')
+
+ expected = ['startTest', 'setUp', 'test', 'addError', 'tearDown',
+ 'stopTest']
+ Foo('test').run(result)
+ self.assertEqual(events, expected)
+
+ # "When a setUp() method is defined, the test runner will run that method
+ # prior to each test. Likewise, if a tearDown() method is defined, the
+ # test runner will invoke that method after each test. In the example,
+ # setUp() was used to create a fresh sequence for each test."
+ #
+ # Make sure the proper call order is maintained, even if the test signals
+ # a failure (as opposed to an error).
+ def test_run_call_order__failure_in_test(self):
+ events = []
+ result = LoggingResult(events)
+
+ class Foo(unittest.TestCase):
+ def setUp(self):
+ events.append('setUp')
+
+ def test(self):
+ events.append('test')
+ self.fail('raised by Foo.test')
+
+ def tearDown(self):
+ events.append('tearDown')
+
+ expected = ['startTest', 'setUp', 'test', 'addFailure', 'tearDown',
+ 'stopTest']
+ Foo('test').run(result)
+ self.assertEqual(events, expected)
+
+ # "When a setUp() method is defined, the test runner will run that method
+ # prior to each test. Likewise, if a tearDown() method is defined, the
+ # test runner will invoke that method after each test. In the example,
+ # setUp() was used to create a fresh sequence for each test."
+ #
+ # Make sure the proper call order is maintained, even if tearDown() raises
+ # an exception.
+ def test_run_call_order__error_in_tearDown(self):
+ events = []
+ result = LoggingResult(events)
+
+ class Foo(unittest.TestCase):
+ def setUp(self):
+ events.append('setUp')
+
+ def test(self):
+ events.append('test')
+
+ def tearDown(self):
+ events.append('tearDown')
+ raise RuntimeError('raised by Foo.tearDown')
+
+ Foo('test').run(result)
+ expected = ['startTest', 'setUp', 'test', 'tearDown', 'addError',
+ 'stopTest']
+ self.assertEqual(events, expected)
+
+ # "This class attribute gives the exception raised by the test() method.
+ # If a test framework needs to use a specialized exception, possibly to
+ # carry additional information, it must subclass this exception in
+ # order to ``play fair'' with the framework. The initial value of this
+ # attribute is AssertionError"
+ def test_failureException__default(self):
+ class Foo(unittest.TestCase):
+ def test(self):
+ pass
+
+ self.failUnless(Foo('test').failureException is AssertionError)
+
+ # "This class attribute gives the exception raised by the test() method.
+ # If a test framework needs to use a specialized exception, possibly to
+ # carry additional information, it must subclass this exception in
+ # order to ``play fair'' with the framework."
+ #
+ # Make sure TestCase.run() respects the designated failureException
+ def test_failureException__subclassing__explicit_raise(self):
+ events = []
+ result = LoggingResult(events)
+
+ class Foo(unittest.TestCase):
+ def test(self):
+ raise RuntimeError()
+
+ failureException = RuntimeError
+
+ self.failUnless(Foo('test').failureException is RuntimeError)
+
+
+ Foo('test').run(result)
+ expected = ['startTest', 'addFailure', 'stopTest']
+ self.assertEqual(events, expected)
+
+ # "This class attribute gives the exception raised by the test() method.
+ # If a test framework needs to use a specialized exception, possibly to
+ # carry additional information, it must subclass this exception in
+ # order to ``play fair'' with the framework."
+ #
+ # Make sure TestCase.run() respects the designated failureException
+ def test_failureException__subclassing__implicit_raise(self):
+ events = []
+ result = LoggingResult(events)
+
+ class Foo(unittest.TestCase):
+ def test(self):
+ self.fail("foo")
+
+ failureException = RuntimeError
+
+ self.failUnless(Foo('test').failureException is RuntimeError)
+
+
+ Foo('test').run(result)
+ expected = ['startTest', 'addFailure', 'stopTest']
+ self.assertEqual(events, expected)
+
+ # "The default implementation does nothing."
+ def test_setUp(self):
+ class Foo(unittest.TestCase):
+ def runTest(self):
+ pass
+
+ # ... and nothing should happen
+ Foo().setUp()
+
+ # "The default implementation does nothing."
+ def test_tearDown(self):
+ class Foo(unittest.TestCase):
+ def runTest(self):
+ pass
+
+ # ... and nothing should happen
+ Foo().tearDown()
+
+ # "Return a string identifying the specific test case."
+ #
+ # Because of the vague nature of the docs, I'm not going to lock this
+ # test down too much. Really all that can be asserted is that the id()
+ # will be a string (either 8-byte or unicode -- again, because the docs
+ # just say "string")
+ def test_id(self):
+ class Foo(unittest.TestCase):
+ def runTest(self):
+ pass
+
+ self.failUnless(isinstance(Foo().id(), basestring))
+
+ # "Returns a one-line description of the test, or None if no description
+ # has been provided. The default implementation of this method returns
+ # the first line of the test method's docstring, if available, or None."
+ def test_shortDescription__no_docstring(self):
+ class Foo(unittest.TestCase):
+ def runTest(self):
+ pass
+
+ self.assertEqual(Foo().shortDescription(), None)
+
+ # "Returns a one-line description of the test, or None if no description
+ # has been provided. The default implementation of this method returns
+ # the first line of the test method's docstring, if available, or None."
+ def test_shortDescription__singleline_docstring(self):
+ class Foo(unittest.TestCase):
+ def runTest(self):
+ "this tests foo"
+ pass
+
+ self.assertEqual(Foo().shortDescription(), "this tests foo")
+
+ # "Returns a one-line description of the test, or None if no description
+ # has been provided. The default implementation of this method returns
+ # the first line of the test method's docstring, if available, or None."
+ def test_shortDescription__multiline_docstring(self):
+ class Foo(unittest.TestCase):
+ def runTest(self):
+ """this tests foo
+ blah, bar and baz are also tested"""
+ pass
+
+ self.assertEqual(Foo().shortDescription(), "this tests foo")
+
+ # "If result is omitted or None, a temporary result object is created
+ # and used, but is not made available to the caller"
+ def test_run__uses_defaultTestResult(self):
+ events = []
+
+ class Foo(unittest.TestCase):
+ def test(self):
+ events.append('test')
+
+ def defaultTestResult(self):
+ return LoggingResult(events)
-def test_TestSuite_iter():
- """
- >>> test1 = unittest.FunctionTestCase(lambda: None)
- >>> test2 = unittest.FunctionTestCase(lambda: None)
- >>> suite = unittest.TestSuite((test1, test2))
- >>> tests = []
- >>> for test in suite:
- ... tests.append(test)
- >>> tests == [test1, test2]
- True
- """
+ # Make run() find a result object on its own
+ Foo('test').run()
+ expected = ['startTest', 'test', 'stopTest']
+ self.assertEqual(events, expected)
######################################################################
## Main
######################################################################
def test_main():
- from test import test_support, test_unittest
- test_support.run_doctest(test_unittest, verbosity=True)
+ test_support.run_unittest(Test_TestCase, Test_TestLoader,
+ Test_TestSuite, Test_TestResult, Test_FunctionTestCase)
-if __name__ == '__main__':
+if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_unpack.py b/Lib/test/test_unpack.py
index cd48689..75033ed 100644
--- a/Lib/test/test_unpack.py
+++ b/Lib/test/test_unpack.py
@@ -55,7 +55,7 @@ Unpacking non-sequence
>>> a, b, c = 7
Traceback (most recent call last):
...
- TypeError: unpack non-sequence
+ TypeError: 'int' object is not iterable
Unpacking tuple of wrong size
diff --git a/Lib/test/test_urllib.py b/Lib/test/test_urllib.py
index 16c612e..3a37525 100644
--- a/Lib/test/test_urllib.py
+++ b/Lib/test/test_urllib.py
@@ -122,6 +122,15 @@ class urlopen_HttpTests(unittest.TestCase):
finally:
self.unfakehttp()
+ def test_empty_socket(self):
+ """urlopen() raises IOError if the underlying socket does not send any
+ data. (#1680230) """
+ self.fakehttp('')
+ try:
+ self.assertRaises(IOError, urllib.urlopen, 'http://something')
+ finally:
+ self.unfakehttp()
+
class urlretrieve_FileTests(unittest.TestCase):
"""Test urllib.urlretrieve() on local files"""
diff --git a/Lib/test/test_urllib2.py b/Lib/test/test_urllib2.py
index 6187dad..10d8c46 100644
--- a/Lib/test/test_urllib2.py
+++ b/Lib/test/test_urllib2.py
@@ -625,11 +625,11 @@ class HandlerTests(unittest.TestCase):
for url in [
"file://localhost:80%s" % urlpath,
-# XXXX bug: these fail with socket.gaierror, should be URLError
-## "file://%s:80%s/%s" % (socket.gethostbyname('localhost'),
-## os.getcwd(), TESTFN),
-## "file://somerandomhost.ontheinternet.com%s/%s" %
-## (os.getcwd(), TESTFN),
+ "file:///file_does_not_exist.txt",
+ "file://%s:80%s/%s" % (socket.gethostbyname('localhost'),
+ os.getcwd(), TESTFN),
+ "file://somerandomhost.ontheinternet.com%s/%s" %
+ (os.getcwd(), TESTFN),
]:
try:
f = open(TESTFN, "wb")
@@ -765,16 +765,24 @@ class HandlerTests(unittest.TestCase):
url = "http://example.com/"
req = Request(url)
- # 200 OK is passed through
+ # all 2xx are passed through
r = MockResponse(200, "OK", {}, "", url)
newr = h.http_response(req, r)
self.assert_(r is newr)
self.assert_(not hasattr(o, "proto")) # o.error not called
+ r = MockResponse(202, "Accepted", {}, "", url)
+ newr = h.http_response(req, r)
+ self.assert_(r is newr)
+ self.assert_(not hasattr(o, "proto")) # o.error not called
+ r = MockResponse(206, "Partial content", {}, "", url)
+ newr = h.http_response(req, r)
+ self.assert_(r is newr)
+ self.assert_(not hasattr(o, "proto")) # o.error not called
# anything else calls o.error (and MockOpener returns None, here)
- r = MockResponse(201, "Created", {}, "", url)
+ r = MockResponse(502, "Bad gateway", {}, "", url)
self.assert_(h.http_response(req, r) is None)
self.assertEqual(o.proto, "http") # o.error called
- self.assertEqual(o.args, (req, r, 201, "Created", {}))
+ self.assertEqual(o.args, (req, r, 502, "Bad gateway", {}))
def test_cookies(self):
cj = MockCookieJar()
diff --git a/Lib/test/test_urllib2net.py b/Lib/test/test_urllib2net.py
index 60d5f48..a52c3dd 100644
--- a/Lib/test/test_urllib2net.py
+++ b/Lib/test/test_urllib2net.py
@@ -264,7 +264,8 @@ class OtherNetworkTests(unittest.TestCase):
(expected_err, url, req, err))
self.assert_(isinstance(err, expected_err), msg)
else:
- buf = f.read()
+ with test_support.transient_internet():
+ buf = f.read()
f.close()
debug("read %d bytes" % len(buf))
debug("******** next url coming up...")
diff --git a/Lib/test/test_userdict.py b/Lib/test/test_userdict.py
index 500971b..8c72463 100644
--- a/Lib/test/test_userdict.py
+++ b/Lib/test/test_userdict.py
@@ -171,7 +171,7 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol):
except RuntimeError as err:
self.assertEqual(err.args, (42,))
else:
- self.fail_("e[42] didn't raise RuntimeError")
+ self.fail("e[42] didn't raise RuntimeError")
class F(UserDict.UserDict):
def __init__(self):
# An instance variable __missing__ should have no effect
@@ -183,7 +183,7 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol):
except KeyError as err:
self.assertEqual(err.args, (42,))
else:
- self.fail_("f[42] didn't raise KeyError")
+ self.fail("f[42] didn't raise KeyError")
class G(UserDict.UserDict):
pass
g = G()
@@ -192,7 +192,7 @@ class UserDictTest(mapping_tests.TestHashMappingProtocol):
except KeyError as err:
self.assertEqual(err.args, (42,))
else:
- self.fail_("g[42] didn't raise KeyError")
+ self.fail("g[42] didn't raise KeyError")
##########################
# Test Dict Mixin
diff --git a/Lib/test/test_warnings.py b/Lib/test/test_warnings.py
index a7ccb6b..283806f 100644
--- a/Lib/test/test_warnings.py
+++ b/Lib/test/test_warnings.py
@@ -3,95 +3,97 @@ import os
import unittest
from test import test_support
-# The warnings module isn't easily tested, because it relies on module
-# globals to store configuration information. setUp() and tearDown()
-# preserve the current settings to avoid bashing them while running tests.
-
-# To capture the warning messages, a replacement for showwarning() is
-# used to save warning information in a global variable.
-
-class WarningMessage:
- "Holds results of latest showwarning() call"
- pass
-
-def showwarning(message, category, filename, lineno, file=None):
- msg.message = str(message)
- msg.category = category.__name__
- msg.filename = os.path.basename(filename)
- msg.lineno = lineno
+import warning_tests
class TestModule(unittest.TestCase):
-
def setUp(self):
- global msg
- msg = WarningMessage()
- self._filters = warnings.filters[:]
- self._showwarning = warnings.showwarning
- warnings.showwarning = showwarning
- self.ignored = [w[2].__name__ for w in self._filters
+ self.ignored = [w[2].__name__ for w in warnings.filters
if w[0]=='ignore' and w[1] is None and w[3] is None]
- def tearDown(self):
- warnings.filters = self._filters[:]
- warnings.showwarning = self._showwarning
-
def test_warn_default_category(self):
- for i in range(4):
- text = 'multi %d' %i # Different text on each call
- warnings.warn(text)
- self.assertEqual(msg.message, text)
- self.assertEqual(msg.category, 'UserWarning')
+ with test_support.catch_warning() as w:
+ for i in range(4):
+ text = 'multi %d' %i # Different text on each call
+ warnings.warn(text)
+ self.assertEqual(str(w.message), text)
+ self.assert_(w.category is UserWarning)
def test_warn_specific_category(self):
- text = 'None'
- for category in [DeprecationWarning, FutureWarning,
- PendingDeprecationWarning, RuntimeWarning,
- SyntaxWarning, UserWarning, Warning]:
- if category.__name__ in self.ignored:
- text = 'filtered out' + category.__name__
- warnings.warn(text, category)
- self.assertNotEqual(msg.message, text)
- else:
- text = 'unfiltered %s' % category.__name__
- warnings.warn(text, category)
- self.assertEqual(msg.message, text)
- self.assertEqual(msg.category, category.__name__)
+ with test_support.catch_warning() as w:
+ text = 'None'
+ for category in [DeprecationWarning, FutureWarning,
+ PendingDeprecationWarning, RuntimeWarning,
+ SyntaxWarning, UserWarning, Warning]:
+ if category.__name__ in self.ignored:
+ text = 'filtered out' + category.__name__
+ warnings.warn(text, category)
+ self.assertNotEqual(w.message, text)
+ else:
+ text = 'unfiltered %s' % category.__name__
+ warnings.warn(text, category)
+ self.assertEqual(str(w.message), text)
+ self.assert_(w.category is category)
def test_filtering(self):
+ with test_support.catch_warning() as w:
+ warnings.filterwarnings("error", "", Warning, "", 0)
+ self.assertRaises(UserWarning, warnings.warn, 'convert to error')
- warnings.filterwarnings("error", "", Warning, "", 0)
- self.assertRaises(UserWarning, warnings.warn, 'convert to error')
-
- warnings.resetwarnings()
- text = 'handle normally'
- warnings.warn(text)
- self.assertEqual(msg.message, text)
- self.assertEqual(msg.category, 'UserWarning')
+ warnings.resetwarnings()
+ text = 'handle normally'
+ warnings.warn(text)
+ self.assertEqual(str(w.message), text)
+ self.assert_(w.category is UserWarning)
- warnings.filterwarnings("ignore", "", Warning, "", 0)
- text = 'filtered out'
- warnings.warn(text)
- self.assertNotEqual(msg.message, text)
+ warnings.filterwarnings("ignore", "", Warning, "", 0)
+ text = 'filtered out'
+ warnings.warn(text)
+ self.assertNotEqual(str(w.message), text)
- warnings.resetwarnings()
- warnings.filterwarnings("error", "hex*", Warning, "", 0)
- self.assertRaises(UserWarning, warnings.warn, 'hex/oct')
- text = 'nonmatching text'
- warnings.warn(text)
- self.assertEqual(msg.message, text)
- self.assertEqual(msg.category, 'UserWarning')
+ warnings.resetwarnings()
+ warnings.filterwarnings("error", "hex*", Warning, "", 0)
+ self.assertRaises(UserWarning, warnings.warn, 'hex/oct')
+ text = 'nonmatching text'
+ warnings.warn(text)
+ self.assertEqual(str(w.message), text)
+ self.assert_(w.category is UserWarning)
def test_options(self):
# Uses the private _setoption() function to test the parsing
# of command-line warning arguments
- self.assertRaises(warnings._OptionError,
- warnings._setoption, '1:2:3:4:5:6')
- self.assertRaises(warnings._OptionError,
- warnings._setoption, 'bogus::Warning')
- self.assertRaises(warnings._OptionError,
- warnings._setoption, 'ignore:2::4:-5')
- warnings._setoption('error::Warning::0')
- self.assertRaises(UserWarning, warnings.warn, 'convert to error')
+ with test_support.guard_warnings_filter():
+ self.assertRaises(warnings._OptionError,
+ warnings._setoption, '1:2:3:4:5:6')
+ self.assertRaises(warnings._OptionError,
+ warnings._setoption, 'bogus::Warning')
+ self.assertRaises(warnings._OptionError,
+ warnings._setoption, 'ignore:2::4:-5')
+ warnings._setoption('error::Warning::0')
+ self.assertRaises(UserWarning, warnings.warn, 'convert to error')
+
+ def test_filename(self):
+ with test_support.catch_warning() as w:
+ warning_tests.inner("spam1")
+ self.assertEqual(os.path.basename(w.filename), "warning_tests.py")
+ warning_tests.outer("spam2")
+ self.assertEqual(os.path.basename(w.filename), "warning_tests.py")
+
+ def test_stacklevel(self):
+ # Test stacklevel argument
+ # make sure all messages are different, so the warning won't be skipped
+ with test_support.catch_warning() as w:
+ warning_tests.inner("spam3", stacklevel=1)
+ self.assertEqual(os.path.basename(w.filename), "warning_tests.py")
+ warning_tests.outer("spam4", stacklevel=1)
+ self.assertEqual(os.path.basename(w.filename), "warning_tests.py")
+
+ warning_tests.inner("spam5", stacklevel=2)
+ self.assertEqual(os.path.basename(w.filename), "test_warnings.py")
+ warning_tests.outer("spam6", stacklevel=2)
+ self.assertEqual(os.path.basename(w.filename), "warning_tests.py")
+
+ warning_tests.inner("spam7", stacklevel=9999)
+ self.assertEqual(os.path.basename(w.filename), "sys")
def test_main(verbose=None):
diff --git a/Lib/test/test_weakref.py b/Lib/test/test_weakref.py
index 46dd5ff..99a5178a 100644
--- a/Lib/test/test_weakref.py
+++ b/Lib/test/test_weakref.py
@@ -841,7 +841,7 @@ class MappingTestCase(TestBase):
items = dict.items()
for item in dict.items():
items.remove(item)
- self.assert_(len(items) == 0, "iteritems() did not touch all items")
+ self.assert_(len(items) == 0, "items() did not touch all items")
# key iterator, via __iter__():
keys = list(dict.keys())
@@ -1104,7 +1104,7 @@ None
... self.__counter += 1
... ob = (ob, self.__counter)
... return ob
-...
+...
>>> class A: # not in docs from here, just testing the ExtendedRef
... pass
...
diff --git a/Lib/test/test_wsgiref.py b/Lib/test/test_wsgiref.py
index d77e861..213c5cf 100755
--- a/Lib/test/test_wsgiref.py
+++ b/Lib/test/test_wsgiref.py
@@ -1,5 +1,5 @@
from __future__ import nested_scopes # Backward compat for 2.1
-from unittest import TestSuite, TestCase, makeSuite
+from unittest import TestCase
from wsgiref.util import setup_testing_defaults
from wsgiref.headers import Headers
from wsgiref.handlers import BaseHandler, BaseCGIHandler
@@ -11,6 +11,7 @@ from StringIO import StringIO
from SocketServer import BaseServer
import re, sys
+from test import test_support
class MockServer(WSGIServer):
"""Non-socket HTTP server"""
@@ -575,11 +576,7 @@ class HandlerTests(TestCase):
# This epilogue is needed for compatibility with the Python 2.5 regrtest module
def test_main():
- import unittest
- from test.test_support import run_suite
- run_suite(
- unittest.defaultTestLoader.loadTestsFromModule(sys.modules[__name__])
- )
+ test_support.run_unittest(__name__)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_zipfile.py b/Lib/test/test_zipfile.py
index d6f5534..a6500da 100644
--- a/Lib/test/test_zipfile.py
+++ b/Lib/test/test_zipfile.py
@@ -4,26 +4,30 @@ try:
except ImportError:
zlib = None
-import zipfile, os, unittest, sys, shutil
+import zipfile, os, unittest, sys, shutil, struct
from StringIO import StringIO
from tempfile import TemporaryFile
+from random import randint, random
+import test.test_support as support
from test.test_support import TESTFN, run_unittest
TESTFN2 = TESTFN + "2"
+FIXEDTEST_SIZE = 10
class TestsWithSourceFile(unittest.TestCase):
def setUp(self):
- line_gen = ("Test of zipfile line %d." % i for i in range(0, 1000))
- self.data = '\n'.join(line_gen)
+ self.line_gen = ("Zipfile test line %d. random float: %f" % (i, random())
+ for i in xrange(FIXEDTEST_SIZE))
+ self.data = '\n'.join(self.line_gen) + '\n'
# Make a source file with some lines
fp = open(TESTFN, "wb")
fp.write(self.data)
fp.close()
- def zipTest(self, f, compression):
+ def makeTestArchive(self, f, compression):
# Create the ZIP archive
zipfp = zipfile.ZipFile(f, "w", compression)
zipfp.write(TESTFN, "another"+os.extsep+"name")
@@ -31,6 +35,9 @@ class TestsWithSourceFile(unittest.TestCase):
zipfp.writestr("strfile", self.data)
zipfp.close()
+ def zipTest(self, f, compression):
+ self.makeTestArchive(f, compression)
+
# Read the ZIP archive
zipfp = zipfile.ZipFile(f, "r", compression)
self.assertEqual(zipfp.read(TESTFN), self.data)
@@ -85,22 +92,144 @@ class TestsWithSourceFile(unittest.TestCase):
# Check that testzip doesn't raise an exception
zipfp.testzip()
+ zipfp.close()
+ def testStored(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.zipTest(f, zipfile.ZIP_STORED)
+
+ def zipOpenTest(self, f, compression):
+ self.makeTestArchive(f, compression)
+ # Read the ZIP archive
+ zipfp = zipfile.ZipFile(f, "r", compression)
+ zipdata1 = []
+ zipopen1 = zipfp.open(TESTFN)
+ while 1:
+ read_data = zipopen1.read(256)
+ if not read_data:
+ break
+ zipdata1.append(read_data)
+
+ zipdata2 = []
+ zipopen2 = zipfp.open("another"+os.extsep+"name")
+ while 1:
+ read_data = zipopen2.read(256)
+ if not read_data:
+ break
+ zipdata2.append(read_data)
+
+ self.assertEqual(''.join(zipdata1), self.data)
+ self.assertEqual(''.join(zipdata2), self.data)
zipfp.close()
+ def testOpenStored(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.zipOpenTest(f, zipfile.ZIP_STORED)
+ def zipRandomOpenTest(self, f, compression):
+ self.makeTestArchive(f, compression)
+ # Read the ZIP archive
+ zipfp = zipfile.ZipFile(f, "r", compression)
+ zipdata1 = []
+ zipopen1 = zipfp.open(TESTFN)
+ while 1:
+ read_data = zipopen1.read(randint(1, 1024))
+ if not read_data:
+ break
+ zipdata1.append(read_data)
+
+ self.assertEqual(''.join(zipdata1), self.data)
+ zipfp.close()
- def testStored(self):
+ def testRandomOpenStored(self):
for f in (TESTFN2, TemporaryFile(), StringIO()):
- self.zipTest(f, zipfile.ZIP_STORED)
+ self.zipRandomOpenTest(f, zipfile.ZIP_STORED)
+
+ def zipReadlineTest(self, f, compression):
+ self.makeTestArchive(f, compression)
+
+ # Read the ZIP archive
+ zipfp = zipfile.ZipFile(f, "r")
+ zipopen = zipfp.open(TESTFN)
+ for line in self.line_gen:
+ linedata = zipopen.readline()
+ self.assertEqual(linedata, line + '\n')
+
+ zipfp.close()
+
+ def zipReadlinesTest(self, f, compression):
+ self.makeTestArchive(f, compression)
+
+ # Read the ZIP archive
+ zipfp = zipfile.ZipFile(f, "r")
+ ziplines = zipfp.open(TESTFN).readlines()
+ for line, zipline in zip(self.line_gen, ziplines):
+ self.assertEqual(zipline, line + '\n')
+
+ zipfp.close()
+
+ def zipIterlinesTest(self, f, compression):
+ self.makeTestArchive(f, compression)
+
+ # Read the ZIP archive
+ zipfp = zipfile.ZipFile(f, "r")
+ for line, zipline in zip(self.line_gen, zipfp.open(TESTFN)):
+ self.assertEqual(zipline, line + '\n')
+
+ zipfp.close()
+
+ def testReadlineStored(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.zipReadlineTest(f, zipfile.ZIP_STORED)
+
+ def testReadlinesStored(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.zipReadlinesTest(f, zipfile.ZIP_STORED)
+
+ def testIterlinesStored(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.zipIterlinesTest(f, zipfile.ZIP_STORED)
if zlib:
def testDeflated(self):
for f in (TESTFN2, TemporaryFile(), StringIO()):
self.zipTest(f, zipfile.ZIP_DEFLATED)
+ def testOpenDeflated(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.zipOpenTest(f, zipfile.ZIP_DEFLATED)
+
+ def testRandomOpenDeflated(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.zipRandomOpenTest(f, zipfile.ZIP_DEFLATED)
+
+ def testReadlineDeflated(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.zipReadlineTest(f, zipfile.ZIP_DEFLATED)
+
+ def testReadlinesDeflated(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.zipReadlinesTest(f, zipfile.ZIP_DEFLATED)
+
+ def testIterlinesDeflated(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.zipIterlinesTest(f, zipfile.ZIP_DEFLATED)
+
+ def testLowCompression(self):
+ # Checks for cases where compressed data is larger than original
+ # Create the ZIP archive
+ zipfp = zipfile.ZipFile(TESTFN2, "w", zipfile.ZIP_DEFLATED)
+ zipfp.writestr("strfile", '12')
+ zipfp.close()
+
+ # Get an open object for strfile
+ zipfp = zipfile.ZipFile(TESTFN2, "r", zipfile.ZIP_DEFLATED)
+ openobj = zipfp.open("strfile")
+ self.assertEqual(openobj.read(1), '1')
+ self.assertEqual(openobj.read(1), '2')
+
def testAbsoluteArcnames(self):
zipfp = zipfile.ZipFile(TESTFN2, "w", zipfile.ZIP_STORED)
zipfp.write(TESTFN, "/absolute")
@@ -110,7 +239,6 @@ class TestsWithSourceFile(unittest.TestCase):
self.assertEqual(zipfp.namelist(), ["absolute"])
zipfp.close()
-
def tearDown(self):
os.remove(TESTFN)
os.remove(TESTFN2)
@@ -123,7 +251,7 @@ class TestZip64InSmallFiles(unittest.TestCase):
self._limit = zipfile.ZIP64_LIMIT
zipfile.ZIP64_LIMIT = 5
- line_gen = ("Test of zipfile line %d." % i for i in range(0, 1000))
+ line_gen = ("Test of zipfile line %d." % i for i in range(0, FIXEDTEST_SIZE))
self.data = '\n'.join(line_gen)
# Make a source file with some lines
@@ -310,10 +438,10 @@ class OtherTests(unittest.TestCase):
def testCreateNonExistentFileForAppend(self):
if os.path.exists(TESTFN):
os.unlink(TESTFN)
-
+
filename = 'testfile.txt'
content = 'hello, world. this is some content.'
-
+
try:
zf = zipfile.ZipFile(TESTFN, 'a')
zf.writestr(filename, content)
@@ -326,9 +454,7 @@ class OtherTests(unittest.TestCase):
zf = zipfile.ZipFile(TESTFN, 'r')
self.assertEqual(zf.read(filename), content)
zf.close()
-
- os.unlink(TESTFN)
-
+
def testCloseErroneousFile(self):
# This test checks that the ZipFile constructor closes the file object
# it opens if there's an error in the file. If it doesn't, the traceback
@@ -342,7 +468,25 @@ class OtherTests(unittest.TestCase):
try:
zf = zipfile.ZipFile(TESTFN)
except zipfile.BadZipfile:
- os.unlink(TESTFN)
+ pass
+
+ def testIsZipErroneousFile(self):
+ # This test checks that the is_zipfile function correctly identifies
+ # a file that is not a zip file
+ fp = open(TESTFN, "w")
+ fp.write("this is not a legal zip file\n")
+ fp.close()
+ chk = zipfile.is_zipfile(TESTFN)
+ self.assert_(chk is False)
+
+ def testIsZipValidFile(self):
+ # This test checks that the is_zipfile function correctly identifies
+ # a file that is a zip file
+ zipf = zipfile.ZipFile(TESTFN, mode="w")
+ zipf.writestr("foo.txt", "O, for a Muse of Fire!")
+ zipf.close()
+ chk = zipfile.is_zipfile(TESTFN)
+ self.assert_(chk is True)
def testNonExistentFileRaisesIOError(self):
# make sure we don't raise an AttributeError when a partially-constructed
@@ -371,6 +515,9 @@ class OtherTests(unittest.TestCase):
# and report that the first file in the archive was corrupt.
self.assertRaises(RuntimeError, zipf.testzip)
+ def tearDown(self):
+ support.unlink(TESTFN)
+ support.unlink(TESTFN2)
class DecryptionTests(unittest.TestCase):
# This test checks that ZIP decryption works. Since the library does not
@@ -406,15 +553,265 @@ class DecryptionTests(unittest.TestCase):
def testBadPassword(self):
self.zip.setpassword("perl")
self.assertRaises(RuntimeError, self.zip.read, "test.txt")
-
+
def testGoodPassword(self):
self.zip.setpassword("python")
self.assertEquals(self.zip.read("test.txt"), self.plain)
+
+class TestsWithRandomBinaryFiles(unittest.TestCase):
+ def setUp(self):
+ datacount = randint(16, 64)*1024 + randint(1, 1024)
+ self.data = ''.join((struct.pack('<f', random()*randint(-1000, 1000)) for i in xrange(datacount)))
+
+ # Make a source file with some lines
+ fp = open(TESTFN, "wb")
+ fp.write(self.data)
+ fp.close()
+
+ def tearDown(self):
+ support.unlink(TESTFN)
+ support.unlink(TESTFN2)
+
+ def makeTestArchive(self, f, compression):
+ # Create the ZIP archive
+ zipfp = zipfile.ZipFile(f, "w", compression)
+ zipfp.write(TESTFN, "another"+os.extsep+"name")
+ zipfp.write(TESTFN, TESTFN)
+ zipfp.close()
+
+ def zipTest(self, f, compression):
+ self.makeTestArchive(f, compression)
+
+ # Read the ZIP archive
+ zipfp = zipfile.ZipFile(f, "r", compression)
+ testdata = zipfp.read(TESTFN)
+ self.assertEqual(len(testdata), len(self.data))
+ self.assertEqual(testdata, self.data)
+ self.assertEqual(zipfp.read("another"+os.extsep+"name"), self.data)
+ zipfp.close()
+
+ def testStored(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.zipTest(f, zipfile.ZIP_STORED)
+
+ def zipOpenTest(self, f, compression):
+ self.makeTestArchive(f, compression)
+
+ # Read the ZIP archive
+ zipfp = zipfile.ZipFile(f, "r", compression)
+ zipdata1 = []
+ zipopen1 = zipfp.open(TESTFN)
+ while 1:
+ read_data = zipopen1.read(256)
+ if not read_data:
+ break
+ zipdata1.append(read_data)
+
+ zipdata2 = []
+ zipopen2 = zipfp.open("another"+os.extsep+"name")
+ while 1:
+ read_data = zipopen2.read(256)
+ if not read_data:
+ break
+ zipdata2.append(read_data)
+
+ testdata1 = ''.join(zipdata1)
+ self.assertEqual(len(testdata1), len(self.data))
+ self.assertEqual(testdata1, self.data)
+
+ testdata2 = ''.join(zipdata2)
+ self.assertEqual(len(testdata1), len(self.data))
+ self.assertEqual(testdata1, self.data)
+ zipfp.close()
+
+ def testOpenStored(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.zipOpenTest(f, zipfile.ZIP_STORED)
+
+ def zipRandomOpenTest(self, f, compression):
+ self.makeTestArchive(f, compression)
+
+ # Read the ZIP archive
+ zipfp = zipfile.ZipFile(f, "r", compression)
+ zipdata1 = []
+ zipopen1 = zipfp.open(TESTFN)
+ while 1:
+ read_data = zipopen1.read(randint(1, 1024))
+ if not read_data:
+ break
+ zipdata1.append(read_data)
+
+ testdata = ''.join(zipdata1)
+ self.assertEqual(len(testdata), len(self.data))
+ self.assertEqual(testdata, self.data)
+ zipfp.close()
+
+ def testRandomOpenStored(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.zipRandomOpenTest(f, zipfile.ZIP_STORED)
+
+class TestsWithMultipleOpens(unittest.TestCase):
+ def setUp(self):
+ # Create the ZIP archive
+ zipfp = zipfile.ZipFile(TESTFN2, "w", zipfile.ZIP_DEFLATED)
+ zipfp.writestr('ones', '1'*FIXEDTEST_SIZE)
+ zipfp.writestr('twos', '2'*FIXEDTEST_SIZE)
+ zipfp.close()
+
+ def testSameFile(self):
+ # Verify that (when the ZipFile is in control of creating file objects)
+ # multiple open() calls can be made without interfering with each other.
+ zipf = zipfile.ZipFile(TESTFN2, mode="r")
+ zopen1 = zipf.open('ones')
+ zopen2 = zipf.open('ones')
+ data1 = zopen1.read(500)
+ data2 = zopen2.read(500)
+ data1 += zopen1.read(500)
+ data2 += zopen2.read(500)
+ self.assertEqual(data1, data2)
+ zipf.close()
+
+ def testDifferentFile(self):
+ # Verify that (when the ZipFile is in control of creating file objects)
+ # multiple open() calls can be made without interfering with each other.
+ zipf = zipfile.ZipFile(TESTFN2, mode="r")
+ zopen1 = zipf.open('ones')
+ zopen2 = zipf.open('twos')
+ data1 = zopen1.read(500)
+ data2 = zopen2.read(500)
+ data1 += zopen1.read(500)
+ data2 += zopen2.read(500)
+ self.assertEqual(data1, '1'*FIXEDTEST_SIZE)
+ self.assertEqual(data2, '2'*FIXEDTEST_SIZE)
+ zipf.close()
+
+ def testInterleaved(self):
+ # Verify that (when the ZipFile is in control of creating file objects)
+ # multiple open() calls can be made without interfering with each other.
+ zipf = zipfile.ZipFile(TESTFN2, mode="r")
+ zopen1 = zipf.open('ones')
+ data1 = zopen1.read(500)
+ zopen2 = zipf.open('twos')
+ data2 = zopen2.read(500)
+ data1 += zopen1.read(500)
+ data2 += zopen2.read(500)
+ self.assertEqual(data1, '1'*FIXEDTEST_SIZE)
+ self.assertEqual(data2, '2'*FIXEDTEST_SIZE)
+ zipf.close()
+
+ def tearDown(self):
+ os.remove(TESTFN2)
+
+
+class UniversalNewlineTests(unittest.TestCase):
+ def setUp(self):
+ self.line_gen = ["Test of zipfile line %d." % i for i in xrange(FIXEDTEST_SIZE)]
+ self.seps = ('\r', '\r\n', '\n')
+ self.arcdata, self.arcfiles = {}, {}
+ for n, s in enumerate(self.seps):
+ self.arcdata[s] = s.join(self.line_gen) + s
+ self.arcfiles[s] = '%s-%d' % (TESTFN, n)
+ file(self.arcfiles[s], "wb").write(self.arcdata[s])
+
+ def makeTestArchive(self, f, compression):
+ # Create the ZIP archive
+ zipfp = zipfile.ZipFile(f, "w", compression)
+ for fn in self.arcfiles.values():
+ zipfp.write(fn, fn)
+ zipfp.close()
+
+ def readTest(self, f, compression):
+ self.makeTestArchive(f, compression)
+
+ # Read the ZIP archive
+ zipfp = zipfile.ZipFile(f, "r")
+ for sep, fn in self.arcfiles.items():
+ zipdata = zipfp.open(fn, "rU").read()
+ self.assertEqual(self.arcdata[sep], zipdata)
+
+ zipfp.close()
+
+ def readlineTest(self, f, compression):
+ self.makeTestArchive(f, compression)
+
+ # Read the ZIP archive
+ zipfp = zipfile.ZipFile(f, "r")
+ for sep, fn in self.arcfiles.items():
+ zipopen = zipfp.open(fn, "rU")
+ for line in self.line_gen:
+ linedata = zipopen.readline()
+ self.assertEqual(linedata, line + '\n')
+
+ zipfp.close()
+
+ def readlinesTest(self, f, compression):
+ self.makeTestArchive(f, compression)
+
+ # Read the ZIP archive
+ zipfp = zipfile.ZipFile(f, "r")
+ for sep, fn in self.arcfiles.items():
+ ziplines = zipfp.open(fn, "rU").readlines()
+ for line, zipline in zip(self.line_gen, ziplines):
+ self.assertEqual(zipline, line + '\n')
+
+ zipfp.close()
+
+ def iterlinesTest(self, f, compression):
+ self.makeTestArchive(f, compression)
+
+ # Read the ZIP archive
+ zipfp = zipfile.ZipFile(f, "r")
+ for sep, fn in self.arcfiles.items():
+ for line, zipline in zip(self.line_gen, zipfp.open(fn, "rU")):
+ self.assertEqual(zipline, line + '\n')
+
+ zipfp.close()
+
+ def testReadStored(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.readTest(f, zipfile.ZIP_STORED)
+
+ def testReadlineStored(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.readlineTest(f, zipfile.ZIP_STORED)
+
+ def testReadlinesStored(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.readlinesTest(f, zipfile.ZIP_STORED)
+
+ def testIterlinesStored(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.iterlinesTest(f, zipfile.ZIP_STORED)
+
+ if zlib:
+ def testReadDeflated(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.readTest(f, zipfile.ZIP_DEFLATED)
+
+ def testReadlineDeflated(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.readlineTest(f, zipfile.ZIP_DEFLATED)
+
+ def testReadlinesDeflated(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.readlinesTest(f, zipfile.ZIP_DEFLATED)
+
+ def testIterlinesDeflated(self):
+ for f in (TESTFN2, TemporaryFile(), StringIO()):
+ self.iterlinesTest(f, zipfile.ZIP_DEFLATED)
+
+ def tearDown(self):
+ for sep, fn in self.arcfiles.items():
+ os.remove(fn)
+ support.unlink(TESTFN)
+ support.unlink(TESTFN2)
+
+
def test_main():
- run_unittest(TestsWithSourceFile, TestZip64InSmallFiles, OtherTests,
- PyZipFileTests, DecryptionTests)
- #run_unittest(TestZip64InSmallFiles)
+ run_unittest(TestsWithSourceFile, TestZip64InSmallFiles, OtherTests,
+ PyZipFileTests, DecryptionTests, TestsWithMultipleOpens,
+ UniversalNewlineTests, TestsWithRandomBinaryFiles)
if __name__ == "__main__":
test_main()
diff --git a/Lib/test/test_zlib.py b/Lib/test/test_zlib.py
index 0be9668..dc4a7d8 100644
--- a/Lib/test/test_zlib.py
+++ b/Lib/test/test_zlib.py
@@ -3,22 +3,6 @@ from test import test_support
import zlib
import random
-# print test_support.TESTFN
-
-def getbuf():
- # This was in the original. Avoid non-repeatable sources.
- # Left here (unused) in case something wants to be done with it.
- import imp
- try:
- t = imp.find_module('test_zlib')
- file = t[0]
- except ImportError:
- file = open(__file__)
- buf = file.read() * 8
- file.close()
- return buf
-
-
class ChecksumTestCase(unittest.TestCase):
# checksum test cases
@@ -461,21 +445,3 @@ def test_main():
if __name__ == "__main__":
test_main()
-
-def test(tests=''):
- if not tests: tests = 'o'
- testcases = []
- if 'k' in tests: testcases.append(ChecksumTestCase)
- if 'x' in tests: testcases.append(ExceptionTestCase)
- if 'c' in tests: testcases.append(CompressTestCase)
- if 'o' in tests: testcases.append(CompressObjectTestCase)
- test_support.run_unittest(*testcases)
-
-if False:
- import sys
- sys.path.insert(1, '/Py23Src/python/dist/src/Lib/test')
- import test_zlib as tz
- ts, ut = tz.test_support, tz.unittest
- su = ut.TestSuite()
- su.addTest(ut.makeSuite(tz.CompressTestCase))
- ts.run_suite(su)
diff --git a/Lib/test/testtar.tar b/Lib/test/testtar.tar
index 1f4493f..c4c82b8 100644
--- a/Lib/test/testtar.tar
+++ b/Lib/test/testtar.tar
Binary files differ
diff --git a/Lib/test/warning_tests.py b/Lib/test/warning_tests.py
new file mode 100644
index 0000000..d0519ef
--- /dev/null
+++ b/Lib/test/warning_tests.py
@@ -0,0 +1,9 @@
+# Helper module for testing the skipmodules argument of warnings.warn()
+
+import warnings
+
+def outer(message, stacklevel=1):
+ inner(message, stacklevel)
+
+def inner(message, stacklevel=1):
+ warnings.warn(message, stacklevel=stacklevel)
diff --git a/Lib/textwrap.py b/Lib/textwrap.py
index 0917d75..e6e089f 100644
--- a/Lib/textwrap.py
+++ b/Lib/textwrap.py
@@ -63,6 +63,8 @@ class TextWrapper:
break_long_words (default: true)
Break words longer than 'width'. If false, those words will not
be broken, and some lines might be longer than 'width'.
+ drop_whitespace (default: true)
+ Drop leading and trailing whitespace from lines.
"""
whitespace_trans = string.maketrans(_whitespace, ' ' * len(_whitespace))
@@ -98,7 +100,8 @@ class TextWrapper:
expand_tabs=True,
replace_whitespace=True,
fix_sentence_endings=False,
- break_long_words=True):
+ break_long_words=True,
+ drop_whitespace=True):
self.width = width
self.initial_indent = initial_indent
self.subsequent_indent = subsequent_indent
@@ -106,6 +109,7 @@ class TextWrapper:
self.replace_whitespace = replace_whitespace
self.fix_sentence_endings = fix_sentence_endings
self.break_long_words = break_long_words
+ self.drop_whitespace = drop_whitespace
# -- Private methods -----------------------------------------------
@@ -140,7 +144,7 @@ class TextWrapper:
'use', ' ', 'the', ' ', '-b', ' ', 'option!'
"""
chunks = self.wordsep_re.split(text)
- chunks = filter(None, chunks)
+ chunks = filter(None, chunks) # remove empty chunks
return chunks
def _fix_sentence_endings(self, chunks):
@@ -228,7 +232,7 @@ class TextWrapper:
# First chunk on line is whitespace -- drop it, unless this
# is the very beginning of the text (ie. no lines started yet).
- if chunks[-1].strip() == '' and lines:
+ if self.drop_whitespace and chunks[-1].strip() == '' and lines:
del chunks[-1]
while chunks:
@@ -249,7 +253,7 @@ class TextWrapper:
self._handle_long_word(chunks, cur_line, cur_len, width)
# If the last chunk on this line is all whitespace, drop it.
- if cur_line and cur_line[-1].strip() == '':
+ if self.drop_whitespace and cur_line and cur_line[-1].strip() == '':
del cur_line[-1]
# Convert current line back to a string and store it in list
diff --git a/Lib/timeit.py b/Lib/timeit.py
index e760c62..a1a9b36 100644
--- a/Lib/timeit.py
+++ b/Lib/timeit.py
@@ -90,6 +90,17 @@ def reindent(src, indent):
"""Helper to reindent a multi-line statement."""
return src.replace("\n", "\n" + " "*indent)
+def _template_func(setup, func):
+ """Create a timer function. Used if the "statement" is a callable."""
+ def inner(_it, _timer):
+ setup()
+ _t0 = _timer()
+ for _i in _it:
+ func()
+ _t1 = _timer()
+ return _t1 - _t0
+ return inner
+
class Timer:
"""Class for timing execution speed of small code snippets.
@@ -109,14 +120,32 @@ class Timer:
def __init__(self, stmt="pass", setup="pass", timer=default_timer):
"""Constructor. See class doc string."""
self.timer = timer
- stmt = reindent(stmt, 8)
- setup = reindent(setup, 4)
- src = template % {'stmt': stmt, 'setup': setup}
- self.src = src # Save for traceback display
- code = compile(src, dummy_src_name, "exec")
ns = {}
- exec(code, globals(), ns)
- self.inner = ns["inner"]
+ if isinstance(stmt, basestring):
+ stmt = reindent(stmt, 8)
+ if isinstance(setup, basestring):
+ setup = reindent(setup, 4)
+ src = template % {'stmt': stmt, 'setup': setup}
+ elif callable(setup):
+ src = template % {'stmt': stmt, 'setup': '_setup()'}
+ ns['_setup'] = setup
+ else:
+ raise ValueError("setup is neither a string nor callable")
+ self.src = src # Save for traceback display
+ code = compile(src, dummy_src_name, "exec")
+ exec(code, globals(), ns)
+ self.inner = ns["inner"]
+ elif callable(stmt):
+ self.src = None
+ if isinstance(setup, basestring):
+ _setup = setup
+ def setup():
+ exec(_setup, globals(), ns)
+ elif not callable(setup):
+ raise ValueError("setup is neither a string nor callable")
+ self.inner = _template_func(setup, stmt)
+ else:
+ raise ValueError("stmt is neither a string nor callable")
def print_exc(self, file=None):
"""Helper to print a traceback from the timed code.
@@ -136,10 +165,13 @@ class Timer:
sent; it defaults to sys.stderr.
"""
import linecache, traceback
- linecache.cache[dummy_src_name] = (len(self.src),
- None,
- self.src.split("\n"),
- dummy_src_name)
+ if self.src is not None:
+ linecache.cache[dummy_src_name] = (len(self.src),
+ None,
+ self.src.split("\n"),
+ dummy_src_name)
+ # else the source is already stored somewhere else
+
traceback.print_exc(file=file)
def timeit(self, number=default_number):
@@ -189,6 +221,16 @@ class Timer:
r.append(t)
return r
+def timeit(stmt="pass", setup="pass", timer=default_timer,
+ number=default_number):
+ """Convenience function to create Timer object and call timeit method."""
+ return Timer(stmt, setup, timer).timeit(number)
+
+def repeat(stmt="pass", setup="pass", timer=default_timer,
+ repeat=default_repeat, number=default_number):
+ """Convenience function to create Timer object and call repeat method."""
+ return Timer(stmt, setup, timer).repeat(repeat, number)
+
def main(args=None):
"""Main program, used when run as a script.
diff --git a/Lib/unittest.py b/Lib/unittest.py
index eab0372..12017dd 100644
--- a/Lib/unittest.py
+++ b/Lib/unittest.py
@@ -25,7 +25,7 @@ Simple usage:
Further information is available in the bundled documentation, and from
- http://pyunit.sourceforge.net/
+ http://docs.python.org/lib/module-unittest.html
Copyright (c) 1999-2003 Steve Purcell
This module is free software, and you may redistribute it and/or modify
@@ -104,7 +104,7 @@ class TestResult:
self.failures = []
self.errors = []
self.testsRun = 0
- self.shouldStop = 0
+ self.shouldStop = False
def startTest(self, test):
"Called when the given test is about to be run"
@@ -232,6 +232,18 @@ class TestCase:
def id(self):
return "%s.%s" % (_strclass(self.__class__), self._testMethodName)
+ def __eq__(self, other):
+ if type(self) is not type(other):
+ return False
+
+ return self._testMethodName == other._testMethodName
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __hash__(self):
+ return hash((type(self), self._testMethodName))
+
def __str__(self):
return "%s (%s)" % (self._testMethodName, _strclass(self.__class__))
@@ -288,10 +300,7 @@ class TestCase:
minimised; usually the top level of the traceback frame is not
needed.
"""
- exctype, excvalue, tb = sys.exc_info()
- if sys.platform[:4] == 'java': ## tracebacks look different in Jython
- return (exctype, excvalue, tb)
- return (exctype, excvalue, tb)
+ return sys.exc_info()
def fail(self, msg=None):
"""Fail immediately, with the given message."""
@@ -398,6 +407,14 @@ class TestSuite:
__str__ = __repr__
+ def __eq__(self, other):
+ if type(self) is not type(other):
+ return False
+ return self._tests == other._tests
+
+ def __ne__(self, other):
+ return not self == other
+
def __iter__(self):
return iter(self._tests)
@@ -408,9 +425,18 @@ class TestSuite:
return cases
def addTest(self, test):
+ # sanity checks
+ if not callable(test):
+ raise TypeError("the test to add must be callable")
+ if (isinstance(test, (type, types.ClassType)) and
+ issubclass(test, (TestCase, TestSuite))):
+ raise TypeError("TestCases and TestSuites must be instantiated "
+ "before passing them to addTest()")
self._tests.append(test)
def addTests(self, tests):
+ if isinstance(tests, basestring):
+ raise TypeError("tests must be an iterable of tests, not a string")
for test in tests:
self.addTest(test)
@@ -433,7 +459,7 @@ class FunctionTestCase(TestCase):
"""A test case that wraps a test function.
This is useful for slipping pre-existing test functions into the
- PyUnit framework. Optionally, set-up and tidy-up functions can be
+ unittest framework. Optionally, set-up and tidy-up functions can be
supplied. As with TestCase, the tidy-up ('tearDown') function will
always be called if the set-up ('setUp') function ran successfully.
"""
@@ -460,6 +486,22 @@ class FunctionTestCase(TestCase):
def id(self):
return self.__testFunc.__name__
+ def __eq__(self, other):
+ if type(self) is not type(other):
+ return False
+
+ return self.__setUpFunc == other.__setUpFunc and \
+ self.__tearDownFunc == other.__tearDownFunc and \
+ self.__testFunc == other.__testFunc and \
+ self.__description == other.__description
+
+ def __ne__(self, other):
+ return not self == other
+
+ def __hash__(self):
+ return hash((type(self), self.__setUpFunc, self.__tearDownFunc,
+ self.__testFunc, self.__description))
+
def __str__(self):
return "%s (%s)" % (_strclass(self.__class__), self.__testFunc.__name__)
@@ -479,7 +521,7 @@ class FunctionTestCase(TestCase):
class TestLoader:
"""This class is responsible for loading tests according to various
- criteria and returning them wrapped in a Test
+ criteria and returning them wrapped in a TestSuite
"""
testMethodPrefix = 'test'
sortTestMethodsUsing = cmp
@@ -533,18 +575,23 @@ class TestLoader:
elif (isinstance(obj, (type, types.ClassType)) and
issubclass(obj, TestCase)):
return self.loadTestsFromTestCase(obj)
- elif type(obj) == types.UnboundMethodType:
- return parent(obj.__name__)
+ elif (type(obj) == types.UnboundMethodType and
+ isinstance(parent, (type, types.ClassType)) and
+ issubclass(parent, TestCase)):
+ return TestSuite([parent(obj.__name__)])
elif isinstance(obj, TestSuite):
return obj
elif callable(obj):
test = obj()
- if not isinstance(test, (TestCase, TestSuite)):
- raise ValueError, \
- "calling %s returned %s, not a test" % (obj,test)
- return test
+ if isinstance(test, TestSuite):
+ return test
+ elif isinstance(test, TestCase):
+ return TestSuite([test])
+ else:
+ raise TypeError("calling %s returned %s, not a test" %
+ (obj, test))
else:
- raise ValueError, "don't know how to make test from: %s" % obj
+ raise TypeError("don't know how to make test from: %s" % obj)
def loadTestsFromNames(self, names, module=None):
"""Return a suite of all tests cases found using the given sequence
@@ -559,10 +606,6 @@ class TestLoader:
def isTestMethod(attrname, testCaseClass=testCaseClass, prefix=self.testMethodPrefix):
return attrname.startswith(prefix) and callable(getattr(testCaseClass, attrname))
testFnNames = filter(isTestMethod, dir(testCaseClass))
- for baseclass in testCaseClass.__bases__:
- for testFnName in self.getTestCaseNames(baseclass):
- if testFnName not in testFnNames: # handle overridden methods
- testFnNames.append(testFnName)
if self.sortTestMethodsUsing:
testFnNames.sort(self.sortTestMethodsUsing)
return testFnNames
@@ -738,7 +781,8 @@ Examples:
in MyTestCase
"""
def __init__(self, module='__main__', defaultTest=None,
- argv=None, testRunner=None, testLoader=defaultTestLoader):
+ argv=None, testRunner=TextTestRunner,
+ testLoader=defaultTestLoader):
if type(module) == type(''):
self.module = __import__(module)
for part in module.split('.')[1:]:
@@ -788,9 +832,16 @@ Examples:
self.module)
def runTests(self):
- if self.testRunner is None:
- self.testRunner = TextTestRunner(verbosity=self.verbosity)
- result = self.testRunner.run(self.test)
+ if isinstance(self.testRunner, (type, types.ClassType)):
+ try:
+ testRunner = self.testRunner(verbosity=self.verbosity)
+ except TypeError:
+ # didn't accept the verbosity argument
+ testRunner = self.testRunner()
+ else:
+ # it is assumed to be a TestRunner instance
+ testRunner = self.testRunner
+ result = testRunner.run(self.test)
sys.exit(not result.wasSuccessful())
main = TestProgram
diff --git a/Lib/urllib.py b/Lib/urllib.py
index 4e24a8f..b83b574 100644
--- a/Lib/urllib.py
+++ b/Lib/urllib.py
@@ -326,6 +326,11 @@ class URLopener:
h.send(data)
errcode, errmsg, headers = h.getreply()
fp = h.getfile()
+ if errcode == -1:
+ if fp: fp.close()
+ # something went wrong with the HTTP status line
+ raise IOError, ('http protocol error', 0,
+ 'got a bad status line', None)
if errcode == 200:
return addinfourl(fp, headers, "http:" + url)
else:
@@ -413,6 +418,11 @@ class URLopener:
h.send(data)
errcode, errmsg, headers = h.getreply()
fp = h.getfile()
+ if errcode == -1:
+ if fp: fp.close()
+ # something went wrong with the HTTP status line
+ raise IOError, ('http protocol error', 0,
+ 'got a bad status line', None)
if errcode == 200:
return addinfourl(fp, headers, "https:" + url)
else:
@@ -1470,7 +1480,7 @@ def test(args=[]):
'/etc/passwd',
'file:/etc/passwd',
'file://localhost/etc/passwd',
- 'ftp://ftp.python.org/pub/python/README',
+ 'ftp://ftp.gnu.org/pub/README',
## 'gopher://gopher.micro.umn.edu/1/',
'http://www.python.org/index.html',
]
diff --git a/Lib/urllib2.py b/Lib/urllib2.py
index 60ff260..a0be039 100644
--- a/Lib/urllib2.py
+++ b/Lib/urllib2.py
@@ -14,36 +14,36 @@ non-error returns. The HTTPRedirectHandler automatically deals with
HTTP 301, 302, 303 and 307 redirect errors, and the HTTPDigestAuthHandler
deals with digest authentication.
-urlopen(url, data=None) -- basic usage is the same as original
+urlopen(url, data=None) -- Basic usage is the same as original
urllib. pass the url and optionally data to post to an HTTP URL, and
get a file-like object back. One difference is that you can also pass
a Request instance instead of URL. Raises a URLError (subclass of
IOError); for HTTP errors, raises an HTTPError, which can also be
treated as a valid response.
-build_opener -- function that creates a new OpenerDirector instance.
-will install the default handlers. accepts one or more Handlers as
+build_opener -- Function that creates a new OpenerDirector instance.
+Will install the default handlers. Accepts one or more Handlers as
arguments, either instances or Handler classes that it will
-instantiate. if one of the argument is a subclass of the default
+instantiate. If one of the argument is a subclass of the default
handler, the argument will be installed instead of the default.
-install_opener -- installs a new opener as the default opener.
+install_opener -- Installs a new opener as the default opener.
objects of interest:
OpenerDirector --
-Request -- an object that encapsulates the state of a request. the
-state can be a simple as the URL. it can also include extra HTTP
+Request -- An object that encapsulates the state of a request. The
+state can be as simple as the URL. It can also include extra HTTP
headers, e.g. a User-Agent.
BaseHandler --
exceptions:
-URLError-- a subclass of IOError, individual protocols have their own
-specific subclass
+URLError -- A subclass of IOError, individual protocols have their own
+specific subclass.
-HTTPError-- also a valid HTTP response, so you can treat an HTTP error
-as an exceptional event or valid response
+HTTPError -- Also a valid HTTP response, so you can treat an HTTP error
+as an exceptional event or valid response.
internals:
BaseHandler and parent
@@ -55,7 +55,10 @@ import urllib2
# set up authentication info
authinfo = urllib2.HTTPBasicAuthHandler()
-authinfo.add_password('realm', 'host', 'username', 'password')
+authinfo.add_password(realm='PDQ Application',
+ uri='https://mahler:8092/site-updates.py',
+ user='klem',
+ passwd='geheim$parole')
proxy_support = urllib2.ProxyHandler({"http" : "http://ahad-haam:3128"})
@@ -334,7 +337,8 @@ class OpenerDirector:
added = True
if added:
- # XXX why does self.handlers need to be sorted?
+ # the handlers must work in an specific order, the order
+ # is specified in a Handler attribute
bisect.insort(self.handlers, handler)
handler.add_parent(self)
@@ -486,7 +490,9 @@ class HTTPErrorProcessor(BaseHandler):
def http_response(self, request, response):
code, msg, hdrs = response.code, response.msg, response.info()
- if code not in (200, 206):
+ # According to RFC 2616, "2xx" code indicates that the client's
+ # request was successfully received, understood, and accepted.
+ if not (200 <= code < 300):
response = self.parent.error(
'http', request, response, code, msg, hdrs)
@@ -766,11 +772,10 @@ class HTTPPasswordMgrWithDefaultRealm(HTTPPasswordMgr):
class AbstractBasicAuthHandler:
- rx = re.compile('[ \t]*([^ \t]+)[ \t]+realm="([^"]*)"', re.I)
+ # XXX this allows for multiple auth-schemes, but will stupidly pick
+ # the last one with a realm specified.
- # XXX there can actually be multiple auth-schemes in a
- # www-authenticate header. should probably be a lot more careful
- # in parsing them to extract multiple alternatives
+ rx = re.compile('(?:.*,)*[ \t]*([^ \t]+)[ \t]+realm="([^"]*)"', re.I)
# XXX could pre-emptively send auth info already accepted (RFC 2617,
# end of section 2, and section 1.2 immediately after "credentials"
@@ -1214,19 +1219,23 @@ class FileHandler(BaseHandler):
host = req.get_host()
file = req.get_selector()
localfile = url2pathname(file)
- stats = os.stat(localfile)
- size = stats.st_size
- modified = email.utils.formatdate(stats.st_mtime, usegmt=True)
- mtype = mimetypes.guess_type(file)[0]
- headers = mimetools.Message(StringIO(
- 'Content-type: %s\nContent-length: %d\nLast-modified: %s\n' %
- (mtype or 'text/plain', size, modified)))
- if host:
- host, port = splitport(host)
- if not host or \
- (not port and socket.gethostbyname(host) in self.get_names()):
- return addinfourl(open(localfile, 'rb'),
- headers, 'file:'+file)
+ try:
+ stats = os.stat(localfile)
+ size = stats.st_size
+ modified = email.utils.formatdate(stats.st_mtime, usegmt=True)
+ mtype = mimetypes.guess_type(file)[0]
+ headers = mimetools.Message(StringIO(
+ 'Content-type: %s\nContent-length: %d\nLast-modified: %s\n' %
+ (mtype or 'text/plain', size, modified)))
+ if host:
+ host, port = splitport(host)
+ if not host or \
+ (not port and socket.gethostbyname(host) in self.get_names()):
+ return addinfourl(open(localfile, 'rb'),
+ headers, 'file:'+file)
+ except OSError as msg:
+ # urllib2 users shouldn't expect OSErrors coming from urlopen()
+ raise URLError(msg)
raise URLError('file not on local host')
class FTPHandler(BaseHandler):
diff --git a/Lib/wave.py b/Lib/wave.py
index dd5a47a..81a7141 100644
--- a/Lib/wave.py
+++ b/Lib/wave.py
@@ -159,7 +159,12 @@ class Wave_read:
f = __builtin__.open(f, 'rb')
self._i_opened_the_file = f
# else, assume it is an open file object already
- self.initfp(f)
+ try:
+ self.initfp(f)
+ except:
+ if self._i_opened_the_file:
+ f.close()
+ raise
def __del__(self):
self.close()
@@ -297,7 +302,12 @@ class Wave_write:
if isinstance(f, basestring):
f = __builtin__.open(f, 'wb')
self._i_opened_the_file = f
- self.initfp(f)
+ try:
+ self.initfp(f)
+ except:
+ if self._i_opened_the_file:
+ f.close()
+ raise
def initfp(self, file):
self._file = file
diff --git a/Lib/webbrowser.py b/Lib/webbrowser.py
index 55aa04c..dd5e019 100644
--- a/Lib/webbrowser.py
+++ b/Lib/webbrowser.py
@@ -2,6 +2,7 @@
"""Interfaces for launching and remotely controlling Web browsers."""
import os
+import shlex
import sys
import stat
import subprocess
@@ -32,7 +33,11 @@ def get(using=None):
for browser in alternatives:
if '%s' in browser:
# User gave us a command line, split it into name and args
- return GenericBrowser(browser.split())
+ browser = shlex.split(browser)
+ if browser[-1] == '&':
+ return BackgroundBrowser(browser[:-1])
+ else:
+ return GenericBrowser(browser)
else:
# User gave us a browser name or path.
try:
@@ -437,19 +442,16 @@ class Grail(BaseBrowser):
# a console terminal or an X display to run.
def register_X_browsers():
- # The default Gnome browser
- if _iscommand("gconftool-2"):
- # get the web browser string from gconftool
- gc = 'gconftool-2 -g /desktop/gnome/url-handlers/http/command 2>/dev/null'
- out = os.popen(gc)
- commd = out.read().strip()
- retncode = out.close()
-
- # if successful, register it
- if retncode is None and commd:
- register("gnome", None, BackgroundBrowser(commd.split()))
-
- # First, the Mozilla/Netscape browsers
+
+ # The default GNOME browser
+ if "GNOME_DESKTOP_SESSION_ID" in os.environ and _iscommand("gnome-open"):
+ register("gnome-open", None, BackgroundBrowser("gnome-open"))
+
+ # The default KDE browser
+ if "KDE_FULL_SESSION" in os.environ and _iscommand("kfmclient"):
+ register("kfmclient", Konqueror, Konqueror("kfmclient"))
+
+ # The Mozilla/Netscape browsers
for browser in ("mozilla-firefox", "firefox",
"mozilla-firebird", "firebird",
"seamonkey", "mozilla", "netscape"):
@@ -508,17 +510,28 @@ if os.environ.get("TERM"):
if sys.platform[:3] == "win":
class WindowsDefault(BaseBrowser):
def open(self, url, new=0, autoraise=1):
- os.startfile(url)
- return True # Oh, my...
+ try:
+ os.startfile(url)
+ except WindowsError:
+ # [Error 22] No application is associated with the specified
+ # file for this operation: '<URL>'
+ return False
+ else:
+ return True
_tryorder = []
_browsers = {}
- # Prefer mozilla/netscape/opera if present
+
+ # First try to use the default Windows browser
+ register("windows-default", WindowsDefault)
+
+ # Detect some common Windows browsers, fallback to IE
+ iexplore = os.path.join(os.environ.get("PROGRAMFILES", "C:\\Program Files"),
+ "Internet Explorer\\IEXPLORE.EXE")
for browser in ("firefox", "firebird", "seamonkey", "mozilla",
- "netscape", "opera"):
+ "netscape", "opera", iexplore):
if _iscommand(browser):
register(browser, None, BackgroundBrowser(browser))
- register("windows-default", WindowsDefault)
#
# Platform support for MacOS
diff --git a/Lib/zipfile.py b/Lib/zipfile.py
index d0a1f65..fa7e910 100644
--- a/Lib/zipfile.py
+++ b/Lib/zipfile.py
@@ -361,6 +361,200 @@ class _ZipDecrypter:
self._UpdateKeys(c)
return c
+class ZipExtFile:
+ """File-like object for reading an archive member.
+ Is returned by ZipFile.open().
+ """
+
+ def __init__(self, fileobj, zipinfo, decrypt=None):
+ self.fileobj = fileobj
+ self.decrypter = decrypt
+ self.bytes_read = 0
+ self.rawbuffer = ''
+ self.readbuffer = ''
+ self.linebuffer = ''
+ self.eof = False
+ self.univ_newlines = False
+ self.nlSeps = ("\n", )
+ self.lastdiscard = ''
+
+ self.compress_type = zipinfo.compress_type
+ self.compress_size = zipinfo.compress_size
+
+ self.closed = False
+ self.mode = "r"
+ self.name = zipinfo.filename
+
+ # read from compressed files in 64k blocks
+ self.compreadsize = 64*1024
+ if self.compress_type == ZIP_DEFLATED:
+ self.dc = zlib.decompressobj(-15)
+
+ def set_univ_newlines(self, univ_newlines):
+ self.univ_newlines = univ_newlines
+
+ # pick line separator char(s) based on universal newlines flag
+ self.nlSeps = ("\n", )
+ if self.univ_newlines:
+ self.nlSeps = ("\r\n", "\r", "\n")
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ nextline = self.readline()
+ if not nextline:
+ raise StopIteration()
+
+ return nextline
+
+ def close(self):
+ self.closed = True
+
+ def _checkfornewline(self):
+ nl, nllen = -1, -1
+ if self.linebuffer:
+ # ugly check for cases where half of an \r\n pair was
+ # read on the last pass, and the \r was discarded. In this
+ # case we just throw away the \n at the start of the buffer.
+ if (self.lastdiscard, self.linebuffer[0]) == ('\r','\n'):
+ self.linebuffer = self.linebuffer[1:]
+
+ for sep in self.nlSeps:
+ nl = self.linebuffer.find(sep)
+ if nl >= 0:
+ nllen = len(sep)
+ return nl, nllen
+
+ return nl, nllen
+
+ def readline(self, size = -1):
+ """Read a line with approx. size. If size is negative,
+ read a whole line.
+ """
+ if size < 0:
+ size = sys.maxint
+ elif size == 0:
+ return ''
+
+ # check for a newline already in buffer
+ nl, nllen = self._checkfornewline()
+
+ if nl >= 0:
+ # the next line was already in the buffer
+ nl = min(nl, size)
+ else:
+ # no line break in buffer - try to read more
+ size -= len(self.linebuffer)
+ while nl < 0 and size > 0:
+ buf = self.read(min(size, 100))
+ if not buf:
+ break
+ self.linebuffer += buf
+ size -= len(buf)
+
+ # check for a newline in buffer
+ nl, nllen = self._checkfornewline()
+
+ # we either ran out of bytes in the file, or
+ # met the specified size limit without finding a newline,
+ # so return current buffer
+ if nl < 0:
+ s = self.linebuffer
+ self.linebuffer = ''
+ return s
+
+ buf = self.linebuffer[:nl]
+ self.lastdiscard = self.linebuffer[nl:nl + nllen]
+ self.linebuffer = self.linebuffer[nl + nllen:]
+
+ # line is always returned with \n as newline char (except possibly
+ # for a final incomplete line in the file, which is handled above).
+ return buf + "\n"
+
+ def readlines(self, sizehint = -1):
+ """Return a list with all (following) lines. The sizehint parameter
+ is ignored in this implementation.
+ """
+ result = []
+ while True:
+ line = self.readline()
+ if not line: break
+ result.append(line)
+ return result
+
+ def read(self, size = None):
+ # act like file() obj and return empty string if size is 0
+ if size == 0:
+ return ''
+
+ # determine read size
+ bytesToRead = self.compress_size - self.bytes_read
+
+ # adjust read size for encrypted files since the first 12 bytes
+ # are for the encryption/password information
+ if self.decrypter is not None:
+ bytesToRead -= 12
+
+ if size is not None and size >= 0:
+ if self.compress_type == ZIP_STORED:
+ lr = len(self.readbuffer)
+ bytesToRead = min(bytesToRead, size - lr)
+ elif self.compress_type == ZIP_DEFLATED:
+ if len(self.readbuffer) > size:
+ # the user has requested fewer bytes than we've already
+ # pulled through the decompressor; don't read any more
+ bytesToRead = 0
+ else:
+ # user will use up the buffer, so read some more
+ lr = len(self.rawbuffer)
+ bytesToRead = min(bytesToRead, self.compreadsize - lr)
+
+ # avoid reading past end of file contents
+ if bytesToRead + self.bytes_read > self.compress_size:
+ bytesToRead = self.compress_size - self.bytes_read
+
+ # try to read from file (if necessary)
+ if bytesToRead > 0:
+ bytes = self.fileobj.read(bytesToRead)
+ self.bytes_read += len(bytes)
+ self.rawbuffer += bytes
+
+ # handle contents of raw buffer
+ if self.rawbuffer:
+ newdata = self.rawbuffer
+ self.rawbuffer = ''
+
+ # decrypt new data if we were given an object to handle that
+ if newdata and self.decrypter is not None:
+ newdata = ''.join(map(self.decrypter, newdata))
+
+ # decompress newly read data if necessary
+ if newdata and self.compress_type == ZIP_DEFLATED:
+ newdata = self.dc.decompress(newdata)
+ self.rawbuffer = self.dc.unconsumed_tail
+ if self.eof and len(self.rawbuffer) == 0:
+ # we're out of raw bytes (both from the file and
+ # the local buffer); flush just to make sure the
+ # decompressor is done
+ newdata += self.dc.flush()
+ # prevent decompressor from being used again
+ self.dc = None
+
+ self.readbuffer += newdata
+
+
+ # return what the user asked for
+ if size is None or len(self.readbuffer) <= size:
+ bytes = self.readbuffer
+ self.readbuffer = ''
+ else:
+ bytes = self.readbuffer[:size]
+ self.readbuffer = self.readbuffer[size:]
+
+ return bytes
+
+
class ZipFile:
""" Class with methods to open, read, write, close, list zip files.
@@ -540,73 +734,75 @@ class ZipFile:
def read(self, name, pwd=None):
"""Return file bytes (as a string) for name."""
- if self.mode not in ("r", "a"):
- raise RuntimeError, 'read() requires mode "r" or "a"'
+ return self.open(name, "r", pwd).read()
+
+ def open(self, name, mode="r", pwd=None):
+ """Return file-like object for 'name'."""
+ if mode not in ("r", "U", "rU"):
+ raise RuntimeError, 'open() requires mode "r", "U", or "rU"'
if not self.fp:
raise RuntimeError, \
"Attempt to read ZIP archive that was already closed"
+
+ # Only open a new file for instances where we were not
+ # given a file object in the constructor
+ if self._filePassed:
+ zef_file = self.fp
+ else:
+ zef_file = open(self.filename, 'rb')
+
+ # Get info object for name
zinfo = self.getinfo(name)
- is_encrypted = zinfo.flag_bits & 0x1
- if is_encrypted:
- if not pwd:
- pwd = self.pwd
- if not pwd:
- raise RuntimeError, "File %s is encrypted, " \
- "password required for extraction" % name
- filepos = self.fp.tell()
- self.fp.seek(zinfo.header_offset, 0)
+ filepos = zef_file.tell()
+
+ zef_file.seek(zinfo.header_offset, 0)
# Skip the file header:
- fheader = self.fp.read(30)
+ fheader = zef_file.read(30)
if fheader[0:4] != stringFileHeader:
raise BadZipfile, "Bad magic number for file header"
fheader = struct.unpack(structFileHeader, fheader)
- fname = self.fp.read(fheader[_FH_FILENAME_LENGTH])
+ fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
if fheader[_FH_EXTRA_FIELD_LENGTH]:
- self.fp.read(fheader[_FH_EXTRA_FIELD_LENGTH])
+ zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH])
if fname != zinfo.orig_filename:
raise BadZipfile, \
'File name in directory "%s" and header "%s" differ.' % (
zinfo.orig_filename, fname)
- bytes = self.fp.read(zinfo.compress_size)
- # Go with decryption
+ # check for encrypted flag & handle password
+ is_encrypted = zinfo.flag_bits & 0x1
+ zd = None
if is_encrypted:
+ if not pwd:
+ pwd = self.pwd
+ if not pwd:
+ raise RuntimeError, "File %s is encrypted, " \
+ "password required for extraction" % name
+
zd = _ZipDecrypter(pwd)
# The first 12 bytes in the cypher stream is an encryption header
# used to strengthen the algorithm. The first 11 bytes are
# completely random, while the 12th contains the MSB of the CRC,
# and is used to check the correctness of the password.
+ bytes = zef_file.read(12)
h = map(zd, bytes[0:12])
if ord(h[11]) != ((zinfo.CRC>>24)&255):
raise RuntimeError, "Bad password for file %s" % name
- bytes = "".join(map(zd, bytes[12:]))
- # Go with decompression
- self.fp.seek(filepos, 0)
- if zinfo.compress_type == ZIP_STORED:
- pass
- elif zinfo.compress_type == ZIP_DEFLATED:
- if not zlib:
- raise RuntimeError, \
- "De-compression requires the (missing) zlib module"
- # zlib compress/decompress code by Jeremy Hylton of CNRI
- dc = zlib.decompressobj(-15)
- bytes = dc.decompress(bytes)
- # need to feed in unused pad byte so that zlib won't choke
- ex = dc.decompress('Z') + dc.flush()
- if ex:
- bytes = bytes + ex
+
+ # build and return a ZipExtFile
+ if zd is None:
+ zef = ZipExtFile(zef_file, zinfo)
else:
- raise BadZipfile, \
- "Unsupported compression method %d for file %s" % \
- (zinfo.compress_type, name)
- crc = binascii.crc32(bytes)
- if crc != zinfo.CRC:
- raise BadZipfile, "Bad CRC-32 for file %s" % name
- return bytes
+ zef = ZipExtFile(zef_file, zinfo, zd)
+
+ # set universal newlines on ZipExtFile if necessary
+ if "U" in mode:
+ zef.set_univ_newlines(True)
+ return zef
def _writecheck(self, zinfo):
"""Check for errors before writing a file to the archive."""