summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorR David Murray <rdmurray@bitdance.com>2016-09-09 22:39:18 (GMT)
committerR David Murray <rdmurray@bitdance.com>2016-09-09 22:39:18 (GMT)
commit06ed218ed0020003ac388572fbcf09b88075b664 (patch)
tree5a44ac4cb9d85a35fe9d8423e6700176f5de8352 /Lib
parent37df068e862c4bbab16da00de72655c4a737ea94 (diff)
downloadcpython-06ed218ed0020003ac388572fbcf09b88075b664.zip
cpython-06ed218ed0020003ac388572fbcf09b88075b664.tar.gz
cpython-06ed218ed0020003ac388572fbcf09b88075b664.tar.bz2
#20476: add a message_factory policy attribute to email.
Diffstat (limited to 'Lib')
-rw-r--r--Lib/email/_policybase.py4
-rw-r--r--Lib/email/feedparser.py9
-rw-r--r--Lib/email/message.py8
-rw-r--r--Lib/email/policy.py2
-rw-r--r--Lib/test/test_email/test_parser.py91
-rw-r--r--Lib/test/test_email/test_policy.py42
6 files changed, 99 insertions, 57 deletions
diff --git a/Lib/email/_policybase.py b/Lib/email/_policybase.py
index c0d98a4..d699484 100644
--- a/Lib/email/_policybase.py
+++ b/Lib/email/_policybase.py
@@ -154,6 +154,8 @@ class Policy(_PolicyBase, metaclass=abc.ABCMeta):
them. This is used when the message is being
serialized by a generator. Default: True.
+ message_factory -- the class to use to create new message objects.
+
"""
raise_on_defect = False
@@ -161,6 +163,8 @@ class Policy(_PolicyBase, metaclass=abc.ABCMeta):
cte_type = '8bit'
max_line_length = 78
mangle_from_ = False
+ # XXX To avoid circular imports, this is set in email.message.
+ message_factory = None
def handle_defect(self, obj, defect):
"""Based on policy, either raise defect or call register_defect.
diff --git a/Lib/email/feedparser.py b/Lib/email/feedparser.py
index 2fa77d7..3d74978 100644
--- a/Lib/email/feedparser.py
+++ b/Lib/email/feedparser.py
@@ -24,7 +24,6 @@ __all__ = ['FeedParser', 'BytesFeedParser']
import re
from email import errors
-from email import message
from email._policybase import compat32
from collections import deque
from io import StringIO
@@ -148,13 +147,7 @@ class FeedParser:
self.policy = policy
self._old_style_factory = False
if _factory is None:
- # What this should be:
- #self._factory = policy.default_message_factory
- # but, because we are post 3.4 feature freeze, fix with temp hack:
- if self.policy is compat32:
- self._factory = message.Message
- else:
- self._factory = message.EmailMessage
+ self._factory = policy.message_factory
else:
self._factory = _factory
try:
diff --git a/Lib/email/message.py b/Lib/email/message.py
index c07da43..f4380d9 100644
--- a/Lib/email/message.py
+++ b/Lib/email/message.py
@@ -4,18 +4,17 @@
"""Basic message object for the email package object model."""
-__all__ = ['Message']
+__all__ = ['Message', 'EmailMessage']
import re
import uu
import quopri
-import warnings
from io import BytesIO, StringIO
# Intrapackage imports
from email import utils
from email import errors
-from email._policybase import compat32
+from email._policybase import Policy, compat32
from email import charset as _charset
from email._encoded_words import decode_b
Charset = _charset.Charset
@@ -1163,3 +1162,6 @@ class EmailMessage(MIMEPart):
super().set_content(*args, **kw)
if 'MIME-Version' not in self:
self['MIME-Version'] = '1.0'
+
+# Set message_factory on Policy here to avoid a circular import.
+Policy.message_factory = Message
diff --git a/Lib/email/policy.py b/Lib/email/policy.py
index 35d0e69..5131311ac 100644
--- a/Lib/email/policy.py
+++ b/Lib/email/policy.py
@@ -7,6 +7,7 @@ from email._policybase import Policy, Compat32, compat32, _extend_docstrings
from email.utils import _has_surrogates
from email.headerregistry import HeaderRegistry as HeaderRegistry
from email.contentmanager import raw_data_manager
+from email.message import EmailMessage
__all__ = [
'Compat32',
@@ -82,6 +83,7 @@ class EmailPolicy(Policy):
"""
+ message_factory = EmailMessage
utf8 = False
refold_source = 'long'
header_factory = HeaderRegistry()
diff --git a/Lib/test/test_email/test_parser.py b/Lib/test/test_email/test_parser.py
index 8ddc1763..06c8640 100644
--- a/Lib/test/test_email/test_parser.py
+++ b/Lib/test/test_email/test_parser.py
@@ -1,7 +1,7 @@
import io
import email
import unittest
-from email.message import Message
+from email.message import Message, EmailMessage
from email.policy import default
from test.test_email import TestEmailBase
@@ -39,38 +39,71 @@ class TestParserBase:
# The unicode line splitter splits on unicode linebreaks, which are
# more numerous than allowed by the email RFCs; make sure we are only
# splitting on those two.
- msg = self.parser(
- "Next-Line: not\x85broken\r\n"
- "Null: not\x00broken\r\n"
- "Vertical-Tab: not\vbroken\r\n"
- "Form-Feed: not\fbroken\r\n"
- "File-Separator: not\x1Cbroken\r\n"
- "Group-Separator: not\x1Dbroken\r\n"
- "Record-Separator: not\x1Ebroken\r\n"
- "Line-Separator: not\u2028broken\r\n"
- "Paragraph-Separator: not\u2029broken\r\n"
- "\r\n",
- policy=default,
- )
- self.assertEqual(msg.items(), [
- ("Next-Line", "not\x85broken"),
- ("Null", "not\x00broken"),
- ("Vertical-Tab", "not\vbroken"),
- ("Form-Feed", "not\fbroken"),
- ("File-Separator", "not\x1Cbroken"),
- ("Group-Separator", "not\x1Dbroken"),
- ("Record-Separator", "not\x1Ebroken"),
- ("Line-Separator", "not\u2028broken"),
- ("Paragraph-Separator", "not\u2029broken"),
- ])
- self.assertEqual(msg.get_payload(), "")
+ for parser in self.parsers:
+ with self.subTest(parser=parser.__name__):
+ msg = parser(
+ "Next-Line: not\x85broken\r\n"
+ "Null: not\x00broken\r\n"
+ "Vertical-Tab: not\vbroken\r\n"
+ "Form-Feed: not\fbroken\r\n"
+ "File-Separator: not\x1Cbroken\r\n"
+ "Group-Separator: not\x1Dbroken\r\n"
+ "Record-Separator: not\x1Ebroken\r\n"
+ "Line-Separator: not\u2028broken\r\n"
+ "Paragraph-Separator: not\u2029broken\r\n"
+ "\r\n",
+ policy=default,
+ )
+ self.assertEqual(msg.items(), [
+ ("Next-Line", "not\x85broken"),
+ ("Null", "not\x00broken"),
+ ("Vertical-Tab", "not\vbroken"),
+ ("Form-Feed", "not\fbroken"),
+ ("File-Separator", "not\x1Cbroken"),
+ ("Group-Separator", "not\x1Dbroken"),
+ ("Record-Separator", "not\x1Ebroken"),
+ ("Line-Separator", "not\u2028broken"),
+ ("Paragraph-Separator", "not\u2029broken"),
+ ])
+ self.assertEqual(msg.get_payload(), "")
+
+ class MyMessage(EmailMessage):
+ pass
+
+ def test_custom_message_factory_on_policy(self):
+ for parser in self.parsers:
+ with self.subTest(parser=parser.__name__):
+ MyPolicy = default.clone(message_factory=self.MyMessage)
+ msg = parser("To: foo\n\ntest", policy=MyPolicy)
+ self.assertIsInstance(msg, self.MyMessage)
+
+ def test_factory_arg_overrides_policy(self):
+ for parser in self.parsers:
+ with self.subTest(parser=parser.__name__):
+ MyPolicy = default.clone(message_factory=self.MyMessage)
+ msg = parser("To: foo\n\ntest", Message, policy=MyPolicy)
+ self.assertNotIsInstance(msg, self.MyMessage)
+ self.assertIsInstance(msg, Message)
+
+# Play some games to get nice output in subTest. This code could be clearer
+# if staticmethod supported __name__.
+
+def message_from_file(s, *args, **kw):
+ f = io.StringIO(s)
+ return email.message_from_file(f, *args, **kw)
class TestParser(TestParserBase, TestEmailBase):
- parser = staticmethod(email.message_from_string)
+ parsers = (email.message_from_string, message_from_file)
+
+def message_from_bytes(s, *args, **kw):
+ return email.message_from_bytes(s.encode(), *args, **kw)
+
+def message_from_binary_file(s, *args, **kw):
+ f = io.BytesIO(s.encode())
+ return email.message_from_binary_file(f, *args, **kw)
class TestBytesParser(TestParserBase, TestEmailBase):
- def parser(self, s, *args, **kw):
- return email.message_from_bytes(s.encode(), *args, **kw)
+ parsers = (message_from_bytes, message_from_binary_file)
if __name__ == '__main__':
diff --git a/Lib/test/test_email/test_policy.py b/Lib/test/test_email/test_policy.py
index 70ac4db..1d95d03 100644
--- a/Lib/test/test_email/test_policy.py
+++ b/Lib/test/test_email/test_policy.py
@@ -5,6 +5,7 @@ import unittest
import email.policy
import email.parser
import email.generator
+import email.message
from email import headerregistry
def make_defaults(base_defaults, differences):
@@ -23,6 +24,7 @@ class PolicyAPITests(unittest.TestCase):
'cte_type': '8bit',
'raise_on_defect': False,
'mangle_from_': True,
+ 'message_factory': email.message.Message,
}
# These default values are the ones set on email.policy.default.
# If any of these defaults change, the docs must be updated.
@@ -34,6 +36,7 @@ class PolicyAPITests(unittest.TestCase):
'refold_source': 'long',
'content_manager': email.policy.EmailPolicy.content_manager,
'mangle_from_': False,
+ 'message_factory': email.message.EmailMessage,
})
# For each policy under test, we give here what we expect the defaults to
@@ -62,20 +65,22 @@ class PolicyAPITests(unittest.TestCase):
def test_defaults(self):
for policy, expected in self.policies.items():
for attr, value in expected.items():
- self.assertEqual(getattr(policy, attr), value,
- ("change {} docs/docstrings if defaults have "
- "changed").format(policy))
+ with self.subTest(policy=policy, attr=attr):
+ self.assertEqual(getattr(policy, attr), value,
+ ("change {} docs/docstrings if defaults have "
+ "changed").format(policy))
def test_all_attributes_covered(self):
for policy, expected in self.policies.items():
for attr in dir(policy):
- if (attr.startswith('_') or
- isinstance(getattr(email.policy.EmailPolicy, attr),
- types.FunctionType)):
- continue
- else:
- self.assertIn(attr, expected,
- "{} is not fully tested".format(attr))
+ with self.subTest(policy=policy, attr=attr):
+ if (attr.startswith('_') or
+ isinstance(getattr(email.policy.EmailPolicy, attr),
+ types.FunctionType)):
+ continue
+ else:
+ self.assertIn(attr, expected,
+ "{} is not fully tested".format(attr))
def test_abc(self):
with self.assertRaises(TypeError) as cm:
@@ -237,6 +242,9 @@ class PolicyAPITests(unittest.TestCase):
# wins), but that the order still works (right overrides left).
+class TestException(Exception):
+ pass
+
class TestPolicyPropagation(unittest.TestCase):
# The abstract methods are used by the parser but not by the wrapper
@@ -244,40 +252,40 @@ class TestPolicyPropagation(unittest.TestCase):
# policy was actually propagated all the way to feedparser.
class MyPolicy(email.policy.Policy):
def badmethod(self, *args, **kw):
- raise Exception("test")
+ raise TestException("test")
fold = fold_binary = header_fetch_parser = badmethod
header_source_parse = header_store_parse = badmethod
def test_message_from_string(self):
- with self.assertRaisesRegex(Exception, "^test$"):
+ with self.assertRaisesRegex(TestException, "^test$"):
email.message_from_string("Subject: test\n\n",
policy=self.MyPolicy)
def test_message_from_bytes(self):
- with self.assertRaisesRegex(Exception, "^test$"):
+ with self.assertRaisesRegex(TestException, "^test$"):
email.message_from_bytes(b"Subject: test\n\n",
policy=self.MyPolicy)
def test_message_from_file(self):
f = io.StringIO('Subject: test\n\n')
- with self.assertRaisesRegex(Exception, "^test$"):
+ with self.assertRaisesRegex(TestException, "^test$"):
email.message_from_file(f, policy=self.MyPolicy)
def test_message_from_binary_file(self):
f = io.BytesIO(b'Subject: test\n\n')
- with self.assertRaisesRegex(Exception, "^test$"):
+ with self.assertRaisesRegex(TestException, "^test$"):
email.message_from_binary_file(f, policy=self.MyPolicy)
# These are redundant, but we need them for black-box completeness.
def test_parser(self):
p = email.parser.Parser(policy=self.MyPolicy)
- with self.assertRaisesRegex(Exception, "^test$"):
+ with self.assertRaisesRegex(TestException, "^test$"):
p.parsestr('Subject: test\n\n')
def test_bytes_parser(self):
p = email.parser.BytesParser(policy=self.MyPolicy)
- with self.assertRaisesRegex(Exception, "^test$"):
+ with self.assertRaisesRegex(TestException, "^test$"):
p.parsebytes(b'Subject: test\n\n')
# Now that we've established that all the parse methods get the