summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/contextlib.py138
-rw-r--r--Lib/decimal.py29
-rw-r--r--Lib/test/contextmanager.py33
-rw-r--r--Lib/test/nested.py40
-rw-r--r--Lib/test/test_contextlib.py240
-rw-r--r--Lib/test/test_with.py84
-rw-r--r--Lib/threading.py29
-rw-r--r--Misc/NEWS9
-rw-r--r--Modules/threadmodule.c36
-rw-r--r--Objects/fileobject.c40
-rw-r--r--Python/ceval.c26
-rw-r--r--Python/errors.c1
12 files changed, 609 insertions, 96 deletions
diff --git a/Lib/contextlib.py b/Lib/contextlib.py
new file mode 100644
index 0000000..33d83a6
--- /dev/null
+++ b/Lib/contextlib.py
@@ -0,0 +1,138 @@
+"""Utilities for with-statement contexts. See PEP 343."""
+
+import sys
+
+__all__ = ["contextmanager", "nested", "closing"]
+
+class GeneratorContextManager(object):
+ """Helper for @contextmanager decorator."""
+
+ def __init__(self, gen):
+ self.gen = gen
+
+ def __context__(self):
+ return self
+
+ def __enter__(self):
+ try:
+ return self.gen.next()
+ except StopIteration:
+ raise RuntimeError("generator didn't yield")
+
+ def __exit__(self, type, value, traceback):
+ if type is None:
+ try:
+ self.gen.next()
+ except StopIteration:
+ return
+ else:
+ raise RuntimeError("generator didn't stop")
+ else:
+ try:
+ self.gen.throw(type, value, traceback)
+ except StopIteration:
+ pass
+
+
+def contextmanager(func):
+ """@contextmanager decorator.
+
+ Typical usage:
+
+ @contextmanager
+ def some_generator(<arguments>):
+ <setup>
+ try:
+ yield <value>
+ finally:
+ <cleanup>
+
+ This makes this:
+
+ with some_generator(<arguments>) as <variable>:
+ <body>
+
+ equivalent to this:
+
+ <setup>
+ try:
+ <variable> = <value>
+ <body>
+ finally:
+ <cleanup>
+
+ """
+ def helper(*args, **kwds):
+ return GeneratorContextManager(func(*args, **kwds))
+ try:
+ helper.__name__ = func.__name__
+ helper.__doc__ = func.__doc__
+ except:
+ pass
+ return helper
+
+
+@contextmanager
+def nested(*contexts):
+ """Support multiple context managers in a single with-statement.
+
+ Code like this:
+
+ with nested(A, B, C) as (X, Y, Z):
+ <body>
+
+ is equivalent to this:
+
+ with A as X:
+ with B as Y:
+ with C as Z:
+ <body>
+
+ """
+ exits = []
+ vars = []
+ exc = (None, None, None)
+ try:
+ try:
+ for context in contexts:
+ mgr = context.__context__()
+ exit = mgr.__exit__
+ enter = mgr.__enter__
+ vars.append(enter())
+ exits.append(exit)
+ yield vars
+ except:
+ exc = sys.exc_info()
+ finally:
+ while exits:
+ exit = exits.pop()
+ try:
+ exit(*exc)
+ except:
+ exc = sys.exc_info()
+ if exc != (None, None, None):
+ raise
+
+
+@contextmanager
+def closing(thing):
+ """Context manager to automatically close something at the end of a block.
+
+ Code like this:
+
+ with closing(<module>.open(<arguments>)) as f:
+ <block>
+
+ is equivalent to this:
+
+ f = <module>.open(<arguments>)
+ try:
+ <block>
+ finally:
+ f.close()
+
+ """
+ try:
+ yield thing
+ finally:
+ thing.close()
diff --git a/Lib/decimal.py b/Lib/decimal.py
index 677d26b..49f8115 100644
--- a/Lib/decimal.py
+++ b/Lib/decimal.py
@@ -2173,6 +2173,32 @@ for name in rounding_functions:
del name, val, globalname, rounding_functions
+class ContextManager(object):
+ """Helper class to simplify Context management.
+
+ Sample usage:
+
+ with decimal.ExtendedContext:
+ s = ...
+ return +s # Convert result to normal precision
+
+ with decimal.getcontext() as ctx:
+ ctx.prec += 2
+ s = ...
+ return +s
+
+ """
+ def __init__(self, new_context):
+ self.new_context = new_context
+ def __enter__(self):
+ self.saved_context = getcontext()
+ setcontext(self.new_context)
+ return self.new_context
+ def __exit__(self, t, v, tb):
+ setcontext(self.saved_context)
+ if t is not None:
+ raise t, v, tb
+
class Context(object):
"""Contains the context for a Decimal instance.
@@ -2224,6 +2250,9 @@ class Context(object):
s.append('traps=[' + ', '.join([t.__name__ for t, v in self.traps.items() if v]) + ']')
return ', '.join(s) + ')'
+ def __context__(self):
+ return ContextManager(self.copy())
+
def clear_flags(self):
"""Reset all flags to zero"""
for flag in self.flags:
diff --git a/Lib/test/contextmanager.py b/Lib/test/contextmanager.py
deleted file mode 100644
index 07fe61c..0000000
--- a/Lib/test/contextmanager.py
+++ /dev/null
@@ -1,33 +0,0 @@
-class GeneratorContextManager(object):
- def __init__(self, gen):
- self.gen = gen
-
- def __context__(self):
- return self
-
- def __enter__(self):
- try:
- return self.gen.next()
- except StopIteration:
- raise RuntimeError("generator didn't yield")
-
- def __exit__(self, type, value, traceback):
- if type is None:
- try:
- self.gen.next()
- except StopIteration:
- return
- else:
- raise RuntimeError("generator didn't stop")
- else:
- try:
- self.gen.throw(type, value, traceback)
- except (type, StopIteration):
- return
- else:
- raise RuntimeError("generator caught exception")
-
-def contextmanager(func):
- def helper(*args, **kwds):
- return GeneratorContextManager(func(*args, **kwds))
- return helper
diff --git a/Lib/test/nested.py b/Lib/test/nested.py
deleted file mode 100644
index b5030e0..0000000
--- a/Lib/test/nested.py
+++ /dev/null
@@ -1,40 +0,0 @@
-import sys
-from collections import deque
-
-
-class nested(object):
- def __init__(self, *contexts):
- self.contexts = contexts
- self.entered = None
-
- def __context__(self):
- return self
-
- def __enter__(self):
- if self.entered is not None:
- raise RuntimeError("Context is not reentrant")
- self.entered = deque()
- vars = []
- try:
- for context in self.contexts:
- mgr = context.__context__()
- vars.append(mgr.__enter__())
- self.entered.appendleft(mgr)
- except:
- self.__exit__(*sys.exc_info())
- raise
- return vars
-
- def __exit__(self, *exc_info):
- # Behave like nested with statements
- # first in, last out
- # New exceptions override old ones
- ex = exc_info
- for mgr in self.entered:
- try:
- mgr.__exit__(*ex)
- except:
- ex = sys.exc_info()
- self.entered = None
- if ex is not exc_info:
- raise ex[0], ex[1], ex[2]
diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py
new file mode 100644
index 0000000..8c8d887
--- /dev/null
+++ b/Lib/test/test_contextlib.py
@@ -0,0 +1,240 @@
+"""Unit tests for contextlib.py, and other context managers."""
+
+from __future__ import with_statement
+
+import os
+import decimal
+import tempfile
+import unittest
+import threading
+from contextlib import * # Tests __all__
+
+class ContextManagerTestCase(unittest.TestCase):
+
+ def test_contextmanager_plain(self):
+ state = []
+ @contextmanager
+ def woohoo():
+ state.append(1)
+ yield 42
+ state.append(999)
+ with woohoo() as x:
+ self.assertEqual(state, [1])
+ self.assertEqual(x, 42)
+ state.append(x)
+ self.assertEqual(state, [1, 42, 999])
+
+ def test_contextmanager_finally(self):
+ state = []
+ @contextmanager
+ def woohoo():
+ state.append(1)
+ try:
+ yield 42
+ finally:
+ state.append(999)
+ try:
+ with woohoo() as x:
+ self.assertEqual(state, [1])
+ self.assertEqual(x, 42)
+ state.append(x)
+ raise ZeroDivisionError()
+ except ZeroDivisionError:
+ pass
+ else:
+ self.fail("Expected ZeroDivisionError")
+ self.assertEqual(state, [1, 42, 999])
+
+ def test_contextmanager_except(self):
+ state = []
+ @contextmanager
+ def woohoo():
+ state.append(1)
+ try:
+ yield 42
+ except ZeroDivisionError, e:
+ state.append(e.args[0])
+ self.assertEqual(state, [1, 42, 999])
+ with woohoo() as x:
+ self.assertEqual(state, [1])
+ self.assertEqual(x, 42)
+ state.append(x)
+ raise ZeroDivisionError(999)
+ self.assertEqual(state, [1, 42, 999])
+
+class NestedTestCase(unittest.TestCase):
+
+ # XXX This needs more work
+
+ def test_nested(self):
+ @contextmanager
+ def a():
+ yield 1
+ @contextmanager
+ def b():
+ yield 2
+ @contextmanager
+ def c():
+ yield 3
+ with nested(a(), b(), c()) as (x, y, z):
+ self.assertEqual(x, 1)
+ self.assertEqual(y, 2)
+ self.assertEqual(z, 3)
+
+ def test_nested_cleanup(self):
+ state = []
+ @contextmanager
+ def a():
+ state.append(1)
+ try:
+ yield 2
+ finally:
+ state.append(3)
+ @contextmanager
+ def b():
+ state.append(4)
+ try:
+ yield 5
+ finally:
+ state.append(6)
+ try:
+ with nested(a(), b()) as (x, y):
+ state.append(x)
+ state.append(y)
+ 1/0
+ except ZeroDivisionError:
+ self.assertEqual(state, [1, 4, 2, 5, 6, 3])
+ else:
+ self.fail("Didn't raise ZeroDivisionError")
+
+class ClosingTestCase(unittest.TestCase):
+
+ # XXX This needs more work
+
+ def test_closing(self):
+ state = []
+ class C:
+ def close(self):
+ state.append(1)
+ x = C()
+ self.assertEqual(state, [])
+ with closing(x) as y:
+ self.assertEqual(x, y)
+ self.assertEqual(state, [1])
+
+ def test_closing_error(self):
+ state = []
+ class C:
+ def close(self):
+ state.append(1)
+ x = C()
+ self.assertEqual(state, [])
+ try:
+ with closing(x) as y:
+ self.assertEqual(x, y)
+ 1/0
+ except ZeroDivisionError:
+ self.assertEqual(state, [1])
+ else:
+ self.fail("Didn't raise ZeroDivisionError")
+
+class FileContextTestCase(unittest.TestCase):
+
+ def testWithOpen(self):
+ tfn = tempfile.mktemp()
+ try:
+ f = None
+ with open(tfn, "w") as f:
+ self.failIf(f.closed)
+ f.write("Booh\n")
+ self.failUnless(f.closed)
+ f = None
+ try:
+ with open(tfn, "r") as f:
+ self.failIf(f.closed)
+ self.assertEqual(f.read(), "Booh\n")
+ 1/0
+ except ZeroDivisionError:
+ self.failUnless(f.closed)
+ else:
+ self.fail("Didn't raise ZeroDivisionError")
+ finally:
+ try:
+ os.remove(tfn)
+ except os.error:
+ pass
+
+class LockContextTestCase(unittest.TestCase):
+
+ def boilerPlate(self, lock, locked):
+ self.failIf(locked())
+ with lock:
+ self.failUnless(locked())
+ self.failIf(locked())
+ try:
+ with lock:
+ self.failUnless(locked())
+ 1/0
+ except ZeroDivisionError:
+ self.failIf(locked())
+ else:
+ self.fail("Didn't raise ZeroDivisionError")
+
+ def testWithLock(self):
+ lock = threading.Lock()
+ self.boilerPlate(lock, lock.locked)
+
+ def testWithRLock(self):
+ lock = threading.RLock()
+ self.boilerPlate(lock, lock._is_owned)
+
+ def testWithCondition(self):
+ lock = threading.Condition()
+ def locked():
+ return lock._is_owned()
+ self.boilerPlate(lock, locked)
+
+ def testWithSemaphore(self):
+ lock = threading.Semaphore()
+ def locked():
+ if lock.acquire(False):
+ lock.release()
+ return False
+ else:
+ return True
+ self.boilerPlate(lock, locked)
+
+ def testWithBoundedSemaphore(self):
+ lock = threading.BoundedSemaphore()
+ def locked():
+ if lock.acquire(False):
+ lock.release()
+ return False
+ else:
+ return True
+ self.boilerPlate(lock, locked)
+
+class DecimalContextTestCase(unittest.TestCase):
+
+ # XXX Somebody should write more thorough tests for this
+
+ def testBasic(self):
+ ctx = decimal.getcontext()
+ ctx.prec = save_prec = decimal.ExtendedContext.prec + 5
+ with decimal.ExtendedContext:
+ self.assertEqual(decimal.getcontext().prec,
+ decimal.ExtendedContext.prec)
+ self.assertEqual(decimal.getcontext().prec, save_prec)
+ try:
+ with decimal.ExtendedContext:
+ self.assertEqual(decimal.getcontext().prec,
+ decimal.ExtendedContext.prec)
+ 1/0
+ except ZeroDivisionError:
+ self.assertEqual(decimal.getcontext().prec, save_prec)
+ else:
+ self.fail("Didn't raise ZeroDivisionError")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Lib/test/test_with.py b/Lib/test/test_with.py
index ed072c9..36035e3 100644
--- a/Lib/test/test_with.py
+++ b/Lib/test/test_with.py
@@ -7,9 +7,10 @@ from __future__ import with_statement
__author__ = "Mike Bland"
__email__ = "mbland at acm dot org"
+import sys
import unittest
-from test.contextmanager import GeneratorContextManager
-from test.nested import nested
+from collections import deque
+from contextlib import GeneratorContextManager, contextmanager
from test.test_support import run_unittest
@@ -57,9 +58,48 @@ def mock_contextmanager_generator():
mock.stopped = True
-class MockNested(nested):
+class Nested(object):
+
+ def __init__(self, *contexts):
+ self.contexts = contexts
+ self.entered = None
+
+ def __context__(self):
+ return self
+
+ def __enter__(self):
+ if self.entered is not None:
+ raise RuntimeError("Context is not reentrant")
+ self.entered = deque()
+ vars = []
+ try:
+ for context in self.contexts:
+ mgr = context.__context__()
+ vars.append(mgr.__enter__())
+ self.entered.appendleft(mgr)
+ except:
+ self.__exit__(*sys.exc_info())
+ raise
+ return vars
+
+ def __exit__(self, *exc_info):
+ # Behave like nested with statements
+ # first in, last out
+ # New exceptions override old ones
+ ex = exc_info
+ for mgr in self.entered:
+ try:
+ mgr.__exit__(*ex)
+ except:
+ ex = sys.exc_info()
+ self.entered = None
+ if ex is not exc_info:
+ raise ex[0], ex[1], ex[2]
+
+
+class MockNested(Nested):
def __init__(self, *contexts):
- nested.__init__(self, *contexts)
+ Nested.__init__(self, *contexts)
self.context_called = False
self.enter_called = False
self.exit_called = False
@@ -67,16 +107,16 @@ class MockNested(nested):
def __context__(self):
self.context_called = True
- return nested.__context__(self)
+ return Nested.__context__(self)
def __enter__(self):
self.enter_called = True
- return nested.__enter__(self)
+ return Nested.__enter__(self)
def __exit__(self, *exc_info):
self.exit_called = True
self.exit_args = exc_info
- return nested.__exit__(self, *exc_info)
+ return Nested.__exit__(self, *exc_info)
class FailureTestCase(unittest.TestCase):
@@ -294,7 +334,7 @@ class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
class NestedNonexceptionalTestCase(unittest.TestCase,
ContextmanagerAssertionMixin):
def testSingleArgInlineGeneratorSyntax(self):
- with nested(mock_contextmanager_generator()):
+ with Nested(mock_contextmanager_generator()):
pass
def testSingleArgUnbound(self):
@@ -310,7 +350,7 @@ class NestedNonexceptionalTestCase(unittest.TestCase,
m = mock_contextmanager_generator()
# This will bind all the arguments to nested() into a single list
# assigned to foo.
- with nested(m) as foo:
+ with Nested(m) as foo:
self.assertInWithManagerInvariants(m)
self.assertAfterWithManagerInvariantsNoError(m)
@@ -318,14 +358,13 @@ class NestedNonexceptionalTestCase(unittest.TestCase,
m = mock_contextmanager_generator()
# This will bind all the arguments to nested() into a single list
# assigned to foo.
- # FIXME: what should this do: with nested(m) as (foo,):
- with nested(m) as (foo):
+ with Nested(m) as (foo):
self.assertInWithManagerInvariants(m)
self.assertAfterWithManagerInvariantsNoError(m)
def testSingleArgBoundToMultipleElementTupleError(self):
def shouldThrowValueError():
- with nested(mock_contextmanager_generator()) as (foo, bar):
+ with Nested(mock_contextmanager_generator()) as (foo, bar):
pass
self.assertRaises(ValueError, shouldThrowValueError)
@@ -535,7 +574,9 @@ class AssignmentTargetTestCase(unittest.TestCase):
class C:
def __context__(self): return self
def __enter__(self): return 1, 2, 3
- def __exit__(self, *a): pass
+ def __exit__(self, t, v, tb):
+ if t is not None:
+ raise t, v, tb
targets = {1: [0, 1, 2]}
with C() as (targets[1][0], targets[1][1], targets[1][2]):
self.assertEqual(targets, {1: [1, 2, 3]})
@@ -551,11 +592,26 @@ class AssignmentTargetTestCase(unittest.TestCase):
self.assertEqual(blah.three, 3)
+class ExitSwallowsExceptionTestCase(unittest.TestCase):
+
+ def testExitSwallowsException(self):
+ class AfricanOrEuropean:
+ def __context__(self): return self
+ def __enter__(self): pass
+ def __exit__(self, t, v, tb): pass
+ try:
+ with AfricanOrEuropean():
+ 1/0
+ except ZeroDivisionError:
+ self.fail("ZeroDivisionError should have been swallowed")
+
+
def test_main():
run_unittest(FailureTestCase, NonexceptionalTestCase,
NestedNonexceptionalTestCase, ExceptionalTestCase,
NonLocalFlowControlTestCase,
- AssignmentTargetTestCase)
+ AssignmentTargetTestCase,
+ ExitSwallowsExceptionTestCase)
if __name__ == '__main__':
diff --git a/Lib/threading.py b/Lib/threading.py
index 9cc108e..5b485d5 100644
--- a/Lib/threading.py
+++ b/Lib/threading.py
@@ -90,6 +90,9 @@ class _RLock(_Verbose):
self.__owner and self.__owner.getName(),
self.__count)
+ def __context__(self):
+ return self
+
def acquire(self, blocking=1):
me = currentThread()
if self.__owner is me:
@@ -108,6 +111,8 @@ class _RLock(_Verbose):
self._note("%s.acquire(%s): failure", self, blocking)
return rc
+ __enter__ = acquire
+
def release(self):
me = currentThread()
assert self.__owner is me, "release() of un-acquire()d lock"
@@ -121,6 +126,11 @@ class _RLock(_Verbose):
if __debug__:
self._note("%s.release(): non-final release", self)
+ def __exit__(self, t, v, tb):
+ self.release()
+ if t is not None:
+ raise t, v, tb
+
# Internal methods used by condition variables
def _acquire_restore(self, (count, owner)):
@@ -156,6 +166,7 @@ class _Condition(_Verbose):
self.__lock = lock
# Export the lock's acquire() and release() methods
self.acquire = lock.acquire
+ self.__enter__ = self.acquire
self.release = lock.release
# If the lock defines _release_save() and/or _acquire_restore(),
# these override the default implementations (which just call
@@ -174,6 +185,14 @@ class _Condition(_Verbose):
pass
self.__waiters = []
+ def __context__(self):
+ return self
+
+ def __exit__(self, t, v, tb):
+ self.release()
+ if t is not None:
+ raise t, v, tb
+
def __repr__(self):
return "<Condition(%s, %d)>" % (self.__lock, len(self.__waiters))
@@ -267,6 +286,9 @@ class _Semaphore(_Verbose):
self.__cond = Condition(Lock())
self.__value = value
+ def __context__(self):
+ return self
+
def acquire(self, blocking=1):
rc = False
self.__cond.acquire()
@@ -286,6 +308,8 @@ class _Semaphore(_Verbose):
self.__cond.release()
return rc
+ __enter__ = acquire
+
def release(self):
self.__cond.acquire()
self.__value = self.__value + 1
@@ -295,6 +319,11 @@ class _Semaphore(_Verbose):
self.__cond.notify()
self.__cond.release()
+ def __exit__(self, t, v, tb):
+ self.release()
+ if t is not None:
+ raise t, v, tb
+
def BoundedSemaphore(*args, **kwargs):
return _BoundedSemaphore(*args, **kwargs)
diff --git a/Misc/NEWS b/Misc/NEWS
index 5599ba5..51e0aef 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -36,6 +36,12 @@ Core and builtins
with_statement``. Use of 'with' as a variable will generate a warning.
Use of 'as' as a variable will also generate a warning (unless it's
part of an import statement).
+ The following objects have __context__ methods:
+ - The built-in file type.
+ - The thread.LockType type.
+ - The following types defined by the threading module:
+ Lock, RLock, Condition, Semaphore, BoundedSemaphore.
+ - The decimal.Context class.
- Fix the encodings package codec search function to only search
inside its own package. Fixes problem reported in patch #1433198.
@@ -411,6 +417,9 @@ Extension Modules
Library
-------
+- PEP 343: new module contextlib.py defines decorator @contextmanager
+ and helpful context managers nested() and closing().
+
- The compiler package now supports future imports after the module docstring.
- Bug #1413790: zipfile now sanitizes absolute archive names that are
diff --git a/Modules/threadmodule.c b/Modules/threadmodule.c
index fccdd62..b0f7700 100644
--- a/Modules/threadmodule.c
+++ b/Modules/threadmodule.c
@@ -116,6 +116,36 @@ PyDoc_STRVAR(locked_doc,
\n\
Return whether the lock is in the locked state.");
+static PyObject *
+lock_context(lockobject *self)
+{
+ Py_INCREF(self);
+ return (PyObject *)self;
+}
+
+PyDoc_STRVAR(lock_exit_doc,
+"__exit__(type, value, tb)\n\
+\n\
+Releases the lock; then re-raises the exception if type is not None.");
+
+static PyObject *
+lock_exit(lockobject *self, PyObject *args)
+{
+ PyObject *type, *value, *tb, *result;
+ if (!PyArg_ParseTuple(args, "OOO:__exit__", &type, &value, &tb))
+ return NULL;
+ result = lock_PyThread_release_lock(self);
+ if (result != NULL && type != Py_None) {
+ Py_DECREF(result);
+ result = NULL;
+ Py_INCREF(type);
+ Py_INCREF(value);
+ Py_INCREF(tb);
+ PyErr_Restore(type, value, tb);
+ }
+ return result;
+}
+
static PyMethodDef lock_methods[] = {
{"acquire_lock", (PyCFunction)lock_PyThread_acquire_lock,
METH_VARARGS, acquire_doc},
@@ -129,6 +159,12 @@ static PyMethodDef lock_methods[] = {
METH_NOARGS, locked_doc},
{"locked", (PyCFunction)lock_locked_lock,
METH_NOARGS, locked_doc},
+ {"__context__", (PyCFunction)lock_context,
+ METH_NOARGS, PyDoc_STR("__context__() -> self.")},
+ {"__enter__", (PyCFunction)lock_PyThread_acquire_lock,
+ METH_VARARGS, acquire_doc},
+ {"__exit__", (PyCFunction)lock_exit,
+ METH_VARARGS, lock_exit_doc},
{NULL, NULL} /* sentinel */
};
diff --git a/Objects/fileobject.c b/Objects/fileobject.c
index d535869..b39a10f 100644
--- a/Objects/fileobject.c
+++ b/Objects/fileobject.c
@@ -1609,7 +1609,7 @@ file_writelines(PyFileObject *f, PyObject *seq)
}
static PyObject *
-file_getiter(PyFileObject *f)
+file_self(PyFileObject *f)
{
if (f->f_fp == NULL)
return err_closed();
@@ -1617,6 +1617,24 @@ file_getiter(PyFileObject *f)
return (PyObject *)f;
}
+static PyObject *
+file_exit(PyFileObject *f, PyObject *args)
+{
+ PyObject *type, *value, *tb, *result;
+ if (!PyArg_ParseTuple(args, "OOO:__exit__", &type, &value, &tb))
+ return NULL;
+ result = file_close(f);
+ if (result != NULL && type != Py_None) {
+ Py_DECREF(result);
+ result = NULL;
+ Py_INCREF(type);
+ Py_INCREF(value);
+ Py_INCREF(tb);
+ PyErr_Restore(type, value, tb);
+ }
+ return result;
+}
+
PyDoc_STRVAR(readline_doc,
"readline([size]) -> next line from the file, as a string.\n"
"\n"
@@ -1701,6 +1719,19 @@ PyDoc_STRVAR(close_doc,
PyDoc_STRVAR(isatty_doc,
"isatty() -> true or false. True if the file is connected to a tty device.");
+PyDoc_STRVAR(context_doc,
+ "__context__() -> self.");
+
+PyDoc_STRVAR(enter_doc,
+ "__enter__() -> self.");
+
+PyDoc_STRVAR(exit_doc,
+"__exit__(type, value, traceback).\n\
+\n\
+Closes the file; then re-raises the exception if type is not None.\n\
+If no exception is re-raised, the return value is the same as for close().\n\
+");
+
static PyMethodDef file_methods[] = {
{"readline", (PyCFunction)file_readline, METH_VARARGS, readline_doc},
{"read", (PyCFunction)file_read, METH_VARARGS, read_doc},
@@ -1713,11 +1744,14 @@ static PyMethodDef file_methods[] = {
{"tell", (PyCFunction)file_tell, METH_NOARGS, tell_doc},
{"readinto", (PyCFunction)file_readinto, METH_VARARGS, readinto_doc},
{"readlines", (PyCFunction)file_readlines,METH_VARARGS, readlines_doc},
- {"xreadlines",(PyCFunction)file_getiter, METH_NOARGS, xreadlines_doc},
+ {"xreadlines",(PyCFunction)file_self, METH_NOARGS, xreadlines_doc},
{"writelines",(PyCFunction)file_writelines, METH_O, writelines_doc},
{"flush", (PyCFunction)file_flush, METH_NOARGS, flush_doc},
{"close", (PyCFunction)file_close, METH_NOARGS, close_doc},
{"isatty", (PyCFunction)file_isatty, METH_NOARGS, isatty_doc},
+ {"__context__", (PyCFunction)file_self, METH_NOARGS, context_doc},
+ {"__enter__", (PyCFunction)file_self, METH_NOARGS, enter_doc},
+ {"__exit__", (PyCFunction)file_exit, METH_VARARGS, exit_doc},
{NULL, NULL} /* sentinel */
};
@@ -2044,7 +2078,7 @@ PyTypeObject PyFile_Type = {
0, /* tp_clear */
0, /* tp_richcompare */
offsetof(PyFileObject, weakreflist), /* tp_weaklistoffset */
- (getiterfunc)file_getiter, /* tp_iter */
+ (getiterfunc)file_self, /* tp_iter */
(iternextfunc)file_iternext, /* tp_iternext */
file_methods, /* tp_methods */
file_memberlist, /* tp_members */
diff --git a/Python/ceval.c b/Python/ceval.c
index 3732f6d..3ef853e 100644
--- a/Python/ceval.c
+++ b/Python/ceval.c
@@ -2200,23 +2200,37 @@ PyEval_EvalFrameEx(PyFrameObject *f, int throw)
The code here just sets the stack up for the call;
separate CALL_FUNCTION(3) and POP_TOP opcodes are
emitted by the compiler.
+
+ In addition, if the stack represents an exception,
+ we "zap" this information; __exit__() should
+ re-raise the exception if it wants to, and if
+ __exit__() returns normally, END_FINALLY should
+ *not* re-raise the exception. (But non-local
+ gotos should still be resumed.)
*/
x = TOP();
u = SECOND();
if (PyInt_Check(u) || u == Py_None) {
u = v = w = Py_None;
+ Py_INCREF(u);
+ Py_INCREF(v);
+ Py_INCREF(w);
}
else {
v = THIRD();
w = FOURTH();
+ /* Zap the exception from the stack,
+ to fool END_FINALLY. */
+ STACKADJ(-2);
+ SET_TOP(x);
+ Py_INCREF(Py_None);
+ SET_SECOND(Py_None);
}
- Py_INCREF(u);
- Py_INCREF(v);
- Py_INCREF(w);
- PUSH(u);
- PUSH(v);
- PUSH(w);
+ STACKADJ(3);
+ SET_THIRD(u);
+ SET_SECOND(v);
+ SET_TOP(w);
break;
}
diff --git a/Python/errors.c b/Python/errors.c
index cbcc6fa..c33bd13 100644
--- a/Python/errors.c
+++ b/Python/errors.c
@@ -24,6 +24,7 @@ PyErr_Restore(PyObject *type, PyObject *value, PyObject *traceback)
if (traceback != NULL && !PyTraceBack_Check(traceback)) {
/* XXX Should never happen -- fatal error instead? */
+ /* Well, it could be None. */
Py_DECREF(traceback);
traceback = NULL;
}