From 12723baceab61f8812d68575c962696cc4e77fa1 Mon Sep 17 00:00:00 2001
From: Just van Rossum <just@letterror.com>
Date: Wed, 2 Jul 2003 20:03:04 +0000
Subject: Fix and test for bug #764548: Use isinstance() instead of comparing
 types directly, to enable subclasses of str and unicode to be used as
 patterns. Blessed by /F.

---
 Lib/sre.py          |  4 ++--
 Lib/sre_compile.py  | 10 +++++-----
 Lib/test/test_re.py | 10 ++++++++++
 3 files changed, 17 insertions(+), 7 deletions(-)

diff --git a/Lib/sre.py b/Lib/sre.py
index 7e107a6..7910c83 100644
--- a/Lib/sre.py
+++ b/Lib/sre.py
@@ -219,9 +219,9 @@ def _compile(*key):
     if p is not None:
         return p
     pattern, flags = key
-    if type(pattern) is _pattern_type:
+    if isinstance(pattern, _pattern_type):
         return pattern
-    if type(pattern) not in sre_compile.STRING_TYPES:
+    if not isinstance(pattern, sre_compile.STRING_TYPES):
         raise TypeError, "first argument must be string or compiled pattern"
     try:
         p = sre_compile.compile(pattern, flags)
diff --git a/Lib/sre_compile.py b/Lib/sre_compile.py
index 1d59d7e..96f337a 100644
--- a/Lib/sre_compile.py
+++ b/Lib/sre_compile.py
@@ -428,12 +428,12 @@ def _compile_info(code, pattern, flags):
         _compile_charset(charset, flags, code)
     code[skip] = len(code) - skip
 
-STRING_TYPES = [type("")]
-
 try:
-    STRING_TYPES.append(type(unicode("")))
+    unicode
 except NameError:
-    pass
+    STRING_TYPES = type("")
+else:
+    STRING_TYPES = (type(""), type(unicode("")))
 
 def _code(p, flags):
 
@@ -453,7 +453,7 @@ def _code(p, flags):
 def compile(p, flags=0):
     # internal: convert pattern list to internal format
 
-    if type(p) in STRING_TYPES:
+    if isinstance(p, STRING_TYPES):
         import sre_parse
         pattern = p
         p = sre_parse.parse(p, flags)
diff --git a/Lib/test/test_re.py b/Lib/test/test_re.py
index 50d7ed4..f724806 100644
--- a/Lib/test/test_re.py
+++ b/Lib/test/test_re.py
@@ -474,6 +474,16 @@ class ReTests(unittest.TestCase):
         self.assertEqual(re.match('(a)((?!(b)*))*', 'abb').groups(),
                          ('a', None, None))
 
+    def test_bug_764548(self):
+        # bug 764548, re.compile() barfs on str/unicode subclasses
+        try:
+            unicode
+        except NameError:
+            return  # no problem if we have no unicode
+        class my_unicode(unicode): pass
+        pat = re.compile(my_unicode("abc"))
+        self.assertEqual(pat.match("xyz"), None)
+
     def test_finditer(self):
         iter = re.finditer(r":+", "a:b::c:::d")
         self.assertEqual([item.group(0) for item in iter],
-- 
cgit v0.12