summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/dbm/__init__.py42
-rw-r--r--Lib/dbm/bsd.py3
-rw-r--r--Lib/test/test_dbm.py51
-rw-r--r--Modules/_dbmmodule.c3
-rw-r--r--Modules/_gdbmmodule.c2
5 files changed, 50 insertions, 51 deletions
diff --git a/Lib/dbm/__init__.py b/Lib/dbm/__init__.py
index 9fdd414..2082e07 100644
--- a/Lib/dbm/__init__.py
+++ b/Lib/dbm/__init__.py
@@ -48,27 +48,26 @@ class error(Exception):
pass
_names = ['dbm.bsd', 'dbm.gnu', 'dbm.ndbm', 'dbm.dumb']
-_errors = [error]
_defaultmod = None
_modules = {}
-for _name in _names:
- try:
- _mod = __import__(_name, fromlist=['open'])
- except ImportError:
- continue
- if not _defaultmod:
- _defaultmod = _mod
- _modules[_name] = _mod
- _errors.append(_mod.error)
-
-if not _defaultmod:
- raise ImportError("no dbm clone found; tried %s" % _names)
-
-error = tuple(_errors)
+error = (error, IOError)
def open(file, flag = 'r', mode = 0o666):
+ global _defaultmod
+ if _defaultmod is None:
+ for name in _names:
+ try:
+ mod = __import__(name, fromlist=['open'])
+ except ImportError:
+ continue
+ if not _defaultmod:
+ _defaultmod = mod
+ _modules[name] = mod
+ if not _defaultmod:
+ raise ImportError("no dbm clone found; tried %s" % _names)
+
# guess the type of an existing database
result = whichdb(file)
if result is None:
@@ -81,19 +80,14 @@ def open(file, flag = 'r', mode = 0o666):
elif result == "":
# db type cannot be determined
raise error("db type could not be determined")
+ elif result not in _modules:
+ raise error("db type is {0}, but the module is not "
+ "available".format(result))
else:
mod = _modules[result]
return mod.open(file, flag, mode)
-try:
- from dbm import ndbm
- _dbmerror = ndbm.error
-except ImportError:
- ndbm = None
- # just some sort of valid exception which might be raised in the ndbm test
- _dbmerror = IOError
-
def whichdb(filename):
"""Guess which db package to use to open a db file.
@@ -129,7 +123,7 @@ def whichdb(filename):
d = ndbm.open(filename)
d.close()
return "dbm.ndbm"
- except (IOError, _dbmerror):
+ except IOError:
pass
# Check for dumbdbm next -- this has a .dir and a .dat file
diff --git a/Lib/dbm/bsd.py b/Lib/dbm/bsd.py
index 8353f50..2dccadb 100644
--- a/Lib/dbm/bsd.py
+++ b/Lib/dbm/bsd.py
@@ -4,7 +4,8 @@ import bsddb
__all__ = ["error", "open"]
-error = bsddb.error
+class error(bsddb.error, IOError):
+ pass
def open(file, flag = 'r', mode=0o666):
return bsddb.hashopen(file, flag, mode)
diff --git a/Lib/test/test_dbm.py b/Lib/test/test_dbm.py
index aab1388..41c37cb 100644
--- a/Lib/test/test_dbm.py
+++ b/Lib/test/test_dbm.py
@@ -14,11 +14,13 @@ _fname = test.support.TESTFN
# setting dbm to use each in turn, and yielding that module
#
def dbm_iterator():
- old_default = dbm._defaultmod
- for module in dbm._modules.values():
- dbm._defaultmod = module
- yield module
- dbm._defaultmod = old_default
+ for name in dbm._names:
+ try:
+ mod = __import__(name, fromlist=['open'])
+ except ImportError:
+ continue
+ dbm._modules[name] = mod
+ yield mod
#
# Clean up all scratch databases we might have created during testing
@@ -40,8 +42,20 @@ class AnyDBMTestCase(unittest.TestCase):
'g': b'intended',
}
- def __init__(self, *args):
- unittest.TestCase.__init__(self, *args)
+ def init_db(self):
+ f = dbm.open(_fname, 'n')
+ for k in self._dict:
+ f[k.encode("ascii")] = self._dict[k]
+ f.close()
+
+ def keys_helper(self, f):
+ keys = sorted(k.decode("ascii") for k in f.keys())
+ dkeys = sorted(self._dict.keys())
+ self.assertEqual(keys, dkeys)
+ return keys
+
+ def test_error(self):
+ self.assert_(issubclass(self.module.error, IOError))
def test_anydbm_creation(self):
f = dbm.open(_fname, 'c')
@@ -83,22 +97,11 @@ class AnyDBMTestCase(unittest.TestCase):
for key in self._dict:
self.assertEqual(self._dict[key], f[key.encode("ascii")])
- def init_db(self):
- f = dbm.open(_fname, 'n')
- for k in self._dict:
- f[k.encode("ascii")] = self._dict[k]
- f.close()
-
- def keys_helper(self, f):
- keys = sorted(k.decode("ascii") for k in f.keys())
- dkeys = sorted(self._dict.keys())
- self.assertEqual(keys, dkeys)
- return keys
-
def tearDown(self):
delete_files()
def setUp(self):
+ dbm._defaultmod = self.module
delete_files()
@@ -137,11 +140,11 @@ class WhichDBTestCase(unittest.TestCase):
def test_main():
- try:
- for module in dbm_iterator():
- test.support.run_unittest(AnyDBMTestCase, WhichDBTestCase)
- finally:
- delete_files()
+ classes = [WhichDBTestCase]
+ for mod in dbm_iterator():
+ classes.append(type("TestCase-" + mod.__name__, (AnyDBMTestCase,),
+ {'module': mod}))
+ test.support.run_unittest(*classes)
if __name__ == "__main__":
test_main()
diff --git a/Modules/_dbmmodule.c b/Modules/_dbmmodule.c
index ddfd4cd..7e80381 100644
--- a/Modules/_dbmmodule.c
+++ b/Modules/_dbmmodule.c
@@ -401,7 +401,8 @@ init_dbm(void) {
return;
d = PyModule_GetDict(m);
if (DbmError == NULL)
- DbmError = PyErr_NewException("_dbm.error", NULL, NULL);
+ DbmError = PyErr_NewException("_dbm.error",
+ PyExc_IOError, NULL);
s = PyUnicode_FromString(which_dbm);
if (s != NULL) {
PyDict_SetItemString(d, "library", s);
diff --git a/Modules/_gdbmmodule.c b/Modules/_gdbmmodule.c
index 6c75819..abc8837 100644
--- a/Modules/_gdbmmodule.c
+++ b/Modules/_gdbmmodule.c
@@ -523,7 +523,7 @@ init_gdbm(void) {
if (m == NULL)
return;
d = PyModule_GetDict(m);
- DbmError = PyErr_NewException("_gdbm.error", NULL, NULL);
+ DbmError = PyErr_NewException("_gdbm.error", PyExc_IOError, NULL);
if (DbmError != NULL) {
PyDict_SetItemString(d, "error", DbmError);
s = PyUnicode_FromString(dbmmodule_open_flags);