summaryrefslogtreecommitdiffstats
path: root/Lib/unittest
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/unittest')
-rw-r--r--Lib/unittest/mock.py79
-rw-r--r--Lib/unittest/test/testmock/testmock.py24
-rw-r--r--Lib/unittest/test/testmock/testwith.py4
3 files changed, 68 insertions, 39 deletions
diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py
index 74f918a..3fbe846 100644
--- a/Lib/unittest/mock.py
+++ b/Lib/unittest/mock.py
@@ -2278,6 +2278,24 @@ def mock_open(mock=None, read_data=''):
`read_data` is a string for the `read` methoddline`, and `readlines` of the
file handle to return. This is an empty string by default.
"""
+ def _readlines_side_effect(*args, **kwargs):
+ if handle.readlines.return_value is not None:
+ return handle.readlines.return_value
+ return list(_state[0])
+
+ def _read_side_effect(*args, **kwargs):
+ if handle.read.return_value is not None:
+ return handle.read.return_value
+ return ''.join(_state[0])
+
+ def _readline_side_effect():
+ if handle.readline.return_value is not None:
+ while True:
+ yield handle.readline.return_value
+ for line in _state[0]:
+ yield line
+
+
global file_spec
if file_spec is None:
import _io
@@ -2286,42 +2304,31 @@ def mock_open(mock=None, read_data=''):
if mock is None:
mock = MagicMock(name='open', spec=open)
- def make_handle(*args, **kwargs):
- # Arg checking is handled by __call__
- def _readlines_side_effect(*args, **kwargs):
- if handle.readlines.return_value is not None:
- return handle.readlines.return_value
- return list(_data)
-
- def _read_side_effect(*args, **kwargs):
- if handle.read.return_value is not None:
- return handle.read.return_value
- return ''.join(_data)
-
- def _readline_side_effect():
- if handle.readline.return_value is not None:
- while True:
- yield handle.readline.return_value
- for line in _data:
- yield line
-
- handle = MagicMock(spec=file_spec)
- handle.__enter__.return_value = handle
-
- _data = _iterate_read_data(read_data)
-
- handle.write.return_value = None
- handle.read.return_value = None
- handle.readline.return_value = None
- handle.readlines.return_value = None
-
- handle.read.side_effect = _read_side_effect
- handle.readline.side_effect = _readline_side_effect()
- handle.readlines.side_effect = _readlines_side_effect
- _check_and_set_parent(mock, handle, None, '()')
- return handle
-
- mock.side_effect = make_handle
+ handle = MagicMock(spec=file_spec)
+ handle.__enter__.return_value = handle
+
+ _state = [_iterate_read_data(read_data), None]
+
+ handle.write.return_value = None
+ handle.read.return_value = None
+ handle.readline.return_value = None
+ handle.readlines.return_value = None
+
+ handle.read.side_effect = _read_side_effect
+ _state[1] = _readline_side_effect()
+ handle.readline.side_effect = _state[1]
+ handle.readlines.side_effect = _readlines_side_effect
+
+ def reset_data(*args, **kwargs):
+ _state[0] = _iterate_read_data(read_data)
+ if handle.readline.side_effect == _state[1]:
+ # Only reset the side effect if the user hasn't overridden it.
+ _state[1] = _readline_side_effect()
+ handle.readline.side_effect = _state[1]
+ return DEFAULT
+
+ mock.side_effect = reset_data
+ mock.return_value = handle
return mock
diff --git a/Lib/unittest/test/testmock/testmock.py b/Lib/unittest/test/testmock/testmock.py
index 32703e6..976c40f 100644
--- a/Lib/unittest/test/testmock/testmock.py
+++ b/Lib/unittest/test/testmock/testmock.py
@@ -1,5 +1,6 @@
import copy
import sys
+import tempfile
import unittest
from unittest.test.testmock.support import is_instance
@@ -1329,8 +1330,29 @@ class MockTest(unittest.TestCase):
def test_mock_open_reuse_issue_21750(self):
mocked_open = mock.mock_open(read_data='data')
f1 = mocked_open('a-name')
+ f1_data = f1.read()
f2 = mocked_open('another-name')
- self.assertEqual(f1.read(), f2.read())
+ f2_data = f2.read()
+ self.assertEqual(f1_data, f2_data)
+
+ def test_mock_open_write(self):
+ # Test exception in file writing write()
+ mock_namedtemp = mock.mock_open(mock.MagicMock(name='JLV'))
+ with mock.patch('tempfile.NamedTemporaryFile', mock_namedtemp):
+ mock_filehandle = mock_namedtemp.return_value
+ mock_write = mock_filehandle.write
+ mock_write.side_effect = OSError('Test 2 Error')
+ def attempt():
+ tempfile.NamedTemporaryFile().write('asd')
+ self.assertRaises(OSError, attempt)
+
+ def test_mock_open_alter_readline(self):
+ mopen = mock.mock_open(read_data='foo\nbarn')
+ mopen.return_value.readline.side_effect = lambda *args:'abc'
+ first = mopen().readline()
+ second = mopen().readline()
+ self.assertEqual('abc', first)
+ self.assertEqual('abc', second)
def test_mock_parents(self):
for Klass in Mock, MagicMock:
diff --git a/Lib/unittest/test/testmock/testwith.py b/Lib/unittest/test/testmock/testwith.py
index ddcfe77..b6bfb75 100644
--- a/Lib/unittest/test/testmock/testwith.py
+++ b/Lib/unittest/test/testmock/testwith.py
@@ -141,6 +141,7 @@ class TestMockOpen(unittest.TestCase):
def test_mock_open_context_manager(self):
mock = mock_open()
+ handle = mock.return_value
with patch('%s.open' % __name__, mock, create=True):
with open('foo') as f:
f.read()
@@ -148,8 +149,7 @@ class TestMockOpen(unittest.TestCase):
expected_calls = [call('foo'), call().__enter__(), call().read(),
call().__exit__(None, None, None)]
self.assertEqual(mock.mock_calls, expected_calls)
- # mock_open.return_value is no longer static, because
- # readline support requires that it mutate state
+ self.assertIs(f, handle)
def test_mock_open_context_manager_multiple_times(self):
mock = mock_open()