diff options
author | Brett Cannon <brett@python.org> | 2012-06-15 23:04:29 (GMT) |
---|---|---|
committer | Brett Cannon <brett@python.org> | 2012-06-15 23:04:29 (GMT) |
commit | 24aa693c7ef8f217fbd238eb7af7d828e13a07eb (patch) | |
tree | 7d01ab630c2e8eef1e168b1aa5d84131b60cfd50 /Lib | |
parent | 99d776fdf4aa5a66266ebcec2263fab501f03088 (diff) | |
parent | 016ef551a793f72f582d707ce5bb55bf4940cf27 (diff) | |
download | cpython-24aa693c7ef8f217fbd238eb7af7d828e13a07eb.zip cpython-24aa693c7ef8f217fbd238eb7af7d828e13a07eb.tar.gz cpython-24aa693c7ef8f217fbd238eb7af7d828e13a07eb.tar.bz2 |
Merge
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/_strptime.py | 6 | ||||
-rw-r--r-- | Lib/datetime.py | 6 | ||||
-rw-r--r-- | Lib/hmac.py | 28 | ||||
-rw-r--r-- | Lib/idlelib/AutoComplete.py | 2 | ||||
-rw-r--r-- | Lib/mailbox.py | 1 | ||||
-rw-r--r-- | Lib/multiprocessing/__init__.py | 11 | ||||
-rw-r--r-- | Lib/multiprocessing/dummy/__init__.py | 4 | ||||
-rw-r--r-- | Lib/multiprocessing/forking.py | 6 | ||||
-rw-r--r-- | Lib/multiprocessing/managers.py | 94 | ||||
-rw-r--r-- | Lib/multiprocessing/synchronize.py | 40 | ||||
-rw-r--r-- | Lib/multiprocessing/util.py | 27 | ||||
-rw-r--r-- | Lib/test/support.py | 2 | ||||
-rw-r--r-- | Lib/test/test_hmac.py | 44 | ||||
-rw-r--r-- | Lib/test/test_mailbox.py | 11 | ||||
-rw-r--r-- | Lib/test/test_multiprocessing.py | 355 | ||||
-rw-r--r-- | Lib/test/test_structseq.py | 5 | ||||
-rw-r--r-- | Lib/test/test_time.py | 65 | ||||
-rw-r--r-- | Lib/test/test_xml_etree.py | 226 | ||||
-rw-r--r-- | Lib/test/test_xml_etree_c.py | 28 | ||||
-rw-r--r-- | Lib/xml/etree/ElementTree.py | 32 |
20 files changed, 708 insertions, 285 deletions
diff --git a/Lib/_strptime.py b/Lib/_strptime.py index fa06376..b0cd3d6 100644 --- a/Lib/_strptime.py +++ b/Lib/_strptime.py @@ -486,19 +486,19 @@ def _strptime(data_string, format="%a %b %d %H:%M:%S %Y"): return (year, month, day, hour, minute, second, - weekday, julian, tz, gmtoff, tzname), fraction + weekday, julian, tz, tzname, gmtoff), fraction def _strptime_time(data_string, format="%a %b %d %H:%M:%S %Y"): """Return a time struct based on the input string and the format string.""" tt = _strptime(data_string, format)[0] - return time.struct_time(tt[:9]) + return time.struct_time(tt[:time._STRUCT_TM_ITEMS]) def _strptime_datetime(cls, data_string, format="%a %b %d %H:%M:%S %Y"): """Return a class cls instance based on the input string and the format string.""" tt, fraction = _strptime(data_string, format) - gmtoff, tzname = tt[-2:] + tzname, gmtoff = tt[-2:] args = tt[:6] + (fraction,) if gmtoff is not None: tzdelta = datetime_timedelta(seconds=gmtoff) diff --git a/Lib/datetime.py b/Lib/datetime.py index 5d8d9b3..21aab35 100644 --- a/Lib/datetime.py +++ b/Lib/datetime.py @@ -1670,10 +1670,8 @@ class datetime(date): if mytz is ottz: base_compare = True else: - if mytz is not None: - myoff = self.utcoffset() - if ottz is not None: - otoff = other.utcoffset() + myoff = self.utcoffset() + otoff = other.utcoffset() base_compare = myoff == otoff if base_compare: diff --git a/Lib/hmac.py b/Lib/hmac.py index 13ffdbe..e47965b 100644 --- a/Lib/hmac.py +++ b/Lib/hmac.py @@ -13,24 +13,24 @@ trans_36 = bytes((x ^ 0x36) for x in range(256)) digest_size = None -def secure_compare(a, b): - """Returns the equivalent of 'a == b', but using a time-independent - comparison method to prevent timing attacks.""" - if not ((isinstance(a, str) and isinstance(b, str)) or - (isinstance(a, bytes) and isinstance(b, bytes))): - raise TypeError("inputs must be strings or bytes") - +def compare_digest(a, b): + """Returns the equivalent of 'a == b', but avoids content based short + circuiting to reduce the vulnerability to timing attacks.""" + # Consistent timing matters more here than data type flexibility + if not (isinstance(a, bytes) and isinstance(b, bytes)): + raise TypeError("inputs must be bytes instances") + + # We assume the length of the expected digest is public knowledge, + # thus this early return isn't leaking anything an attacker wouldn't + # already know if len(a) != len(b): return False + # We assume that integers in the bytes range are all cached, + # thus timing shouldn't vary much due to integer object creation result = 0 - if isinstance(a, bytes): - for x, y in zip(a, b): - result |= x ^ y - else: - for x, y in zip(a, b): - result |= ord(x) ^ ord(y) - + for x, y in zip(a, b): + result |= x ^ y return result == 0 diff --git a/Lib/idlelib/AutoComplete.py b/Lib/idlelib/AutoComplete.py index b38b108..929d358 100644 --- a/Lib/idlelib/AutoComplete.py +++ b/Lib/idlelib/AutoComplete.py @@ -140,7 +140,7 @@ class AutoComplete: elif hp.is_in_code() and (not mode or mode==COMPLETE_ATTRIBUTES): self._remove_autocomplete_window() mode = COMPLETE_ATTRIBUTES - while i and curline[i-1] in ID_CHARS or ord(curline[i-1]) > 127: + while i and (curline[i-1] in ID_CHARS or ord(curline[i-1]) > 127): i -= 1 comp_start = curline[i:j] if i and curline[i-1] == '.': diff --git a/Lib/mailbox.py b/Lib/mailbox.py index 7a29555..d86ad94 100644 --- a/Lib/mailbox.py +++ b/Lib/mailbox.py @@ -675,6 +675,7 @@ class _singlefileMailbox(Mailbox): new_file.write(buffer) new_toc[key] = (new_start, new_file.tell()) self._post_message_hook(new_file) + self._file_length = new_file.tell() except: new_file.close() os.remove(new_file.name) diff --git a/Lib/multiprocessing/__init__.py b/Lib/multiprocessing/__init__.py index 02460f0..1f3e67c 100644 --- a/Lib/multiprocessing/__init__.py +++ b/Lib/multiprocessing/__init__.py @@ -23,8 +23,8 @@ __all__ = [ 'Manager', 'Pipe', 'cpu_count', 'log_to_stderr', 'get_logger', 'allow_connection_pickling', 'BufferTooShort', 'TimeoutError', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Condition', - 'Event', 'Queue', 'SimpleQueue', 'JoinableQueue', 'Pool', 'Value', 'Array', - 'RawValue', 'RawArray', 'SUBDEBUG', 'SUBWARNING', + 'Event', 'Barrier', 'Queue', 'SimpleQueue', 'JoinableQueue', 'Pool', + 'Value', 'Array', 'RawValue', 'RawArray', 'SUBDEBUG', 'SUBWARNING', ] __author__ = 'R. Oudkerk (r.m.oudkerk@gmail.com)' @@ -186,6 +186,13 @@ def Event(): from multiprocessing.synchronize import Event return Event() +def Barrier(parties, action=None, timeout=None): + ''' + Returns a barrier object + ''' + from multiprocessing.synchronize import Barrier + return Barrier(parties, action, timeout) + def Queue(maxsize=0): ''' Returns a queue object diff --git a/Lib/multiprocessing/dummy/__init__.py b/Lib/multiprocessing/dummy/__init__.py index 9bf8f6b..e31fc61 100644 --- a/Lib/multiprocessing/dummy/__init__.py +++ b/Lib/multiprocessing/dummy/__init__.py @@ -35,7 +35,7 @@ __all__ = [ 'Process', 'current_process', 'active_children', 'freeze_support', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', 'Condition', - 'Event', 'Queue', 'Manager', 'Pipe', 'Pool', 'JoinableQueue' + 'Event', 'Barrier', 'Queue', 'Manager', 'Pipe', 'Pool', 'JoinableQueue' ] # @@ -49,7 +49,7 @@ import array from multiprocessing.dummy.connection import Pipe from threading import Lock, RLock, Semaphore, BoundedSemaphore -from threading import Event, Condition +from threading import Event, Condition, Barrier from queue import Queue # diff --git a/Lib/multiprocessing/forking.py b/Lib/multiprocessing/forking.py index 3a474cd..4baf548 100644 --- a/Lib/multiprocessing/forking.py +++ b/Lib/multiprocessing/forking.py @@ -13,7 +13,7 @@ import signal from multiprocessing import util, process -__all__ = ['Popen', 'assert_spawning', 'exit', 'duplicate', 'close', 'ForkingPickler'] +__all__ = ['Popen', 'assert_spawning', 'duplicate', 'close', 'ForkingPickler'] # # Check that the current thread is spawning a child process @@ -75,7 +75,6 @@ else: # if sys.platform != 'win32': - exit = os._exit duplicate = os.dup close = os.close @@ -168,7 +167,6 @@ else: WINEXE = (sys.platform == 'win32' and getattr(sys, 'frozen', False)) WINSERVICE = sys.executable.lower().endswith("pythonservice.exe") - exit = _winapi.ExitProcess close = _winapi.CloseHandle # @@ -349,7 +347,7 @@ else: from_parent.close() exitcode = self._bootstrap() - exit(exitcode) + sys.exit(exitcode) def get_preparation_data(name): diff --git a/Lib/multiprocessing/managers.py b/Lib/multiprocessing/managers.py index 6a7dccb..f6611af 100644 --- a/Lib/multiprocessing/managers.py +++ b/Lib/multiprocessing/managers.py @@ -22,7 +22,7 @@ import queue from traceback import format_exc from multiprocessing import Process, current_process, active_children, Pool, util, connection from multiprocessing.process import AuthenticationString -from multiprocessing.forking import exit, Popen, ForkingPickler +from multiprocessing.forking import Popen, ForkingPickler from time import time as _time # @@ -140,28 +140,38 @@ class Server(object): self.id_to_obj = {'0': (None, ())} self.id_to_refcount = {} self.mutex = threading.RLock() - self.stop = 0 def serve_forever(self): ''' Run the server forever ''' + self.stop_event = threading.Event() current_process()._manager_server = self try: + accepter = threading.Thread(target=self.accepter) + accepter.daemon = True + accepter.start() try: - while 1: - try: - c = self.listener.accept() - except (OSError, IOError): - continue - t = threading.Thread(target=self.handle_request, args=(c,)) - t.daemon = True - t.start() + while not self.stop_event.is_set(): + self.stop_event.wait(1) except (KeyboardInterrupt, SystemExit): pass finally: - self.stop = 999 - self.listener.close() + if sys.stdout != sys.__stdout__: + util.debug('resetting stdout, stderr') + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ + sys.exit(0) + + def accepter(self): + while True: + try: + c = self.listener.accept() + except (OSError, IOError): + continue + t = threading.Thread(target=self.handle_request, args=(c,)) + t.daemon = True + t.start() def handle_request(self, c): ''' @@ -208,7 +218,7 @@ class Server(object): send = conn.send id_to_obj = self.id_to_obj - while not self.stop: + while not self.stop_event.is_set(): try: methodname = obj = None @@ -318,32 +328,13 @@ class Server(object): Shutdown this process ''' try: - try: - util.debug('manager received shutdown message') - c.send(('#RETURN', None)) - - if sys.stdout != sys.__stdout__: - util.debug('resetting stdout, stderr') - sys.stdout = sys.__stdout__ - sys.stderr = sys.__stderr__ - - util._run_finalizers(0) - - for p in active_children(): - util.debug('terminating a child process of manager') - p.terminate() - - for p in active_children(): - util.debug('terminating a child process of manager') - p.join() - - util._run_finalizers() - util.info('manager exiting with exitcode 0') - except: - import traceback - traceback.print_exc() + util.debug('manager received shutdown message') + c.send(('#RETURN', None)) + except: + import traceback + traceback.print_exc() finally: - exit(0) + self.stop_event.set() def create(self, c, typeid, *args, **kwds): ''' @@ -455,10 +446,6 @@ class BaseManager(object): self._serializer = serializer self._Listener, self._Client = listener_client[serializer] - def __reduce__(self): - return type(self).from_address, \ - (self._address, self._authkey, self._serializer) - def get_server(self): ''' Return server object with serve_forever() method and address attribute @@ -595,7 +582,7 @@ class BaseManager(object): except Exception: pass - process.join(timeout=0.2) + process.join(timeout=1.0) if process.is_alive(): util.info('manager still alive') if hasattr(process, 'terminate'): @@ -1006,6 +993,26 @@ class EventProxy(BaseProxy): def wait(self, timeout=None): return self._callmethod('wait', (timeout,)) + +class BarrierProxy(BaseProxy): + _exposed_ = ('__getattribute__', 'wait', 'abort', 'reset') + def wait(self, timeout=None): + return self._callmethod('wait', (timeout,)) + def abort(self): + return self._callmethod('abort') + def reset(self): + return self._callmethod('reset') + @property + def parties(self): + return self._callmethod('__getattribute__', ('parties',)) + @property + def n_waiting(self): + return self._callmethod('__getattribute__', ('n_waiting',)) + @property + def broken(self): + return self._callmethod('__getattribute__', ('broken',)) + + class NamespaceProxy(BaseProxy): _exposed_ = ('__getattribute__', '__setattr__', '__delattr__') def __getattr__(self, key): @@ -1097,6 +1104,7 @@ SyncManager.register('Semaphore', threading.Semaphore, AcquirerProxy) SyncManager.register('BoundedSemaphore', threading.BoundedSemaphore, AcquirerProxy) SyncManager.register('Condition', threading.Condition, ConditionProxy) +SyncManager.register('Barrier', threading.Barrier, BarrierProxy) SyncManager.register('Pool', Pool, PoolProxy) SyncManager.register('list', list, ListProxy) SyncManager.register('dict', dict, DictProxy) diff --git a/Lib/multiprocessing/synchronize.py b/Lib/multiprocessing/synchronize.py index 4502a97..22eabe5 100644 --- a/Lib/multiprocessing/synchronize.py +++ b/Lib/multiprocessing/synchronize.py @@ -333,3 +333,43 @@ class Event(object): return False finally: self._cond.release() + +# +# Barrier +# + +class Barrier(threading.Barrier): + + def __init__(self, parties, action=None, timeout=None): + import struct + from multiprocessing.heap import BufferWrapper + wrapper = BufferWrapper(struct.calcsize('i') * 2) + cond = Condition() + self.__setstate__((parties, action, timeout, cond, wrapper)) + self._state = 0 + self._count = 0 + + def __setstate__(self, state): + (self._parties, self._action, self._timeout, + self._cond, self._wrapper) = state + self._array = self._wrapper.create_memoryview().cast('i') + + def __getstate__(self): + return (self._parties, self._action, self._timeout, + self._cond, self._wrapper) + + @property + def _state(self): + return self._array[0] + + @_state.setter + def _state(self, value): + self._array[0] = value + + @property + def _count(self): + return self._array[1] + + @_count.setter + def _count(self, value): + self._array[1] = value diff --git a/Lib/multiprocessing/util.py b/Lib/multiprocessing/util.py index 48abe38..8a6aede 100644 --- a/Lib/multiprocessing/util.py +++ b/Lib/multiprocessing/util.py @@ -269,21 +269,24 @@ _exiting = False def _exit_function(): global _exiting - info('process shutting down') - debug('running all "atexit" finalizers with priority >= 0') - _run_finalizers(0) + if not _exiting: + _exiting = True - for p in active_children(): - if p._daemonic: - info('calling terminate() for daemon %s', p.name) - p._popen.terminate() + info('process shutting down') + debug('running all "atexit" finalizers with priority >= 0') + _run_finalizers(0) - for p in active_children(): - info('calling join() for process %s', p.name) - p.join() + for p in active_children(): + if p._daemonic: + info('calling terminate() for daemon %s', p.name) + p._popen.terminate() - debug('running the remaining "atexit" finalizers') - _run_finalizers() + for p in active_children(): + info('calling join() for process %s', p.name) + p.join() + + debug('running the remaining "atexit" finalizers') + _run_finalizers() atexit.register(_exit_function) diff --git a/Lib/test/support.py b/Lib/test/support.py index 6749a511..3ff1df5 100644 --- a/Lib/test/support.py +++ b/Lib/test/support.py @@ -1593,7 +1593,7 @@ def strip_python_stderr(stderr): This will typically be run on the result of the communicate() method of a subprocess.Popen object. """ - stderr = re.sub(br"\[\d+ refs\]\r?\n?$", b"", stderr).strip() + stderr = re.sub(br"\[\d+ refs\]\r?\n?", b"", stderr).strip() return stderr def args_from_interpreter_flags(): diff --git a/Lib/test/test_hmac.py b/Lib/test/test_hmac.py index 042bc5d..4e5961d 100644 --- a/Lib/test/test_hmac.py +++ b/Lib/test/test_hmac.py @@ -302,40 +302,42 @@ class CopyTestCase(unittest.TestCase): self.assertEqual(h1.hexdigest(), h2.hexdigest(), "Hexdigest of copy doesn't match original hexdigest.") -class SecureCompareTestCase(unittest.TestCase): +class CompareDigestTestCase(unittest.TestCase): def test_compare(self): # Testing input type exception handling a, b = 100, 200 - self.assertRaises(TypeError, hmac.secure_compare, a, b) - a, b = 100, "foobar" - self.assertRaises(TypeError, hmac.secure_compare, a, b) + self.assertRaises(TypeError, hmac.compare_digest, a, b) + a, b = 100, b"foobar" + self.assertRaises(TypeError, hmac.compare_digest, a, b) + a, b = b"foobar", 200 + self.assertRaises(TypeError, hmac.compare_digest, a, b) a, b = "foobar", b"foobar" - self.assertRaises(TypeError, hmac.secure_compare, a, b) + self.assertRaises(TypeError, hmac.compare_digest, a, b) + a, b = b"foobar", "foobar" + self.assertRaises(TypeError, hmac.compare_digest, a, b) + a, b = "foobar", "foobar" + self.assertRaises(TypeError, hmac.compare_digest, a, b) + a, b = bytearray(b"foobar"), bytearray(b"foobar") + self.assertRaises(TypeError, hmac.compare_digest, a, b) - # Testing str/bytes of different lengths - a, b = "foobar", "foo" - self.assertFalse(hmac.secure_compare(a, b)) + # Testing bytes of different lengths a, b = b"foobar", b"foo" - self.assertFalse(hmac.secure_compare(a, b)) + self.assertFalse(hmac.compare_digest(a, b)) a, b = b"\xde\xad\xbe\xef", b"\xde\xad" - self.assertFalse(hmac.secure_compare(a, b)) + self.assertFalse(hmac.compare_digest(a, b)) - # Testing str/bytes of same lengths, different values - a, b = "foobar", "foobaz" - self.assertFalse(hmac.secure_compare(a, b)) + # Testing bytes of same lengths, different values a, b = b"foobar", b"foobaz" - self.assertFalse(hmac.secure_compare(a, b)) + self.assertFalse(hmac.compare_digest(a, b)) a, b = b"\xde\xad\xbe\xef", b"\xab\xad\x1d\xea" - self.assertFalse(hmac.secure_compare(a, b)) + self.assertFalse(hmac.compare_digest(a, b)) - # Testing str/bytes of same lengths, same values - a, b = "foobar", "foobar" - self.assertTrue(hmac.secure_compare(a, b)) + # Testing bytes of same lengths, same values a, b = b"foobar", b"foobar" - self.assertTrue(hmac.secure_compare(a, b)) + self.assertTrue(hmac.compare_digest(a, b)) a, b = b"\xde\xad\xbe\xef", b"\xde\xad\xbe\xef" - self.assertTrue(hmac.secure_compare(a, b)) + self.assertTrue(hmac.compare_digest(a, b)) def test_main(): support.run_unittest( @@ -343,7 +345,7 @@ def test_main(): ConstructorTestCase, SanityTestCase, CopyTestCase, - SecureCompareTestCase + CompareDigestTestCase ) if __name__ == "__main__": diff --git a/Lib/test/test_mailbox.py b/Lib/test/test_mailbox.py index 0fe4bd9..cb54505 100644 --- a/Lib/test/test_mailbox.py +++ b/Lib/test/test_mailbox.py @@ -504,6 +504,17 @@ class TestMailbox(TestBase): # Write changes to disk self._test_flush_or_close(self._box.flush, True) + def test_popitem_and_flush_twice(self): + # See #15036. + self._box.add(self._template % 0) + self._box.add(self._template % 1) + self._box.flush() + + self._box.popitem() + self._box.flush() + self._box.popitem() + self._box.flush() + def test_lock_unlock(self): # Lock and unlock the mailbox self.assertFalse(os.path.exists(self._get_lock_path())) diff --git a/Lib/test/test_multiprocessing.py b/Lib/test/test_multiprocessing.py index 65e7b0b..2704827 100644 --- a/Lib/test/test_multiprocessing.py +++ b/Lib/test/test_multiprocessing.py @@ -18,6 +18,7 @@ import array import socket import random import logging +import struct import test.support @@ -1057,6 +1058,340 @@ class _TestEvent(BaseTestCase): self.assertEqual(wait(), True) # +# Tests for Barrier - adapted from tests in test/lock_tests.py +# + +# Many of the tests for threading.Barrier use a list as an atomic +# counter: a value is appended to increment the counter, and the +# length of the list gives the value. We use the class DummyList +# for the same purpose. + +class _DummyList(object): + + def __init__(self): + wrapper = multiprocessing.heap.BufferWrapper(struct.calcsize('i')) + lock = multiprocessing.Lock() + self.__setstate__((wrapper, lock)) + self._lengthbuf[0] = 0 + + def __setstate__(self, state): + (self._wrapper, self._lock) = state + self._lengthbuf = self._wrapper.create_memoryview().cast('i') + + def __getstate__(self): + return (self._wrapper, self._lock) + + def append(self, _): + with self._lock: + self._lengthbuf[0] += 1 + + def __len__(self): + with self._lock: + return self._lengthbuf[0] + +def _wait(): + # A crude wait/yield function not relying on synchronization primitives. + time.sleep(0.01) + + +class Bunch(object): + """ + A bunch of threads. + """ + def __init__(self, namespace, f, args, n, wait_before_exit=False): + """ + Construct a bunch of `n` threads running the same function `f`. + If `wait_before_exit` is True, the threads won't terminate until + do_finish() is called. + """ + self.f = f + self.args = args + self.n = n + self.started = namespace.DummyList() + self.finished = namespace.DummyList() + self._can_exit = namespace.Event() + if not wait_before_exit: + self._can_exit.set() + for i in range(n): + p = namespace.Process(target=self.task) + p.daemon = True + p.start() + + def task(self): + pid = os.getpid() + self.started.append(pid) + try: + self.f(*self.args) + finally: + self.finished.append(pid) + self._can_exit.wait(30) + assert self._can_exit.is_set() + + def wait_for_started(self): + while len(self.started) < self.n: + _wait() + + def wait_for_finished(self): + while len(self.finished) < self.n: + _wait() + + def do_finish(self): + self._can_exit.set() + + +class AppendTrue(object): + def __init__(self, obj): + self.obj = obj + def __call__(self): + self.obj.append(True) + + +class _TestBarrier(BaseTestCase): + """ + Tests for Barrier objects. + """ + N = 5 + defaultTimeout = 10.0 # XXX Slow Windows buildbots need generous timeout + + def setUp(self): + self.barrier = self.Barrier(self.N, timeout=self.defaultTimeout) + + def tearDown(self): + self.barrier.abort() + self.barrier = None + + def DummyList(self): + if self.TYPE == 'threads': + return [] + elif self.TYPE == 'manager': + return self.manager.list() + else: + return _DummyList() + + def run_threads(self, f, args): + b = Bunch(self, f, args, self.N-1) + f(*args) + b.wait_for_finished() + + @classmethod + def multipass(cls, barrier, results, n): + m = barrier.parties + assert m == cls.N + for i in range(n): + results[0].append(True) + assert len(results[1]) == i * m + barrier.wait() + results[1].append(True) + assert len(results[0]) == (i + 1) * m + barrier.wait() + try: + assert barrier.n_waiting == 0 + except NotImplementedError: + pass + assert not barrier.broken + + def test_barrier(self, passes=1): + """ + Test that a barrier is passed in lockstep + """ + results = [self.DummyList(), self.DummyList()] + self.run_threads(self.multipass, (self.barrier, results, passes)) + + def test_barrier_10(self): + """ + Test that a barrier works for 10 consecutive runs + """ + return self.test_barrier(10) + + @classmethod + def _test_wait_return_f(cls, barrier, queue): + res = barrier.wait() + queue.put(res) + + def test_wait_return(self): + """ + test the return value from barrier.wait + """ + queue = self.Queue() + self.run_threads(self._test_wait_return_f, (self.barrier, queue)) + results = [queue.get() for i in range(self.N)] + self.assertEqual(results.count(0), 1) + + @classmethod + def _test_action_f(cls, barrier, results): + barrier.wait() + if len(results) != 1: + raise RuntimeError + + def test_action(self): + """ + Test the 'action' callback + """ + results = self.DummyList() + barrier = self.Barrier(self.N, action=AppendTrue(results)) + self.run_threads(self._test_action_f, (barrier, results)) + self.assertEqual(len(results), 1) + + @classmethod + def _test_abort_f(cls, barrier, results1, results2): + try: + i = barrier.wait() + if i == cls.N//2: + raise RuntimeError + barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + except RuntimeError: + barrier.abort() + + def test_abort(self): + """ + Test that an abort will put the barrier in a broken state + """ + results1 = self.DummyList() + results2 = self.DummyList() + self.run_threads(self._test_abort_f, + (self.barrier, results1, results2)) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertTrue(self.barrier.broken) + + @classmethod + def _test_reset_f(cls, barrier, results1, results2, results3): + i = barrier.wait() + if i == cls.N//2: + # Wait until the other threads are all in the barrier. + while barrier.n_waiting < cls.N-1: + time.sleep(0.001) + barrier.reset() + else: + try: + barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + # Now, pass the barrier again + barrier.wait() + results3.append(True) + + def test_reset(self): + """ + Test that a 'reset' on a barrier frees the waiting threads + """ + results1 = self.DummyList() + results2 = self.DummyList() + results3 = self.DummyList() + self.run_threads(self._test_reset_f, + (self.barrier, results1, results2, results3)) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertEqual(len(results3), self.N) + + @classmethod + def _test_abort_and_reset_f(cls, barrier, barrier2, + results1, results2, results3): + try: + i = barrier.wait() + if i == cls.N//2: + raise RuntimeError + barrier.wait() + results1.append(True) + except threading.BrokenBarrierError: + results2.append(True) + except RuntimeError: + barrier.abort() + # Synchronize and reset the barrier. Must synchronize first so + # that everyone has left it when we reset, and after so that no + # one enters it before the reset. + if barrier2.wait() == cls.N//2: + barrier.reset() + barrier2.wait() + barrier.wait() + results3.append(True) + + def test_abort_and_reset(self): + """ + Test that a barrier can be reset after being broken. + """ + results1 = self.DummyList() + results2 = self.DummyList() + results3 = self.DummyList() + barrier2 = self.Barrier(self.N) + + self.run_threads(self._test_abort_and_reset_f, + (self.barrier, barrier2, results1, results2, results3)) + self.assertEqual(len(results1), 0) + self.assertEqual(len(results2), self.N-1) + self.assertEqual(len(results3), self.N) + + @classmethod + def _test_timeout_f(cls, barrier, results): + i = barrier.wait(20) + if i == cls.N//2: + # One thread is late! + time.sleep(4.0) + try: + barrier.wait(0.5) + except threading.BrokenBarrierError: + results.append(True) + + def test_timeout(self): + """ + Test wait(timeout) + """ + results = self.DummyList() + self.run_threads(self._test_timeout_f, (self.barrier, results)) + self.assertEqual(len(results), self.barrier.parties) + + @classmethod + def _test_default_timeout_f(cls, barrier, results): + i = barrier.wait(20) + if i == cls.N//2: + # One thread is later than the default timeout + time.sleep(4.0) + try: + barrier.wait() + except threading.BrokenBarrierError: + results.append(True) + + def test_default_timeout(self): + """ + Test the barrier's default timeout + """ + barrier = self.Barrier(self.N, timeout=1.0) + results = self.DummyList() + self.run_threads(self._test_default_timeout_f, (barrier, results)) + self.assertEqual(len(results), barrier.parties) + + def test_single_thread(self): + b = self.Barrier(1) + b.wait() + b.wait() + + @classmethod + def _test_thousand_f(cls, barrier, passes, conn, lock): + for i in range(passes): + barrier.wait() + with lock: + conn.send(i) + + def test_thousand(self): + if self.TYPE == 'manager': + return + passes = 1000 + lock = self.Lock() + conn, child_conn = self.Pipe(False) + for j in range(self.N): + p = self.Process(target=self._test_thousand_f, + args=(self.barrier, passes, child_conn, lock)) + p.start() + + for i in range(passes): + for j in range(self.N): + self.assertEqual(conn.recv(), i) + +# # # @@ -1485,6 +1820,11 @@ class _TestZZZNumberOfObjects(BaseTestCase): # run after all the other tests for the manager. It tests that # there have been no "reference leaks" for the manager's shared # objects. Note the comment in _TestPool.test_terminate(). + + # If some other test using ManagerMixin.manager fails, then the + # raised exception may keep alive a frame which holds a reference + # to a managed object. This will cause test_number_of_objects to + # also fail. ALLOWED_TYPES = ('manager',) def test_number_of_objects(self): @@ -1564,6 +1904,11 @@ class _TestMyManager(BaseTestCase): manager.shutdown() + # If the manager process exited cleanly then the exitcode + # will be zero. Otherwise (after a short timeout) + # terminate() is used, resulting in an exitcode of -SIGTERM. + self.assertEqual(manager._process.exitcode, 0) + # # Test of connecting to a remote server and using xmlrpclib for serialization # @@ -1923,7 +2268,7 @@ class _TestConnection(BaseTestCase): class _TestListener(BaseTestCase): - ALLOWED_TYPES = ('processes') + ALLOWED_TYPES = ('processes',) def test_multiple_bind(self): for family in self.connection.families: @@ -2505,10 +2850,12 @@ def create_test_cases(Mixin, type): result = {} glob = globals() Type = type.capitalize() + ALL_TYPES = {'processes', 'threads', 'manager'} for name in list(glob.keys()): if name.startswith('_Test'): base = glob[name] + assert set(base.ALLOWED_TYPES) <= ALL_TYPES, set(base.ALLOWED_TYPES) if type in base.ALLOWED_TYPES: newname = 'With' + Type + name[1:] class Temp(base, unittest.TestCase, Mixin): @@ -2527,7 +2874,7 @@ class ProcessesMixin(object): Process = multiprocessing.Process locals().update(get_attributes(multiprocessing, ( 'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', - 'Condition', 'Event', 'Value', 'Array', 'RawValue', + 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'RawValue', 'RawArray', 'current_process', 'active_children', 'Pipe', 'connection', 'JoinableQueue', 'Pool' ))) @@ -2542,7 +2889,7 @@ class ManagerMixin(object): manager = object.__new__(multiprocessing.managers.SyncManager) locals().update(get_attributes(manager, ( 'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', - 'Condition', 'Event', 'Value', 'Array', 'list', 'dict', + 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'list', 'dict', 'Namespace', 'JoinableQueue', 'Pool' ))) @@ -2555,7 +2902,7 @@ class ThreadsMixin(object): Process = multiprocessing.dummy.Process locals().update(get_attributes(multiprocessing.dummy, ( 'Queue', 'Lock', 'RLock', 'Semaphore', 'BoundedSemaphore', - 'Condition', 'Event', 'Value', 'Array', 'current_process', + 'Condition', 'Event', 'Barrier', 'Value', 'Array', 'current_process', 'active_children', 'Pipe', 'connection', 'dict', 'list', 'Namespace', 'JoinableQueue', 'Pool' ))) diff --git a/Lib/test/test_structseq.py b/Lib/test/test_structseq.py index d6c63b7..a89e955 100644 --- a/Lib/test/test_structseq.py +++ b/Lib/test/test_structseq.py @@ -78,8 +78,9 @@ class StructSeqTest(unittest.TestCase): def test_fields(self): t = time.gmtime() - self.assertEqual(len(t), t.n_fields) - self.assertEqual(t.n_fields, t.n_sequence_fields+t.n_unnamed_fields) + self.assertEqual(len(t), t.n_sequence_fields) + self.assertEqual(t.n_unnamed_fields, 0) + self.assertEqual(t.n_fields, time._STRUCT_TM_ITEMS) def test_constructor(self): t = time.struct_time diff --git a/Lib/test/test_time.py b/Lib/test/test_time.py index 02f05c3..63e1453 100644 --- a/Lib/test/test_time.py +++ b/Lib/test/test_time.py @@ -31,15 +31,14 @@ class TimeTestCase(unittest.TestCase): time.time() info = time.get_clock_info('time') self.assertFalse(info.monotonic) - if sys.platform != 'win32': - self.assertTrue(info.adjusted) + self.assertTrue(info.adjustable) def test_clock(self): time.clock() info = time.get_clock_info('clock') self.assertTrue(info.monotonic) - self.assertFalse(info.adjusted) + self.assertFalse(info.adjustable) @unittest.skipUnless(hasattr(time, 'clock_gettime'), 'need time.clock_gettime()') @@ -371,10 +370,7 @@ class TimeTestCase(unittest.TestCase): info = time.get_clock_info('monotonic') self.assertTrue(info.monotonic) - if sys.platform == 'linux': - self.assertTrue(info.adjusted) - else: - self.assertFalse(info.adjusted) + self.assertFalse(info.adjustable) def test_perf_counter(self): time.perf_counter() @@ -390,7 +386,7 @@ class TimeTestCase(unittest.TestCase): info = time.get_clock_info('process_time') self.assertTrue(info.monotonic) - self.assertFalse(info.adjusted) + self.assertFalse(info.adjustable) @unittest.skipUnless(hasattr(time, 'monotonic'), 'need time.monotonic') @@ -441,7 +437,7 @@ class TimeTestCase(unittest.TestCase): # 0.0 < resolution <= 1.0 self.assertGreater(info.resolution, 0.0) self.assertLessEqual(info.resolution, 1.0) - self.assertIsInstance(info.adjusted, bool) + self.assertIsInstance(info.adjustable, bool) self.assertRaises(ValueError, time.get_clock_info, 'xxx') @@ -624,7 +620,58 @@ class TestPytime(unittest.TestCase): for invalid in self.invalid_values: self.assertRaises(OverflowError, pytime_object_to_timespec, invalid) + @unittest.skipUnless(time._STRUCT_TM_ITEMS == 11, "needs tm_zone support") + def test_localtime_timezone(self): + + # Get the localtime and examine it for the offset and zone. + lt = time.localtime() + self.assertTrue(hasattr(lt, "tm_gmtoff")) + self.assertTrue(hasattr(lt, "tm_zone")) + # See if the offset and zone are similar to the module + # attributes. + if lt.tm_gmtoff is None: + self.assertTrue(not hasattr(time, "timezone")) + else: + self.assertEqual(lt.tm_gmtoff, -[time.timezone, time.altzone][lt.tm_isdst]) + if lt.tm_zone is None: + self.assertTrue(not hasattr(time, "tzname")) + else: + self.assertEqual(lt.tm_zone, time.tzname[lt.tm_isdst]) + + # Try and make UNIX times from the localtime and a 9-tuple + # created from the localtime. Test to see that the times are + # the same. + t = time.mktime(lt); t9 = time.mktime(lt[:9]) + self.assertEqual(t, t9) + + # Make localtimes from the UNIX times and compare them to + # the original localtime, thus making a round trip. + new_lt = time.localtime(t); new_lt9 = time.localtime(t9) + self.assertEqual(new_lt, lt) + self.assertEqual(new_lt.tm_gmtoff, lt.tm_gmtoff) + self.assertEqual(new_lt.tm_zone, lt.tm_zone) + self.assertEqual(new_lt9, lt) + self.assertEqual(new_lt.tm_gmtoff, lt.tm_gmtoff) + self.assertEqual(new_lt9.tm_zone, lt.tm_zone) + + @unittest.skipUnless(time._STRUCT_TM_ITEMS == 11, "needs tm_zone support") + def test_strptime_timezone(self): + t = time.strptime("UTC", "%Z") + self.assertEqual(t.tm_zone, 'UTC') + t = time.strptime("+0500", "%z") + self.assertEqual(t.tm_gmtoff, 5 * 3600) + + @unittest.skipUnless(time._STRUCT_TM_ITEMS == 11, "needs tm_zone support") + def test_short_times(self): + + import pickle + + # Load a short time structure using pickle. + st = b"ctime\nstruct_time\np0\n((I2007\nI8\nI11\nI1\nI24\nI49\nI5\nI223\nI1\ntp1\n(dp2\ntp3\nRp4\n." + lt = pickle.loads(st) + self.assertIs(lt.tm_gmtoff, None) + self.assertIs(lt.tm_zone, None) def test_main(): support.run_unittest( diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py index 49a5633..24ecae5 100644 --- a/Lib/test/test_xml_etree.py +++ b/Lib/test/test_xml_etree.py @@ -23,7 +23,8 @@ import weakref from test import support from test.support import findfile, import_fresh_module, gc_collect -pyET = import_fresh_module('xml.etree.ElementTree', blocked=['_elementtree']) +pyET = None +ET = None SIMPLE_XMLFILE = findfile("simple.xml", subdir="xmltestdata") try: @@ -209,10 +210,8 @@ def interface(): These methods return an iterable. See bug 6472. - >>> check_method(element.iter("tag").__next__) >>> check_method(element.iterfind("tag").__next__) >>> check_method(element.iterfind("*").__next__) - >>> check_method(tree.iter("tag").__next__) >>> check_method(tree.iterfind("tag").__next__) >>> check_method(tree.iterfind("*").__next__) @@ -291,42 +290,6 @@ def cdata(): '<tag>hello</tag>' """ -# Only with Python implementation -def simplefind(): - """ - Test find methods using the elementpath fallback. - - >>> ElementTree = pyET - - >>> CurrentElementPath = ElementTree.ElementPath - >>> ElementTree.ElementPath = ElementTree._SimpleElementPath() - >>> elem = ElementTree.XML(SAMPLE_XML) - >>> elem.find("tag").tag - 'tag' - >>> ElementTree.ElementTree(elem).find("tag").tag - 'tag' - >>> elem.findtext("tag") - 'text' - >>> elem.findtext("tog") - >>> elem.findtext("tog", "default") - 'default' - >>> ElementTree.ElementTree(elem).findtext("tag") - 'text' - >>> summarize_list(elem.findall("tag")) - ['tag', 'tag'] - >>> summarize_list(elem.findall(".//tag")) - ['tag', 'tag', 'tag'] - - Path syntax doesn't work in this case. - - >>> elem.find("section/tag") - >>> elem.findtext("section/tag") - >>> summarize_list(elem.findall("section/tag")) - [] - - >>> ElementTree.ElementPath = CurrentElementPath - """ - def find(): """ Test find methods (including xpath syntax). @@ -1002,36 +965,6 @@ def methods(): '1 < 2\n' """ -def iterators(): - """ - Test iterators. - - >>> e = ET.XML("<html><body>this is a <i>paragraph</i>.</body>..</html>") - >>> summarize_list(e.iter()) - ['html', 'body', 'i'] - >>> summarize_list(e.find("body").iter()) - ['body', 'i'] - >>> summarize(next(e.iter())) - 'html' - >>> "".join(e.itertext()) - 'this is a paragraph...' - >>> "".join(e.find("body").itertext()) - 'this is a paragraph.' - >>> next(e.itertext()) - 'this is a ' - - Method iterparse should return an iterator. See bug 6472. - - >>> sourcefile = serialize(e, to_string=False) - >>> next(ET.iterparse(sourcefile)) # doctest: +ELLIPSIS - ('end', <Element 'i' at 0x...>) - - >>> tree = ET.ElementTree(None) - >>> tree.iter() - Traceback (most recent call last): - AttributeError: 'NoneType' object has no attribute 'iter' - """ - ENTITY_XML = """\ <!DOCTYPE points [ <!ENTITY % user-entities SYSTEM 'user-entities.xml'> @@ -1339,6 +1272,7 @@ XINCLUDE["default.xml"] = """\ </document> """.format(html.escape(SIMPLE_XMLFILE, True)) + def xinclude_loader(href, parse="xml", encoding=None): try: data = XINCLUDE[href] @@ -1411,22 +1345,6 @@ def xinclude(): >>> # print(serialize(document)) # C5 """ -def xinclude_default(): - """ - >>> from xml.etree import ElementInclude - - >>> document = xinclude_loader("default.xml") - >>> ElementInclude.include(document) - >>> print(serialize(document)) # default - <document> - <p>Example.</p> - <root> - <element key="value">text</element> - <element>text</element>tail - <empty-element /> - </root> - </document> - """ # # badly formatted xi:include tags @@ -1917,9 +1835,8 @@ class ElementTreeTest(unittest.TestCase): self.assertIsInstance(ET.QName, type) self.assertIsInstance(ET.ElementTree, type) self.assertIsInstance(ET.Element, type) - # XXX issue 14128 with C ElementTree - # self.assertIsInstance(ET.TreeBuilder, type) - # self.assertIsInstance(ET.XMLParser, type) + self.assertIsInstance(ET.TreeBuilder, type) + self.assertIsInstance(ET.XMLParser, type) def test_Element_subclass_trivial(self): class MyElement(ET.Element): @@ -1953,6 +1870,73 @@ class ElementTreeTest(unittest.TestCase): self.assertEqual(mye.newmethod(), 'joe') +class ElementIterTest(unittest.TestCase): + def _ilist(self, elem, tag=None): + return summarize_list(elem.iter(tag)) + + def test_basic(self): + doc = ET.XML("<html><body>this is a <i>paragraph</i>.</body>..</html>") + self.assertEqual(self._ilist(doc), ['html', 'body', 'i']) + self.assertEqual(self._ilist(doc.find('body')), ['body', 'i']) + self.assertEqual(next(doc.iter()).tag, 'html') + self.assertEqual(''.join(doc.itertext()), 'this is a paragraph...') + self.assertEqual(''.join(doc.find('body').itertext()), + 'this is a paragraph.') + self.assertEqual(next(doc.itertext()), 'this is a ') + + # iterparse should return an iterator + sourcefile = serialize(doc, to_string=False) + self.assertEqual(next(ET.iterparse(sourcefile))[0], 'end') + + tree = ET.ElementTree(None) + self.assertRaises(AttributeError, tree.iter) + + def test_corners(self): + # single root, no subelements + a = ET.Element('a') + self.assertEqual(self._ilist(a), ['a']) + + # one child + b = ET.SubElement(a, 'b') + self.assertEqual(self._ilist(a), ['a', 'b']) + + # one child and one grandchild + c = ET.SubElement(b, 'c') + self.assertEqual(self._ilist(a), ['a', 'b', 'c']) + + # two children, only first with grandchild + d = ET.SubElement(a, 'd') + self.assertEqual(self._ilist(a), ['a', 'b', 'c', 'd']) + + # replace first child by second + a[0] = a[1] + del a[1] + self.assertEqual(self._ilist(a), ['a', 'd']) + + def test_iter_by_tag(self): + doc = ET.XML(''' + <document> + <house> + <room>bedroom1</room> + <room>bedroom2</room> + </house> + <shed>nothing here + </shed> + <house> + <room>bedroom8</room> + </house> + </document>''') + + self.assertEqual(self._ilist(doc, 'room'), ['room'] * 3) + self.assertEqual(self._ilist(doc, 'house'), ['house'] * 2) + + # make sure both tag=None and tag='*' return all tags + all_tags = ['document', 'house', 'room', 'room', + 'shed', 'house', 'room'] + self.assertEqual(self._ilist(doc), all_tags) + self.assertEqual(self._ilist(doc, '*'), all_tags) + + class TreeBuilderTest(unittest.TestCase): sample1 = ('<!DOCTYPE html PUBLIC' ' "-//W3C//DTD XHTML 1.0 Transitional//EN"' @@ -2027,6 +2011,23 @@ class TreeBuilderTest(unittest.TestCase): 'http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd')) +@unittest.skip('Unstable due to module monkeypatching') +class XincludeTest(unittest.TestCase): + def test_xinclude_default(self): + from xml.etree import ElementInclude + doc = xinclude_loader('default.xml') + ElementInclude.include(doc) + s = serialize(doc) + self.assertEqual(s.strip(), '''<document> + <p>Example.</p> + <root> + <element key="value">text</element> + <element>text</element>tail + <empty-element /> +</root> +</document>''') + + class XMLParserTest(unittest.TestCase): sample1 = '<file><line>22</line></file>' sample2 = ('<!DOCTYPE html PUBLIC' @@ -2073,13 +2074,6 @@ class XMLParserTest(unittest.TestCase): 'http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd')) -class NoAcceleratorTest(unittest.TestCase): - # Test that the C accelerator was not imported for pyET - def test_correct_import_pyET(self): - self.assertEqual(pyET.Element.__module__, 'xml.etree.ElementTree') - self.assertEqual(pyET.SubElement.__module__, 'xml.etree.ElementTree') - - class NamespaceParseTest(unittest.TestCase): def test_find_with_namespace(self): nsmap = {'h': 'hello', 'f': 'foo'} @@ -2090,7 +2084,6 @@ class NamespaceParseTest(unittest.TestCase): self.assertEqual(len(doc.findall('.//{foo}name', nsmap)), 1) - class ElementSlicingTest(unittest.TestCase): def _elem_tags(self, elemlist): return [e.tag for e in elemlist] @@ -2232,6 +2225,14 @@ class KeywordArgsTest(unittest.TestCase): with self.assertRaisesRegex(TypeError, 'must be dict, not str'): ET.Element('a', attrib="I'm not a dict") +# -------------------------------------------------------------------- + +@unittest.skipUnless(pyET, 'only for the Python version') +class NoAcceleratorTest(unittest.TestCase): + # Test that the C accelerator was not imported for pyET + def test_correct_import_pyET(self): + self.assertEqual(pyET.Element.__module__, 'xml.etree.ElementTree') + self.assertEqual(pyET.SubElement.__module__, 'xml.etree.ElementTree') # -------------------------------------------------------------------- @@ -2276,31 +2277,42 @@ class CleanContext(object): self.checkwarnings.__exit__(*args) -def test_main(module=pyET): - from test import test_xml_etree +def test_main(module=None): + # When invoked without a module, runs the Python ET tests by loading pyET. + # Otherwise, uses the given module as the ET. + if module is None: + global pyET + pyET = import_fresh_module('xml.etree.ElementTree', + blocked=['_elementtree']) + module = pyET - # The same doctests are used for both the Python and the C implementations - test_xml_etree.ET = module + global ET + ET = module test_classes = [ ElementSlicingTest, BasicElementTest, StringIOTest, ParseErrorTest, + XincludeTest, ElementTreeTest, - NamespaceParseTest, + ElementIterTest, TreeBuilderTest, - XMLParserTest, - KeywordArgsTest] - if module is pyET: - # Run the tests specific to the Python implementation - test_classes += [NoAcceleratorTest] + ] + + # These tests will only run for the pure-Python version that doesn't import + # _elementtree. We can't use skipUnless here, because pyET is filled in only + # after the module is loaded. + if pyET: + test_classes.extend([ + NoAcceleratorTest, + ]) support.run_unittest(*test_classes) # XXX the C module should give the same warnings as the Python module with CleanContext(quiet=(module is not pyET)): - support.run_doctest(test_xml_etree, verbosity=True) + support.run_doctest(sys.modules[__name__], verbosity=True) if __name__ == '__main__': test_main() diff --git a/Lib/test/test_xml_etree_c.py b/Lib/test/test_xml_etree_c.py index 10416d2..142a22f 100644 --- a/Lib/test/test_xml_etree_c.py +++ b/Lib/test/test_xml_etree_c.py @@ -8,31 +8,6 @@ cET = import_fresh_module('xml.etree.ElementTree', fresh=['_elementtree']) cET_alias = import_fresh_module('xml.etree.cElementTree', fresh=['_elementtree', 'xml.etree']) -# cElementTree specific tests - -def sanity(): - r""" - Import sanity. - - Issue #6697. - - >>> cElementTree = cET - >>> e = cElementTree.Element('a') - >>> getattr(e, '\uD800') # doctest: +ELLIPSIS - Traceback (most recent call last): - ... - UnicodeEncodeError: ... - - >>> p = cElementTree.XMLParser() - >>> p.version.split()[0] - 'Expat' - >>> getattr(p, '\uD800') - Traceback (most recent call last): - ... - AttributeError: 'XMLParser' object has no attribute '\ud800' - """ - - class MiscTests(unittest.TestCase): # Issue #8651. @support.bigmemtest(size=support._2G + 100, memuse=1) @@ -46,6 +21,7 @@ class MiscTests(unittest.TestCase): finally: data = None + @unittest.skipUnless(cET, 'requires _elementtree') class TestAliasWorking(unittest.TestCase): # Test that the cET alias module is alive @@ -53,6 +29,7 @@ class TestAliasWorking(unittest.TestCase): e = cET_alias.Element('foo') self.assertEqual(e.tag, 'foo') + @unittest.skipUnless(cET, 'requires _elementtree') class TestAcceleratorImported(unittest.TestCase): # Test that the C accelerator was imported, as expected @@ -67,7 +44,6 @@ def test_main(): from test import test_xml_etree, test_xml_etree_c # Run the tests specific to the C implementation - support.run_doctest(test_xml_etree_c, verbosity=True) support.run_unittest( MiscTests, TestAliasWorking, diff --git a/Lib/xml/etree/ElementTree.py b/Lib/xml/etree/ElementTree.py index e068fc2..d30a83c 100644 --- a/Lib/xml/etree/ElementTree.py +++ b/Lib/xml/etree/ElementTree.py @@ -101,32 +101,8 @@ import sys import re import warnings -class _SimpleElementPath: - # emulate pre-1.2 find/findtext/findall behaviour - def find(self, element, tag, namespaces=None): - for elem in element: - if elem.tag == tag: - return elem - return None - def findtext(self, element, tag, default=None, namespaces=None): - elem = self.find(element, tag) - if elem is None: - return default - return elem.text or "" - def iterfind(self, element, tag, namespaces=None): - if tag[:3] == ".//": - for elem in element.iter(tag[3:]): - yield elem - for elem in element: - if elem.tag == tag: - yield elem - def findall(self, element, tag, namespaces=None): - return list(self.iterfind(element, tag, namespaces)) +from . import ElementPath -try: - from . import ElementPath -except ImportError: - ElementPath = _SimpleElementPath() ## # Parser error. This is a subclass of <b>SyntaxError</b>. @@ -916,11 +892,7 @@ def _namespaces(elem, default_namespace=None): _raise_serialization_error(qname) # populate qname and namespaces table - try: - iterate = elem.iter - except AttributeError: - iterate = elem.getiterator # cET compatibility - for elem in iterate(): + for elem in elem.iter(): tag = elem.tag if isinstance(tag, QName): if tag.text not in qnames: |