summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
Diffstat (limited to 'Lib')
-rw-r--r--Lib/sre.py4
-rw-r--r--Lib/sre_compile.py10
-rw-r--r--Lib/test/test_re.py10
3 files changed, 17 insertions, 7 deletions
diff --git a/Lib/sre.py b/Lib/sre.py
index 7e107a6..7910c83 100644
--- a/Lib/sre.py
+++ b/Lib/sre.py
@@ -219,9 +219,9 @@ def _compile(*key):
if p is not None:
return p
pattern, flags = key
- if type(pattern) is _pattern_type:
+ if isinstance(pattern, _pattern_type):
return pattern
- if type(pattern) not in sre_compile.STRING_TYPES:
+ if not isinstance(pattern, sre_compile.STRING_TYPES):
raise TypeError, "first argument must be string or compiled pattern"
try:
p = sre_compile.compile(pattern, flags)
diff --git a/Lib/sre_compile.py b/Lib/sre_compile.py
index 1d59d7e..96f337a 100644
--- a/Lib/sre_compile.py
+++ b/Lib/sre_compile.py
@@ -428,12 +428,12 @@ def _compile_info(code, pattern, flags):
_compile_charset(charset, flags, code)
code[skip] = len(code) - skip
-STRING_TYPES = [type("")]
-
try:
- STRING_TYPES.append(type(unicode("")))
+ unicode
except NameError:
- pass
+ STRING_TYPES = type("")
+else:
+ STRING_TYPES = (type(""), type(unicode("")))
def _code(p, flags):
@@ -453,7 +453,7 @@ def _code(p, flags):
def compile(p, flags=0):
# internal: convert pattern list to internal format
- if type(p) in STRING_TYPES:
+ if isinstance(p, STRING_TYPES):
import sre_parse
pattern = p
p = sre_parse.parse(p, flags)
diff --git a/Lib/test/test_re.py b/Lib/test/test_re.py
index 50d7ed4..f724806 100644
--- a/Lib/test/test_re.py
+++ b/Lib/test/test_re.py
@@ -474,6 +474,16 @@ class ReTests(unittest.TestCase):
self.assertEqual(re.match('(a)((?!(b)*))*', 'abb').groups(),
('a', None, None))
+ def test_bug_764548(self):
+ # bug 764548, re.compile() barfs on str/unicode subclasses
+ try:
+ unicode
+ except NameError:
+ return # no problem if we have no unicode
+ class my_unicode(unicode): pass
+ pat = re.compile(my_unicode("abc"))
+ self.assertEqual(pat.match("xyz"), None)
+
def test_finditer(self):
iter = re.finditer(r":+", "a:b::c:::d")
self.assertEqual([item.group(0) for item in iter],