summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorBrett Cannon <brett@python.org>2012-06-15 23:04:29 (GMT)
committerBrett Cannon <brett@python.org>2012-06-15 23:04:29 (GMT)
commit24aa693c7ef8f217fbd238eb7af7d828e13a07eb (patch)
tree7d01ab630c2e8eef1e168b1aa5d84131b60cfd50 /Lib
parent99d776fdf4aa5a66266ebcec2263fab501f03088 (diff)
parent016ef551a793f72f582d707ce5bb55bf4940cf27 (diff)
downloadcpython-24aa693c7ef8f217fbd238eb7af7d828e13a07eb.zip
cpython-24aa693c7ef8f217fbd238eb7af7d828e13a07eb.tar.gz
cpython-24aa693c7ef8f217fbd238eb7af7d828e13a07eb.tar.bz2
Merge
Diffstat (limited to 'Lib')
-rw-r--r--Lib/_strptime.py6
-rw-r--r--Lib/datetime.py6
-rw-r--r--Lib/hmac.py28
-rw-r--r--Lib/idlelib/AutoComplete.py2
-rw-r--r--Lib/mailbox.py1
-rw-r--r--Lib/multiprocessing/__init__.py11
-rw-r--r--Lib/multiprocessing/dummy/__init__.py4
-rw-r--r--Lib/multiprocessing/forking.py6
-rw-r--r--Lib/multiprocessing/managers.py94
-rw-r--r--Lib/multiprocessing/synchronize.py40
-rw-r--r--Lib/multiprocessing/util.py27
-rw-r--r--Lib/test/support.py2
-rw-r--r--Lib/test/test_hmac.py44
-rw-r--r--Lib/test/test_mailbox.py11
-rw-r--r--Lib/test/test_multiprocessing.py355
-rw-r--r--Lib/test/test_structseq.py5
-rw-r--r--Lib/test/test_time.py65
-rw-r--r--Lib/test/test_xml_etree.py226
-rw-r--r--Lib/test/test_xml_etree_c.py28
-rw-r--r--Lib/xml/etree/ElementTree.py32
20 files changed, 708 insertions, 285 deletions
diff --git a/Lib/_strptime.py b/Lib/_strptime.py
index fa06376..b0cd3d6 100644
--- a/Lib/_strptime.py
+++ b/Lib/_strptime.py
@@ -486,19 +486,19 @@ def _strptime(data_string, format="%a %b %d %H:%M:%S %Y"):
return (year, month, day,
hour, minute, second,
- weekday, julian, tz, gmtoff, tzname), fraction
+ weekday, julian, tz, tzname, gmtoff), fraction
def _strptime_time(data_string, format="%a %b %d %H:%M:%S %Y"):
"""Return a time struct based on the input string and the
format string."""
tt = _strptime(data_string, format)[0]
- return time.struct_time(tt[:9])
+ return time.struct_time(tt[:time._STRUCT_TM_ITEMS])
def _strptime_datetime(cls, data_string, format="%a %b %d %H:%M:%S %Y"):
"""Return a class cls instance based on the input string and the
format string."""
tt, fraction = _strptime(data_string, format)
- gmtoff, tzname = tt[-2:]
+ tzname, gmtoff = tt[-2:]
args = tt[:6] + (fraction,)
if gmtoff is not None:
tzdelta = datetime_timedelta(seconds=gmtoff)
diff --git a/Lib/datetime.py b/Lib/datetime.py
index 5d8d9b3..21aab35 100644
--- a/Lib/datetime.py
+++ b/Lib/datetime.py
@@ -1670,10 +1670,8 @@ class datetime(date):
if mytz is ottz:
base_compare = True
else:
- if mytz is not None:
- myoff = self.utcoffset()
- if ottz is not None:
- otoff = other.utcoffset()
+ myoff = self.utcoffset()
+ otoff = other.utcoffset()
base_compare = myoff == otoff
if base_compare:
diff --git a/Lib/hmac.py b/Lib/hmac.py
index 13ffdbe..e47965b 100644
--- a/Lib/hmac.py
+++ b/Lib/hmac.py
@@ -13,24 +13,24 @@ trans_36 = bytes((x ^ 0x36) for x in range(256))
digest_size = None
-def secure_compare(a, b):
- """Returns the equivalent of 'a == b', but using a time-independent
- comparison method to prevent timing attacks."""
- if not ((isinstance(a, str) and isinstance(b, str)) or
- (isinstance(a, bytes) and isinstance(b, bytes))):
- raise TypeError("inputs must be strings or bytes")
-
+def compare_digest(a, b):
+ """Returns the equivalent of 'a == b', but avoids content based short
+ circuiting to reduce the vulnerability to timing attacks."""
+ # Consistent timing matters more here than data type flexibility
+ if not (isinstance(a, bytes) and isinstance(b, bytes)):
+ raise TypeError("inputs must be bytes instances")
+
+ # We assume the length of the expected digest is public knowledge,
+ # thus this early return isn't leaking anything an attacker wouldn't
+ # already know
if len(a) != len(b):
return False
+ # We assume that integers in the bytes range are all cached,
+ # thus timing shouldn't vary much due to integer object creation
result = 0
- if isinstance(a, bytes):
- for x, y in zip(a, b):
- result |= x ^ y
- else:
- for x, y in zip(a, b):
- result |= ord(x) ^ ord(y)
-
+ for x, y in zip(a, b):
+ result |= x ^ y
return result == 0
diff --git a/Lib/idlelib/AutoComplete.py b/Lib/idlelib/AutoComplete.py
index b38b108..929d358 100644
--- a/Lib/idlelib/AutoComplete.py
+++ b/Lib/idlelib/AutoComplete.py
@@ -140,7 +140,7 @@ class AutoComplete:
elif hp.is_in_code() and (not mode or mode==COMPLETE_ATTRIBUTES):
self._remove_autocomplete_window()
mode = COMPLETE_ATTRIBUTES
- while i and curline[i-1] in ID_CHARS or ord(curline[i-1]) > 127:
+ while i and (curline[i-1] in ID_CHARS or ord(curline[i-1]) > 127):
i -= 1
comp_start = curline[i:j]
if i and curline[i-1] == '.':
diff --git a/Lib/mailbox.py b/Lib/mailbox.py
index 7a29555..d86ad94 100644
--- a/Lib/mailbox.py
+++ b/Lib/mailbox.py
@@ -675,6 +675,7 @@ class _singlefileMailbox(Mailbox):
new_file.write(buffer)
new_toc[key] = (new_start, new_file.tell())
self._post_message_hook(new_file)
+ self._file_length = new_file.tell()
except:
new_file.close()
os.remove(new_file.name)
diff --git a/Lib/multiprocessing/__init__.py b/Lib/multiprocessing/__init__.py
index 02460f0..1f3e67c 100644
--- a/Lib/multiprocessing/__init__.py
+++ b/Lib/multiprocessing/__init__.py
@@ -23,8 +23,8 @@ __all__ = [
'Manager', 'Pipe', 'cpu_count', 'log_to_stderr', 'get_logger',
'allow_connection_pickling', 'BufferTooShort', 'TimeoutError',
'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Condition',
- 'Event', 'Queue', 'SimpleQueue', 'JoinableQueue', 'Pool', 'Value', 'Array',
- 'RawValue', 'RawArray', 'SUBDEBUG', 'SUBWARNING',
+ 'Event', 'Barrier', 'Queue', 'SimpleQueue', 'JoinableQueue', 'Pool',
+ 'Value', 'Array', 'RawValue', 'RawArray', 'SUBDEBUG', 'SUBWARNING',
]
__author__ = 'R. Oudkerk (r.m.oudkerk@gmail.com)'
@@ -186,6 +186,13 @@ def Event():
from multiprocessing.synchronize import Event
return Event()
+def Barrier(parties, action=None, timeout=None):
+ '''
+ Returns a barrier object
+ '''
+ from multiprocessing.synchronize import Barrier
+ return Barrier(parties, action, timeout)
+
def Queue(maxsize=0):
'''
Returns a queue object
diff --git a/Lib/multiprocessing/dummy/__init__.py b/Lib/multiprocessing/dummy/__init__.py
index 9bf8f6b..e31fc61 100644
--- a/Lib/multiprocessing/dummy/__init__.py
+++ b/Lib/multiprocessing/dummy/__init__.py
@@ -35,7 +35,7 @@
__all__ = [
'Process', 'current_process', 'active_children', 'freeze_support',
'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Condition',
- 'Event', 'Queue', 'Manager', 'Pipe', 'Pool', 'JoinableQueue'
+ 'Event', 'Barrier', 'Queue', 'Manager', 'Pipe', 'Pool', 'JoinableQueue'
]
#
@@ -49,7 +49,7 @@ import array
from multiprocessing.dummy.connection import Pipe
from threading import Lock, RLock, Semaphore, BoundedSemaphore
-from threading import Event, Condition
+from threading import Event, Condition, Barrier
from queue import Queue
#
diff --git a/Lib/multiprocessing/forking.py b/Lib/multiprocessing/forking.py
index 3a474cd..4baf548 100644
--- a/Lib/multiprocessing/forking.py
+++ b/Lib/multiprocessing/forking.py
@@ -13,7 +13,7 @@ import signal
from multiprocessing import util, process
-__all__ = ['Popen', 'assert_spawning', 'exit', 'duplicate', 'close', 'ForkingPickler']
+__all__ = ['Popen', 'assert_spawning', 'duplicate', 'close', 'ForkingPickler']
#
# Check that the current thread is spawning a child process
@@ -75,7 +75,6 @@ else:
#
if sys.platform != 'win32':
- exit = os._exit
duplicate = os.dup
close = os.close
@@ -168,7 +167,6 @@ else:
WINEXE = (sys.platform == 'win32' and getattr(sys, 'frozen', False))
WINSERVICE = sys.executable.lower().endswith("pythonservice.exe")
- exit = _winapi.ExitProcess
close = _winapi.CloseHandle
#
@@ -349,7 +347,7 @@ else:
from_parent.close()
exitcode = self._bootstrap()
- exit(exitcode)
+ sys.exit(exitcode)
def get_preparation_data(name):
diff --git a/Lib/multiprocessing/managers.py b/Lib/multiprocessing/managers.py
index 6a7dccb..f6611af 100644
--- a/Lib/multiprocessing/managers.py
+++ b/Lib/multiprocessing/managers.py
@@ -22,7 +22,7 @@ import queue
from traceback import format_exc
from multiprocessing import Process, current_process, active_children, Pool, util, connection
from multiprocessing.process import AuthenticationString
-from multiprocessing.forking import exit, Popen, ForkingPickler
+from multiprocessing.forking import Popen, ForkingPickler
from time import time as _time
#
@@ -140,28 +140,38 @@ class Server(object):
self.id_to_obj = {'0': (None, ())}
self.id_to_refcount = {}
self.mutex = threading.RLock()
- self.stop = 0
def serve_forever(self):
'''
Run the server forever
'''
+ self.stop_event = threading.Event()
current_process()._manager_server = self
try:
+ accepter = threading.Thread(target=self.accepter)
+ accepter.daemon = True
+ accepter.start()
try:
- while 1:
- try:
- c = self.listener.accept()
- except (OSError, IOError):
- continue
- t = threading.Thread(target=self.handle_request, args=(c,))
- t.daemon = True
- t.start()
+ while not self.stop_event.is_set():
+ self.stop_event.wait(1)
except (KeyboardInterrupt, SystemExit):
pass
finally:
- self.stop = 999
- self.listener.close()
+ if sys.stdout != sys.__stdout__:
+ util.debug('resetting stdout, stderr')
+ sys.stdout = sys.__stdout__
+ sys.stderr = sys.__stderr__
+ sys.exit(0)
+
+ def accepter(self):
+ while True:
+ try:
+ c = self.listener.accept()
+ except (OSError, IOError):
+ continue
+ t = threading.Thread(target=self.handle_request, args=(c,))
+ t.daemon = True
+ t.start()
def handle_request(self, c):
'''
@@ -208,7 +218,7 @@ class Server(object):
send = conn.send
id_to_obj = self.id_to_obj
- while not self.stop:
+ while not self.stop_event.is_set():
try:
methodname = obj = None
@@ -318,32 +328,13 @@ class Server(object):
Shutdown this process
'''
try:
- try:
- util.debug('manager received shutdown message')
- c.send(('#RETURN', None))
-
- if sys.stdout != sys.__stdout__:
- util.debug('resetting stdout, stderr')
- sys.stdout = sys.__stdout__
- sys.stderr = sys.__stderr__
-
- util._run_finalizers(0)
-
- for p in active_children():
- util.debug('terminating a child process of manager')
- p.terminate()
-
- for p in active_children():
- util.debug('terminating a child process of manager')
- p.join()
-
- util._run_finalizers()
- util.info('manager exiting with exitcode 0')
- except:
- import traceback
- traceback.print_exc()
+ util.debug('manager received shutdown message')
+ c.send(('#RETURN', None))
+ except:
+ import traceback
+ traceback.print_exc()
finally:
- exit(0)
+ self.stop_event.set()
def create(self, c, typeid, *args, **kwds):
'''
@@ -455,10 +446,6 @@ class BaseManager(object):
self._serializer = serializer
self._Listener, self._Client = listener_client[serializer]
- def __reduce__(self):
- return type(self).from_address, \
- (self._address, self._authkey, self._serializer)
-
def get_server(self):
'''
Return server object with serve_forever() method and address attribute
@@ -595,7 +582,7 @@ class BaseManager(object):
except Exception:
pass
- process.join(timeout=0.2)
+ process.join(timeout=1.0)
if process.is_alive():
util.info('manager still alive')
if hasattr(process, 'terminate'):
@@ -1006,6 +993,26 @@ class EventProxy(BaseProxy):
def wait(self, timeout=None):
return self._callmethod('wait', (timeout,))
+
+class BarrierProxy(BaseProxy):
+ _exposed_ = ('__getattribute__', 'wait', 'abort', 'reset')
+ def wait(self, timeout=None):
+ return self._callmethod('wait', (timeout,))
+ def abort(self):
+ return self._callmethod('abort')
+ def reset(self):
+ return self._callmethod('reset')
+ @property
+ def parties(self):
+ return self._callmethod('__getattribute__', ('parties',))
+ @property
+ def n_waiting(self):
+ return self._callmethod('__getattribute__', ('n_waiting',))
+ @property
+ def broken(self):
+ return self._callmethod('__getattribute__', ('broken',))
+
+
class NamespaceProxy(BaseProxy):
_exposed_ = ('__getattribute__', '__setattr__', '__delattr__')
def __getattr__(self, key):
@@ -1097,6 +1104,7 @@ SyncManager.register('Semaphore', threading.Semaphore, AcquirerProxy)
SyncManager.register('BoundedSemaphore', threading.BoundedSemaphore,
AcquirerProxy)
SyncManager.register('Condition', threading.Condition, ConditionProxy)
+SyncManager.register('Barrier', threading.Barrier, BarrierProxy)
SyncManager.register('Pool', Pool, PoolProxy)
SyncManager.register('list', list, ListProxy)
SyncManager.register('dict', dict, DictProxy)
diff --git a/Lib/multiprocessing/synchronize.py b/Lib/multiprocessing/synchronize.py
index 4502a97..22eabe5 100644
--- a/Lib/multiprocessing/synchronize.py
+++ b/Lib/multiprocessing/synchronize.py
@@ -333,3 +333,43 @@ class Event(object):
return False
finally:
self._cond.release()
+
+#
+# Barrier
+#
+
+class Barrier(threading.Barrier):
+
+ def __init__(self, parties, action=None, timeout=None):
+ import struct
+ from multiprocessing.heap import BufferWrapper
+ wrapper = BufferWrapper(struct.calcsize('i') * 2)
+ cond = Condition()
+ self.__setstate__((parties, action, timeout, cond, wrapper))
+ self._state = 0
+ self._count = 0
+
+ def __setstate__(self, state):
+ (self._parties, self._action, self._timeout,
+ self._cond, self._wrapper) = state
+ self._array = self._wrapper.create_memoryview().cast('i')
+
+ def __getstate__(self):
+ return (self._parties, self._action, self._timeout,
+ self._cond, self._wrapper)
+
+ @property
+ def _state(self):
+ return self._array[0]
+
+ @_state.setter
+ def _state(self, value):
+ self._array[0] = value
+
+ @property
+ def _count(self):
+ return self._array[1]
+
+ @_count.setter
+ def _count(self, value):
+ self._array[1] = value
diff --git a/Lib/multiprocessing/util.py b/Lib/multiprocessing/util.py
index 48abe38..8a6aede 100644
--- a/Lib/multiprocessing/util.py
+++ b/Lib/multiprocessing/util.py
@@ -269,21 +269,24 @@ _exiting = False
def _exit_function():
global _exiting
- info('process shutting down')
- debug('running all "atexit" finalizers with priority >= 0')
- _run_finalizers(0)
+ if not _exiting:
+ _exiting = True
- for p in active_children():
- if p._daemonic:
- info('calling terminate() for daemon %s', p.name)
- p._popen.terminate()
+ info('process shutting down')
+ debug('running all "atexit" finalizers with priority >= 0')
+ _run_finalizers(0)
- for p in active_children():
- info('calling join() for process %s', p.name)
- p.join()
+ for p in active_children():
+ if p._daemonic:
+ info('calling terminate() for daemon %s', p.name)
+ p._popen.terminate()
- debug('running the remaining "atexit" finalizers')
- _run_finalizers()
+ for p in active_children():
+ info('calling join() for process %s', p.name)
+ p.join()
+
+ debug('running the remaining "atexit" finalizers')
+ _run_finalizers()
atexit.register(_exit_function)
diff --git a/Lib/test/support.py b/Lib/test/support.py
index 6749a511..3ff1df5 100644
--- a/Lib/test/support.py
+++ b/Lib/test/support.py
@@ -1593,7 +1593,7 @@ def strip_python_stderr(stderr):
This will typically be run on the result of the communicate() method
of a subprocess.Popen object.
"""
- stderr = re.sub(br"\[\d+ refs\]\r?\n?$", b"", stderr).strip()
+ stderr = re.sub(br"\[\d+ refs\]\r?\n?", b"", stderr).strip()
return stderr
def args_from_interpreter_flags():
diff --git a/Lib/test/test_hmac.py b/Lib/test/test_hmac.py
index 042bc5d..4e5961d 100644
--- a/Lib/test/test_hmac.py
+++ b/Lib/test/test_hmac.py
@@ -302,40 +302,42 @@ class CopyTestCase(unittest.TestCase):
self.assertEqual(h1.hexdigest(), h2.hexdigest(),
"Hexdigest of copy doesn't match original hexdigest.")
-class SecureCompareTestCase(unittest.TestCase):
+class CompareDigestTestCase(unittest.TestCase):
def test_compare(self):
# Testing input type exception handling
a, b = 100, 200
- self.assertRaises(TypeError, hmac.secure_compare, a, b)
- a, b = 100, "foobar"
- self.assertRaises(TypeError, hmac.secure_compare, a, b)
+ self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ a, b = 100, b"foobar"
+ self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ a, b = b"foobar", 200
+ self.assertRaises(TypeError, hmac.compare_digest, a, b)
a, b = "foobar", b"foobar"
- self.assertRaises(TypeError, hmac.secure_compare, a, b)
+ self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ a, b = b"foobar", "foobar"
+ self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ a, b = "foobar", "foobar"
+ self.assertRaises(TypeError, hmac.compare_digest, a, b)
+ a, b = bytearray(b"foobar"), bytearray(b"foobar")
+ self.assertRaises(TypeError, hmac.compare_digest, a, b)
- # Testing str/bytes of different lengths
- a, b = "foobar", "foo"
- self.assertFalse(hmac.secure_compare(a, b))
+ # Testing bytes of different lengths
a, b = b"foobar", b"foo"
- self.assertFalse(hmac.secure_compare(a, b))
+ self.assertFalse(hmac.compare_digest(a, b))
a, b = b"\xde\xad\xbe\xef", b"\xde\xad"
- self.assertFalse(hmac.secure_compare(a, b))
+ self.assertFalse(hmac.compare_digest(a, b))
- # Testing str/bytes of same lengths, different values
- a, b = "foobar", "foobaz"
- self.assertFalse(hmac.secure_compare(a, b))
+ # Testing bytes of same lengths, different values
a, b = b"foobar", b"foobaz"
- self.assertFalse(hmac.secure_compare(a, b))
+ self.assertFalse(hmac.compare_digest(a, b))
a, b = b"\xde\xad\xbe\xef", b"\xab\xad\x1d\xea"
- self.assertFalse(hmac.secure_compare(a, b))
+ self.assertFalse(hmac.compare_digest(a, b))
- # Testing str/bytes of same lengths, same values
- a, b = "foobar", "foobar"
- self.assertTrue(hmac.secure_compare(a, b))
+ # Testing bytes of same lengths, same values
a, b = b"foobar", b"foobar"
- self.assertTrue(hmac.secure_compare(a, b))
+ self.assertTrue(hmac.compare_digest(a, b))
a, b = b"\xde\xad\xbe\xef", b"\xde\xad\xbe\xef"
- self.assertTrue(hmac.secure_compare(a, b))
+ self.assertTrue(hmac.compare_digest(a, b))
def test_main():
support.run_unittest(
@@ -343,7 +345,7 @@ def test_main():
ConstructorTestCase,
SanityTestCase,
CopyTestCase,
- SecureCompareTestCase
+ CompareDigestTestCase
)
if __name__ == "__main__":
diff --git a/Lib/test/test_mailbox.py b/Lib/test/test_mailbox.py
index 0fe4bd9..cb54505 100644
--- a/Lib/test/test_mailbox.py
+++ b/Lib/test/test_mailbox.py
@@ -504,6 +504,17 @@ class TestMailbox(TestBase):
# Write changes to disk
self._test_flush_or_close(self._box.flush, True)
+ def test_popitem_and_flush_twice(self):
+ # See #15036.
+ self._box.add(self._template % 0)
+ self._box.add(self._template % 1)
+ self._box.flush()
+
+ self._box.popitem()
+ self._box.flush()
+ self._box.popitem()
+ self._box.flush()
+
def test_lock_unlock(self):
# Lock and unlock the mailbox
self.assertFalse(os.path.exists(self._get_lock_path()))
diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/test_multiprocessing.py
index 65e7b0b..2704827 100644
--- a/Lib/test/test_multiprocessing.py
+++ b/Lib/test/test_multiprocessing.py
@@ -18,6 +18,7 @@ import array
import socket
import random
import logging
+import struct
import test.support
@@ -1057,6 +1058,340 @@ class _TestEvent(BaseTestCase):
self.assertEqual(wait(), True)
#
+# Tests for Barrier - adapted from tests in test/lock_tests.py
+#
+
+# Many of the tests for threading.Barrier use a list as an atomic
+# counter: a value is appended to increment the counter, and the
+# length of the list gives the value. We use the class DummyList
+# for the same purpose.
+
+class _DummyList(object):
+
+ def __init__(self):
+ wrapper = multiprocessing.heap.BufferWrapper(struct.calcsize('i'))
+ lock = multiprocessing.Lock()
+ self.__setstate__((wrapper, lock))
+ self._lengthbuf[0] = 0
+
+ def __setstate__(self, state):
+ (self._wrapper, self._lock) = state
+ self._lengthbuf = self._wrapper.create_memoryview().cast('i')
+
+ def __getstate__(self):
+ return (self._wrapper, self._lock)
+
+ def append(self, _):
+ with self._lock:
+ self._lengthbuf[0] += 1
+
+ def __len__(self):
+ with self._lock:
+ return self._lengthbuf[0]
+
+def _wait():
+ # A crude wait/yield function not relying on synchronization primitives.
+ time.sleep(0.01)
+
+
+class Bunch(object):
+ """
+ A bunch of threads.
+ """
+ def __init__(self, namespace, f, args, n, wait_before_exit=False):
+ """
+ Construct a bunch of `n` threads running the same function `f`.
+ If `wait_before_exit` is True, the threads won't terminate until
+ do_finish() is called.
+ """
+ self.f = f
+ self.args = args
+ self.n = n
+ self.started = namespace.DummyList()
+ self.finished = namespace.DummyList()
+ self._can_exit = namespace.Event()
+ if not wait_before_exit:
+ self._can_exit.set()
+ for i in range(n):
+ p = namespace.Process(target=self.task)
+ p.daemon = True
+ p.start()
+
+ def task(self):
+ pid = os.getpid()
+ self.started.append(pid)
+ try:
+ self.f(*self.args)
+ finally:
+ self.finished.append(pid)
+ self._can_exit.wait(30)
+ assert self._can_exit.is_set()
+
+ def wait_for_started(self):
+ while len(self.started) < self.n:
+ _wait()
+
+ def wait_for_finished(self):
+ while len(self.finished) < self.n:
+ _wait()
+
+ def do_finish(self):
+ self._can_exit.set()
+
+
+class AppendTrue(object):
+ def __init__(self, obj):
+ self.obj = obj
+ def __call__(self):
+ self.obj.append(True)
+
+
+class _TestBarrier(BaseTestCase):
+ """
+ Tests for Barrier objects.
+ """
+ N = 5
+ defaultTimeout = 10.0 # XXX Slow Windows buildbots need generous timeout
+
+ def setUp(self):
+ self.barrier = self.Barrier(self.N, timeout=self.defaultTimeout)
+
+ def tearDown(self):
+ self.barrier.abort()
+ self.barrier = None
+
+ def DummyList(self):
+ if self.TYPE == 'threads':
+ return []
+ elif self.TYPE == 'manager':
+ return self.manager.list()
+ else:
+ return _DummyList()
+
+ def run_threads(self, f, args):
+ b = Bunch(self, f, args, self.N-1)
+ f(*args)
+ b.wait_for_finished()
+
+ @classmethod
+ def multipass(cls, barrier, results, n):
+ m = barrier.parties
+ assert m == cls.N
+ for i in range(n):
+ results[0].append(True)
+ assert len(results[1]) == i * m
+ barrier.wait()
+ results[1].append(True)
+ assert len(results[0]) == (i + 1) * m
+ barrier.wait()
+ try:
+ assert barrier.n_waiting == 0
+ except NotImplementedError:
+ pass
+ assert not barrier.broken
+
+ def test_barrier(self, passes=1):
+ """
+ Test that a barrier is passed in lockstep
+ """
+ results = [self.DummyList(), self.DummyList()]
+ self.run_threads(self.multipass, (self.barrier, results, passes))
+
+ def test_barrier_10(self):
+ """
+ Test that a barrier works for 10 consecutive runs
+ """
+ return self.test_barrier(10)
+
+ @classmethod
+ def _test_wait_return_f(cls, barrier, queue):
+ res = barrier.wait()
+ queue.put(res)
+
+ def test_wait_return(self):
+ """
+ test the return value from barrier.wait
+ """
+ queue = self.Queue()
+ self.run_threads(self._test_wait_return_f, (self.barrier, queue))
+ results = [queue.get() for i in range(self.N)]
+ self.assertEqual(results.count(0), 1)
+
+ @classmethod
+ def _test_action_f(cls, barrier, results):
+ barrier.wait()
+ if len(results) != 1:
+ raise RuntimeError
+
+ def test_action(self):
+ """
+ Test the 'action' callback
+ """
+ results = self.DummyList()
+ barrier = self.Barrier(self.N, action=AppendTrue(results))
+ self.run_threads(self._test_action_f, (barrier, results))
+ self.assertEqual(len(results), 1)
+
+ @classmethod
+ def _test_abort_f(cls, barrier, results1, results2):
+ try:
+ i = barrier.wait()
+ if i == cls.N//2:
+ raise RuntimeError
+ barrier.wait()
+ results1.append(True)
+ except threading.BrokenBarrierError:
+ results2.append(True)
+ except RuntimeError:
+ barrier.abort()
+
+ def test_abort(self):
+ """
+ Test that an abort will put the barrier in a broken state
+ """
+ results1 = self.DummyList()
+ results2 = self.DummyList()
+ self.run_threads(self._test_abort_f,
+ (self.barrier, results1, results2))
+ self.assertEqual(len(results1), 0)
+ self.assertEqual(len(results2), self.N-1)
+ self.assertTrue(self.barrier.broken)
+
+ @classmethod
+ def _test_reset_f(cls, barrier, results1, results2, results3):
+ i = barrier.wait()
+ if i == cls.N//2:
+ # Wait until the other threads are all in the barrier.
+ while barrier.n_waiting < cls.N-1:
+ time.sleep(0.001)
+ barrier.reset()
+ else:
+ try:
+ barrier.wait()
+ results1.append(True)
+ except threading.BrokenBarrierError:
+ results2.append(True)
+ # Now, pass the barrier again
+ barrier.wait()
+ results3.append(True)
+
+ def test_reset(self):
+ """
+ Test that a 'reset' on a barrier frees the waiting threads
+ """
+ results1 = self.DummyList()
+ results2 = self.DummyList()
+ results3 = self.DummyList()
+ self.run_threads(self._test_reset_f,
+ (self.barrier, results1, results2, results3))
+ self.assertEqual(len(results1), 0)
+ self.assertEqual(len(results2), self.N-1)
+ self.assertEqual(len(results3), self.N)
+
+ @classmethod
+ def _test_abort_and_reset_f(cls, barrier, barrier2,
+ results1, results2, results3):
+ try:
+ i = barrier.wait()
+ if i == cls.N//2:
+ raise RuntimeError
+ barrier.wait()
+ results1.append(True)
+ except threading.BrokenBarrierError:
+ results2.append(True)
+ except RuntimeError:
+ barrier.abort()
+ # Synchronize and reset the barrier. Must synchronize first so
+ # that everyone has left it when we reset, and after so that no
+ # one enters it before the reset.
+ if barrier2.wait() == cls.N//2:
+ barrier.reset()
+ barrier2.wait()
+ barrier.wait()
+ results3.append(True)
+
+ def test_abort_and_reset(self):
+ """
+ Test that a barrier can be reset after being broken.
+ """
+ results1 = self.DummyList()
+ results2 = self.DummyList()
+ results3 = self.DummyList()
+ barrier2 = self.Barrier(self.N)
+
+ self.run_threads(self._test_abort_and_reset_f,
+ (self.barrier, barrier2, results1, results2, results3))
+ self.assertEqual(len(results1), 0)
+ self.assertEqual(len(results2), self.N-1)
+ self.assertEqual(len(results3), self.N)
+
+ @classmethod
+ def _test_timeout_f(cls, barrier, results):
+ i = barrier.wait(20)
+ if i == cls.N//2:
+ # One thread is late!
+ time.sleep(4.0)
+ try:
+ barrier.wait(0.5)
+ except threading.BrokenBarrierError:
+ results.append(True)
+
+ def test_timeout(self):
+ """
+ Test wait(timeout)
+ """
+ results = self.DummyList()
+ self.run_threads(self._test_timeout_f, (self.barrier, results))
+ self.assertEqual(len(results), self.barrier.parties)
+
+ @classmethod
+ def _test_default_timeout_f(cls, barrier, results):
+ i = barrier.wait(20)
+ if i == cls.N//2:
+ # One thread is later than the default timeout
+ time.sleep(4.0)
+ try:
+ barrier.wait()
+ except threading.BrokenBarrierError:
+ results.append(True)
+
+ def test_default_timeout(self):
+ """
+ Test the barrier's default timeout
+ """
+ barrier = self.Barrier(self.N, timeout=1.0)
+ results = self.DummyList()
+ self.run_threads(self._test_default_timeout_f, (barrier, results))
+ self.assertEqual(len(results), barrier.parties)
+
+ def test_single_thread(self):
+ b = self.Barrier(1)
+ b.wait()
+ b.wait()
+
+ @classmethod
+ def _test_thousand_f(cls, barrier, passes, conn, lock):
+ for i in range(passes):
+ barrier.wait()
+ with lock:
+ conn.send(i)
+
+ def test_thousand(self):
+ if self.TYPE == 'manager':
+ return
+ passes = 1000
+ lock = self.Lock()
+ conn, child_conn = self.Pipe(False)
+ for j in range(self.N):
+ p = self.Process(target=self._test_thousand_f,
+ args=(self.barrier, passes, child_conn, lock))
+ p.start()
+
+ for i in range(passes):
+ for j in range(self.N):
+ self.assertEqual(conn.recv(), i)
+
+#
#
#
@@ -1485,6 +1820,11 @@ class _TestZZZNumberOfObjects(BaseTestCase):
# run after all the other tests for the manager. It tests that
# there have been no "reference leaks" for the manager's shared
# objects. Note the comment in _TestPool.test_terminate().
+
+ # If some other test using ManagerMixin.manager fails, then the
+ # raised exception may keep alive a frame which holds a reference
+ # to a managed object. This will cause test_number_of_objects to
+ # also fail.
ALLOWED_TYPES = ('manager',)
def test_number_of_objects(self):
@@ -1564,6 +1904,11 @@ class _TestMyManager(BaseTestCase):
manager.shutdown()
+ # If the manager process exited cleanly then the exitcode
+ # will be zero. Otherwise (after a short timeout)
+ # terminate() is used, resulting in an exitcode of -SIGTERM.
+ self.assertEqual(manager._process.exitcode, 0)
+
#
# Test of connecting to a remote server and using xmlrpclib for serialization
#
@@ -1923,7 +2268,7 @@ class _TestConnection(BaseTestCase):
class _TestListener(BaseTestCase):
- ALLOWED_TYPES = ('processes')
+ ALLOWED_TYPES = ('processes',)
def test_multiple_bind(self):
for family in self.connection.families:
@@ -2505,10 +2850,12 @@ def create_test_cases(Mixin, type):
result = {}
glob = globals()
Type = type.capitalize()
+ ALL_TYPES = {'processes', 'threads', 'manager'}
for name in list(glob.keys()):
if name.startswith('_Test'):
base = glob[name]
+ assert set(base.ALLOWED_TYPES) <= ALL_TYPES, set(base.ALLOWED_TYPES)
if type in base.ALLOWED_TYPES:
newname = 'With' + Type + name[1:]
class Temp(base, unittest.TestCase, Mixin):
@@ -2527,7 +2874,7 @@ class ProcessesMixin(object):
Process = multiprocessing.Process
locals().update(get_attributes(multiprocessing, (
'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore',
- 'Condition', 'Event', 'Value', 'Array', 'RawValue',
+ 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'RawValue',
'RawArray', 'current_process', 'active_children', 'Pipe',
'connection', 'JoinableQueue', 'Pool'
)))
@@ -2542,7 +2889,7 @@ class ManagerMixin(object):
manager = object.__new__(multiprocessing.managers.SyncManager)
locals().update(get_attributes(manager, (
'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore',
- 'Condition', 'Event', 'Value', 'Array', 'list', 'dict',
+ 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'list', 'dict',
'Namespace', 'JoinableQueue', 'Pool'
)))
@@ -2555,7 +2902,7 @@ class ThreadsMixin(object):
Process = multiprocessing.dummy.Process
locals().update(get_attributes(multiprocessing.dummy, (
'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore',
- 'Condition', 'Event', 'Value', 'Array', 'current_process',
+ 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'current_process',
'active_children', 'Pipe', 'connection', 'dict', 'list',
'Namespace', 'JoinableQueue', 'Pool'
)))
diff --git a/Lib/test/test_structseq.py b/Lib/test/test_structseq.py
index d6c63b7..a89e955 100644
--- a/Lib/test/test_structseq.py
+++ b/Lib/test/test_structseq.py
@@ -78,8 +78,9 @@ class StructSeqTest(unittest.TestCase):
def test_fields(self):
t = time.gmtime()
- self.assertEqual(len(t), t.n_fields)
- self.assertEqual(t.n_fields, t.n_sequence_fields+t.n_unnamed_fields)
+ self.assertEqual(len(t), t.n_sequence_fields)
+ self.assertEqual(t.n_unnamed_fields, 0)
+ self.assertEqual(t.n_fields, time._STRUCT_TM_ITEMS)
def test_constructor(self):
t = time.struct_time
diff --git a/Lib/test/test_time.py b/Lib/test/test_time.py
index 02f05c3..63e1453 100644
--- a/Lib/test/test_time.py
+++ b/Lib/test/test_time.py
@@ -31,15 +31,14 @@ class TimeTestCase(unittest.TestCase):
time.time()
info = time.get_clock_info('time')
self.assertFalse(info.monotonic)
- if sys.platform != 'win32':
- self.assertTrue(info.adjusted)
+ self.assertTrue(info.adjustable)
def test_clock(self):
time.clock()
info = time.get_clock_info('clock')
self.assertTrue(info.monotonic)
- self.assertFalse(info.adjusted)
+ self.assertFalse(info.adjustable)
@unittest.skipUnless(hasattr(time, 'clock_gettime'),
'need time.clock_gettime()')
@@ -371,10 +370,7 @@ class TimeTestCase(unittest.TestCase):
info = time.get_clock_info('monotonic')
self.assertTrue(info.monotonic)
- if sys.platform == 'linux':
- self.assertTrue(info.adjusted)
- else:
- self.assertFalse(info.adjusted)
+ self.assertFalse(info.adjustable)
def test_perf_counter(self):
time.perf_counter()
@@ -390,7 +386,7 @@ class TimeTestCase(unittest.TestCase):
info = time.get_clock_info('process_time')
self.assertTrue(info.monotonic)
- self.assertFalse(info.adjusted)
+ self.assertFalse(info.adjustable)
@unittest.skipUnless(hasattr(time, 'monotonic'),
'need time.monotonic')
@@ -441,7 +437,7 @@ class TimeTestCase(unittest.TestCase):
# 0.0 < resolution <= 1.0
self.assertGreater(info.resolution, 0.0)
self.assertLessEqual(info.resolution, 1.0)
- self.assertIsInstance(info.adjusted, bool)
+ self.assertIsInstance(info.adjustable, bool)
self.assertRaises(ValueError, time.get_clock_info, 'xxx')
@@ -624,7 +620,58 @@ class TestPytime(unittest.TestCase):
for invalid in self.invalid_values:
self.assertRaises(OverflowError, pytime_object_to_timespec, invalid)
+ @unittest.skipUnless(time._STRUCT_TM_ITEMS == 11, "needs tm_zone support")
+ def test_localtime_timezone(self):
+
+ # Get the localtime and examine it for the offset and zone.
+ lt = time.localtime()
+ self.assertTrue(hasattr(lt, "tm_gmtoff"))
+ self.assertTrue(hasattr(lt, "tm_zone"))
+ # See if the offset and zone are similar to the module
+ # attributes.
+ if lt.tm_gmtoff is None:
+ self.assertTrue(not hasattr(time, "timezone"))
+ else:
+ self.assertEqual(lt.tm_gmtoff, -[time.timezone, time.altzone][lt.tm_isdst])
+ if lt.tm_zone is None:
+ self.assertTrue(not hasattr(time, "tzname"))
+ else:
+ self.assertEqual(lt.tm_zone, time.tzname[lt.tm_isdst])
+
+ # Try and make UNIX times from the localtime and a 9-tuple
+ # created from the localtime. Test to see that the times are
+ # the same.
+ t = time.mktime(lt); t9 = time.mktime(lt[:9])
+ self.assertEqual(t, t9)
+
+ # Make localtimes from the UNIX times and compare them to
+ # the original localtime, thus making a round trip.
+ new_lt = time.localtime(t); new_lt9 = time.localtime(t9)
+ self.assertEqual(new_lt, lt)
+ self.assertEqual(new_lt.tm_gmtoff, lt.tm_gmtoff)
+ self.assertEqual(new_lt.tm_zone, lt.tm_zone)
+ self.assertEqual(new_lt9, lt)
+ self.assertEqual(new_lt.tm_gmtoff, lt.tm_gmtoff)
+ self.assertEqual(new_lt9.tm_zone, lt.tm_zone)
+
+ @unittest.skipUnless(time._STRUCT_TM_ITEMS == 11, "needs tm_zone support")
+ def test_strptime_timezone(self):
+ t = time.strptime("UTC", "%Z")
+ self.assertEqual(t.tm_zone, 'UTC')
+ t = time.strptime("+0500", "%z")
+ self.assertEqual(t.tm_gmtoff, 5 * 3600)
+
+ @unittest.skipUnless(time._STRUCT_TM_ITEMS == 11, "needs tm_zone support")
+ def test_short_times(self):
+
+ import pickle
+
+ # Load a short time structure using pickle.
+ st = b"ctime\nstruct_time\np0\n((I2007\nI8\nI11\nI1\nI24\nI49\nI5\nI223\nI1\ntp1\n(dp2\ntp3\nRp4\n."
+ lt = pickle.loads(st)
+ self.assertIs(lt.tm_gmtoff, None)
+ self.assertIs(lt.tm_zone, None)
def test_main():
support.run_unittest(
diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py
index 49a5633..24ecae5 100644
--- a/Lib/test/test_xml_etree.py
+++ b/Lib/test/test_xml_etree.py
@@ -23,7 +23,8 @@ import weakref
from test import support
from test.support import findfile, import_fresh_module, gc_collect
-pyET = import_fresh_module('xml.etree.ElementTree', blocked=['_elementtree'])
+pyET = None
+ET = None
SIMPLE_XMLFILE = findfile("simple.xml", subdir="xmltestdata")
try:
@@ -209,10 +210,8 @@ def interface():
These methods return an iterable. See bug 6472.
- >>> check_method(element.iter("tag").__next__)
>>> check_method(element.iterfind("tag").__next__)
>>> check_method(element.iterfind("*").__next__)
- >>> check_method(tree.iter("tag").__next__)
>>> check_method(tree.iterfind("tag").__next__)
>>> check_method(tree.iterfind("*").__next__)
@@ -291,42 +290,6 @@ def cdata():
'<tag>hello</tag>'
"""
-# Only with Python implementation
-def simplefind():
- """
- Test find methods using the elementpath fallback.
-
- >>> ElementTree = pyET
-
- >>> CurrentElementPath = ElementTree.ElementPath
- >>> ElementTree.ElementPath = ElementTree._SimpleElementPath()
- >>> elem = ElementTree.XML(SAMPLE_XML)
- >>> elem.find("tag").tag
- 'tag'
- >>> ElementTree.ElementTree(elem).find("tag").tag
- 'tag'
- >>> elem.findtext("tag")
- 'text'
- >>> elem.findtext("tog")
- >>> elem.findtext("tog", "default")
- 'default'
- >>> ElementTree.ElementTree(elem).findtext("tag")
- 'text'
- >>> summarize_list(elem.findall("tag"))
- ['tag', 'tag']
- >>> summarize_list(elem.findall(".//tag"))
- ['tag', 'tag', 'tag']
-
- Path syntax doesn't work in this case.
-
- >>> elem.find("section/tag")
- >>> elem.findtext("section/tag")
- >>> summarize_list(elem.findall("section/tag"))
- []
-
- >>> ElementTree.ElementPath = CurrentElementPath
- """
-
def find():
"""
Test find methods (including xpath syntax).
@@ -1002,36 +965,6 @@ def methods():
'1 < 2\n'
"""
-def iterators():
- """
- Test iterators.
-
- >>> e = ET.XML("<html><body>this is a <i>paragraph</i>.</body>..</html>")
- >>> summarize_list(e.iter())
- ['html', 'body', 'i']
- >>> summarize_list(e.find("body").iter())
- ['body', 'i']
- >>> summarize(next(e.iter()))
- 'html'
- >>> "".join(e.itertext())
- 'this is a paragraph...'
- >>> "".join(e.find("body").itertext())
- 'this is a paragraph.'
- >>> next(e.itertext())
- 'this is a '
-
- Method iterparse should return an iterator. See bug 6472.
-
- >>> sourcefile = serialize(e, to_string=False)
- >>> next(ET.iterparse(sourcefile)) # doctest: +ELLIPSIS
- ('end', <Element 'i' at 0x...>)
-
- >>> tree = ET.ElementTree(None)
- >>> tree.iter()
- Traceback (most recent call last):
- AttributeError: 'NoneType' object has no attribute 'iter'
- """
-
ENTITY_XML = """\
<!DOCTYPE points [
<!ENTITY % user-entities SYSTEM 'user-entities.xml'>
@@ -1339,6 +1272,7 @@ XINCLUDE["default.xml"] = """\
</document>
""".format(html.escape(SIMPLE_XMLFILE, True))
+
def xinclude_loader(href, parse="xml", encoding=None):
try:
data = XINCLUDE[href]
@@ -1411,22 +1345,6 @@ def xinclude():
>>> # print(serialize(document)) # C5
"""
-def xinclude_default():
- """
- >>> from xml.etree import ElementInclude
-
- >>> document = xinclude_loader("default.xml")
- >>> ElementInclude.include(document)
- >>> print(serialize(document)) # default
- <document>
- <p>Example.</p>
- <root>
- <element key="value">text</element>
- <element>text</element>tail
- <empty-element />
- </root>
- </document>
- """
#
# badly formatted xi:include tags
@@ -1917,9 +1835,8 @@ class ElementTreeTest(unittest.TestCase):
self.assertIsInstance(ET.QName, type)
self.assertIsInstance(ET.ElementTree, type)
self.assertIsInstance(ET.Element, type)
- # XXX issue 14128 with C ElementTree
- # self.assertIsInstance(ET.TreeBuilder, type)
- # self.assertIsInstance(ET.XMLParser, type)
+ self.assertIsInstance(ET.TreeBuilder, type)
+ self.assertIsInstance(ET.XMLParser, type)
def test_Element_subclass_trivial(self):
class MyElement(ET.Element):
@@ -1953,6 +1870,73 @@ class ElementTreeTest(unittest.TestCase):
self.assertEqual(mye.newmethod(), 'joe')
+class ElementIterTest(unittest.TestCase):
+ def _ilist(self, elem, tag=None):
+ return summarize_list(elem.iter(tag))
+
+ def test_basic(self):
+ doc = ET.XML("<html><body>this is a <i>paragraph</i>.</body>..</html>")
+ self.assertEqual(self._ilist(doc), ['html', 'body', 'i'])
+ self.assertEqual(self._ilist(doc.find('body')), ['body', 'i'])
+ self.assertEqual(next(doc.iter()).tag, 'html')
+ self.assertEqual(''.join(doc.itertext()), 'this is a paragraph...')
+ self.assertEqual(''.join(doc.find('body').itertext()),
+ 'this is a paragraph.')
+ self.assertEqual(next(doc.itertext()), 'this is a ')
+
+ # iterparse should return an iterator
+ sourcefile = serialize(doc, to_string=False)
+ self.assertEqual(next(ET.iterparse(sourcefile))[0], 'end')
+
+ tree = ET.ElementTree(None)
+ self.assertRaises(AttributeError, tree.iter)
+
+ def test_corners(self):
+ # single root, no subelements
+ a = ET.Element('a')
+ self.assertEqual(self._ilist(a), ['a'])
+
+ # one child
+ b = ET.SubElement(a, 'b')
+ self.assertEqual(self._ilist(a), ['a', 'b'])
+
+ # one child and one grandchild
+ c = ET.SubElement(b, 'c')
+ self.assertEqual(self._ilist(a), ['a', 'b', 'c'])
+
+ # two children, only first with grandchild
+ d = ET.SubElement(a, 'd')
+ self.assertEqual(self._ilist(a), ['a', 'b', 'c', 'd'])
+
+ # replace first child by second
+ a[0] = a[1]
+ del a[1]
+ self.assertEqual(self._ilist(a), ['a', 'd'])
+
+ def test_iter_by_tag(self):
+ doc = ET.XML('''
+ <document>
+ <house>
+ <room>bedroom1</room>
+ <room>bedroom2</room>
+ </house>
+ <shed>nothing here
+ </shed>
+ <house>
+ <room>bedroom8</room>
+ </house>
+ </document>''')
+
+ self.assertEqual(self._ilist(doc, 'room'), ['room'] * 3)
+ self.assertEqual(self._ilist(doc, 'house'), ['house'] * 2)
+
+ # make sure both tag=None and tag='*' return all tags
+ all_tags = ['document', 'house', 'room', 'room',
+ 'shed', 'house', 'room']
+ self.assertEqual(self._ilist(doc), all_tags)
+ self.assertEqual(self._ilist(doc, '*'), all_tags)
+
+
class TreeBuilderTest(unittest.TestCase):
sample1 = ('<!DOCTYPE html PUBLIC'
' "-//W3C//DTD XHTML 1.0 Transitional//EN"'
@@ -2027,6 +2011,23 @@ class TreeBuilderTest(unittest.TestCase):
'http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd'))
+@unittest.skip('Unstable due to module monkeypatching')
+class XincludeTest(unittest.TestCase):
+ def test_xinclude_default(self):
+ from xml.etree import ElementInclude
+ doc = xinclude_loader('default.xml')
+ ElementInclude.include(doc)
+ s = serialize(doc)
+ self.assertEqual(s.strip(), '''<document>
+ <p>Example.</p>
+ <root>
+ <element key="value">text</element>
+ <element>text</element>tail
+ <empty-element />
+</root>
+</document>''')
+
+
class XMLParserTest(unittest.TestCase):
sample1 = '<file><line>22</line></file>'
sample2 = ('<!DOCTYPE html PUBLIC'
@@ -2073,13 +2074,6 @@ class XMLParserTest(unittest.TestCase):
'http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd'))
-class NoAcceleratorTest(unittest.TestCase):
- # Test that the C accelerator was not imported for pyET
- def test_correct_import_pyET(self):
- self.assertEqual(pyET.Element.__module__, 'xml.etree.ElementTree')
- self.assertEqual(pyET.SubElement.__module__, 'xml.etree.ElementTree')
-
-
class NamespaceParseTest(unittest.TestCase):
def test_find_with_namespace(self):
nsmap = {'h': 'hello', 'f': 'foo'}
@@ -2090,7 +2084,6 @@ class NamespaceParseTest(unittest.TestCase):
self.assertEqual(len(doc.findall('.//{foo}name', nsmap)), 1)
-
class ElementSlicingTest(unittest.TestCase):
def _elem_tags(self, elemlist):
return [e.tag for e in elemlist]
@@ -2232,6 +2225,14 @@ class KeywordArgsTest(unittest.TestCase):
with self.assertRaisesRegex(TypeError, 'must be dict, not str'):
ET.Element('a', attrib="I'm not a dict")
+# --------------------------------------------------------------------
+
+@unittest.skipUnless(pyET, 'only for the Python version')
+class NoAcceleratorTest(unittest.TestCase):
+ # Test that the C accelerator was not imported for pyET
+ def test_correct_import_pyET(self):
+ self.assertEqual(pyET.Element.__module__, 'xml.etree.ElementTree')
+ self.assertEqual(pyET.SubElement.__module__, 'xml.etree.ElementTree')
# --------------------------------------------------------------------
@@ -2276,31 +2277,42 @@ class CleanContext(object):
self.checkwarnings.__exit__(*args)
-def test_main(module=pyET):
- from test import test_xml_etree
+def test_main(module=None):
+ # When invoked without a module, runs the Python ET tests by loading pyET.
+ # Otherwise, uses the given module as the ET.
+ if module is None:
+ global pyET
+ pyET = import_fresh_module('xml.etree.ElementTree',
+ blocked=['_elementtree'])
+ module = pyET
- # The same doctests are used for both the Python and the C implementations
- test_xml_etree.ET = module
+ global ET
+ ET = module
test_classes = [
ElementSlicingTest,
BasicElementTest,
StringIOTest,
ParseErrorTest,
+ XincludeTest,
ElementTreeTest,
- NamespaceParseTest,
+ ElementIterTest,
TreeBuilderTest,
- XMLParserTest,
- KeywordArgsTest]
- if module is pyET:
- # Run the tests specific to the Python implementation
- test_classes += [NoAcceleratorTest]
+ ]
+
+ # These tests will only run for the pure-Python version that doesn't import
+ # _elementtree. We can't use skipUnless here, because pyET is filled in only
+ # after the module is loaded.
+ if pyET:
+ test_classes.extend([
+ NoAcceleratorTest,
+ ])
support.run_unittest(*test_classes)
# XXX the C module should give the same warnings as the Python module
with CleanContext(quiet=(module is not pyET)):
- support.run_doctest(test_xml_etree, verbosity=True)
+ support.run_doctest(sys.modules[__name__], verbosity=True)
if __name__ == '__main__':
test_main()
diff --git a/Lib/test/test_xml_etree_c.py b/Lib/test/test_xml_etree_c.py
index 10416d2..142a22f 100644
--- a/Lib/test/test_xml_etree_c.py
+++ b/Lib/test/test_xml_etree_c.py
@@ -8,31 +8,6 @@ cET = import_fresh_module('xml.etree.ElementTree', fresh=['_elementtree'])
cET_alias = import_fresh_module('xml.etree.cElementTree', fresh=['_elementtree', 'xml.etree'])
-# cElementTree specific tests
-
-def sanity():
- r"""
- Import sanity.
-
- Issue #6697.
-
- >>> cElementTree = cET
- >>> e = cElementTree.Element('a')
- >>> getattr(e, '\uD800') # doctest: +ELLIPSIS
- Traceback (most recent call last):
- ...
- UnicodeEncodeError: ...
-
- >>> p = cElementTree.XMLParser()
- >>> p.version.split()[0]
- 'Expat'
- >>> getattr(p, '\uD800')
- Traceback (most recent call last):
- ...
- AttributeError: 'XMLParser' object has no attribute '\ud800'
- """
-
-
class MiscTests(unittest.TestCase):
# Issue #8651.
@support.bigmemtest(size=support._2G + 100, memuse=1)
@@ -46,6 +21,7 @@ class MiscTests(unittest.TestCase):
finally:
data = None
+
@unittest.skipUnless(cET, 'requires _elementtree')
class TestAliasWorking(unittest.TestCase):
# Test that the cET alias module is alive
@@ -53,6 +29,7 @@ class TestAliasWorking(unittest.TestCase):
e = cET_alias.Element('foo')
self.assertEqual(e.tag, 'foo')
+
@unittest.skipUnless(cET, 'requires _elementtree')
class TestAcceleratorImported(unittest.TestCase):
# Test that the C accelerator was imported, as expected
@@ -67,7 +44,6 @@ def test_main():
from test import test_xml_etree, test_xml_etree_c
# Run the tests specific to the C implementation
- support.run_doctest(test_xml_etree_c, verbosity=True)
support.run_unittest(
MiscTests,
TestAliasWorking,
diff --git a/Lib/xml/etree/ElementTree.py b/Lib/xml/etree/ElementTree.py
index e068fc2..d30a83c 100644
--- a/Lib/xml/etree/ElementTree.py
+++ b/Lib/xml/etree/ElementTree.py
@@ -101,32 +101,8 @@ import sys
import re
import warnings
-class _SimpleElementPath:
- # emulate pre-1.2 find/findtext/findall behaviour
- def find(self, element, tag, namespaces=None):
- for elem in element:
- if elem.tag == tag:
- return elem
- return None
- def findtext(self, element, tag, default=None, namespaces=None):
- elem = self.find(element, tag)
- if elem is None:
- return default
- return elem.text or ""
- def iterfind(self, element, tag, namespaces=None):
- if tag[:3] == ".//":
- for elem in element.iter(tag[3:]):
- yield elem
- for elem in element:
- if elem.tag == tag:
- yield elem
- def findall(self, element, tag, namespaces=None):
- return list(self.iterfind(element, tag, namespaces))
+from . import ElementPath
-try:
- from . import ElementPath
-except ImportError:
- ElementPath = _SimpleElementPath()
##
# Parser error. This is a subclass of <b>SyntaxError</b>.
@@ -916,11 +892,7 @@ def _namespaces(elem, default_namespace=None):
_raise_serialization_error(qname)
# populate qname and namespaces table
- try:
- iterate = elem.iter
- except AttributeError:
- iterate = elem.getiterator # cET compatibility
- for elem in iterate():
+ for elem in elem.iter():
tag = elem.tag
if isinstance(tag, QName):
if tag.text not in qnames: