summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--database/__init__.py0
-rw-r--r--database/common.py8
-rw-r--r--database/database_connection.py250
-rw-r--r--database/database_connection_unittest.py186
-rw-r--r--frontend/migrations/001_initial_db.py4
-rwxr-xr-xmigrate/migrate.py105
-rw-r--r--migrate/migrate_unittest.py50
-rw-r--r--scheduler/monitor_db_unittest.py8
-rwxr-xr-xtko/migrations/001_initial_db.py4
9 files changed, 500 insertions, 115 deletions
diff --git a/database/__init__.py b/database/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/database/__init__.py
diff --git a/database/common.py b/database/common.py
new file mode 100644
index 00000000..9941b190
--- /dev/null
+++ b/database/common.py
@@ -0,0 +1,8 @@
+import os, sys
+dirname = os.path.dirname(sys.modules[__name__].__file__)
+autotest_dir = os.path.abspath(os.path.join(dirname, ".."))
+client_dir = os.path.join(autotest_dir, "client")
+sys.path.insert(0, client_dir)
+import setup_modules
+sys.path.pop(0)
+setup_modules.setup(base_path=autotest_dir, root_module_name="autotest_lib")
diff --git a/database/database_connection.py b/database/database_connection.py
new file mode 100644
index 00000000..6d617287
--- /dev/null
+++ b/database/database_connection.py
@@ -0,0 +1,250 @@
+import traceback, time
+import common
+from autotest_lib.client.common_lib import global_config
+
+RECONNECT_FOREVER = object()
+
+_DB_EXCEPTIONS = ('DatabaseError', 'OperationalError', 'ProgrammingError')
+_GLOBAL_CONFIG_NAMES = {
+ 'username' : 'user',
+ 'db_name' : 'database',
+}
+
+def _copy_exceptions(source, destination):
+ for exception_name in _DB_EXCEPTIONS:
+ setattr(destination, exception_name, getattr(source, exception_name))
+
+
+class _GenericBackend(object):
+ def __init__(self, database_module):
+ self._database_module = database_module
+ self._connection = None
+ self._cursor = None
+ self.rowcount = None
+ _copy_exceptions(database_module, self)
+
+
+ def connect(self, host=None, username=None, password=None, db_name=None):
+ """
+ This is assumed to enable autocommit.
+ """
+ raise NotImplementedError
+
+
+ def disconnect(self):
+ if self._connection:
+ self._connection.close()
+ self._connection = None
+ self._cursor = None
+
+
+ def execute(self, query, arguments=None):
+ self._cursor.execute(query, arguments)
+ self.rowcount = self._cursor.rowcount
+ return self._cursor.fetchall()
+
+
+ def get_exception_details(exception):
+ return ExceptionDetails.UNKNOWN
+
+
+class _MySqlBackend(_GenericBackend):
+ def __init__(self):
+ import MySQLdb
+ super(_MySqlBackend, self).__init__(MySQLdb)
+
+
+ @staticmethod
+ def convert_boolean(boolean, conversion_dict):
+ 'Convert booleans to integer strings'
+ return str(int(boolean))
+
+
+ def connect(self, host=None, username=None, password=None, db_name=None):
+ import MySQLdb.converters
+ convert_dict = MySQLdb.converters.conversions
+ convert_dict.setdefault(bool, self.convert_boolean)
+
+ self._connection = self._database_module.connect(
+ host=host, user=username, passwd=password, db=db_name,
+ conv=convert_dict)
+ self._connection.autocommit(True)
+ self._cursor = self._connection.cursor()
+
+
+ def get_exception_details(exception):
+ pass
+
+
+class _SqliteBackend(_GenericBackend):
+ def __init__(self):
+ from pysqlite2 import dbapi2
+ super(_SqliteBackend, self).__init__(dbapi2)
+
+
+ def connect(self, host=None, username=None, password=None, db_name=None):
+ self._connection = self._database_module.connect(db_name)
+ self._connection.isolation_level = None # enable autocommit
+ self._cursor = self._connection.cursor()
+
+
+ def execute(self, query, arguments=None):
+ # pysqlite2 uses paramstyle=qmark
+ # TODO: make this more sophisticated if necessary
+ query = query.replace('%s', '?')
+ return super(_SqliteBackend, self).execute(query, arguments)
+
+
+_BACKEND_MAP = {
+ 'mysql' : _MySqlBackend,
+ 'sqlite' : _SqliteBackend,
+}
+
+
+class DatabaseConnection(object):
+ """
+ Generic wrapper for a database connection. Supports both mysql and sqlite
+ backends.
+
+ Public attributes:
+ * reconnect_enabled: if True, when an OperationalError occurs the class will
+ try to reconnect to the database automatically.
+ * reconnect_delay_sec: seconds to wait before reconnecting
+ * max_reconnect_attempts: maximum number of time to try reconnecting before
+ giving up. Setting to RECONNECT_FOREVER removes the limit.
+ * rowcount - will hold cursor.rowcount after each call to execute().
+ * global_config_section - the section in which to find DB information. this
+ should be passed to the constructor, not set later, and may be None, in
+ which case information must be passed to connect().
+ """
+ _DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password',
+ 'db_name')
+
+ def __init__(self, global_config_section=None):
+ self.global_config_section = global_config_section
+ self._backend = None
+ self.rowcount = None
+
+ # reconnect defaults
+ self.reconnect_enabled = True
+ self.reconnect_delay_sec = 20
+ self.max_reconnect_attempts = 10
+
+ self._read_options()
+
+
+ def _get_option(self, name, provided_value):
+ if provided_value is not None:
+ return provided_value
+ if self.global_config_section:
+ global_config_name = _GLOBAL_CONFIG_NAMES.get(name, name)
+ return global_config.global_config.get_config_value(
+ self.global_config_section, global_config_name)
+ return getattr(self, name, None)
+
+
+ def _read_options(self, db_type=None, host=None, username=None,
+ password=None, db_name=None):
+ self.db_type = self._get_option('db_type', db_type)
+ self.host = self._get_option('host', host)
+ self.username = self._get_option('username', username)
+ self.password = self._get_option('password', password)
+ self.db_name = self._get_option('db_name', db_name)
+
+
+ def _get_backend(self, db_type):
+ if db_type not in _BACKEND_MAP:
+ raise ValueError('Invalid database type: %s, should be one of %s' %
+ (db_type, ', '.join(_BACKEND_MAP.keys())))
+ backend_class = _BACKEND_MAP[db_type]
+ return backend_class()
+
+
+ def _reached_max_attempts(self, num_attempts):
+ return (self.max_reconnect_attempts is not RECONNECT_FOREVER and
+ num_attempts > self.max_reconnect_attempts)
+
+
+ def _is_reconnect_enabled(self, supplied_param):
+ if supplied_param is not None:
+ return supplied_param
+ return self.reconnect_enabled
+
+
+ def _connect_backend(self, try_reconnecting=None):
+ num_attempts = 0
+ while True:
+ try:
+ self._backend.connect(host=self.host, username=self.username,
+ password=self.password,
+ db_name=self.db_name)
+ return
+ except self._backend.OperationalError:
+ num_attempts += 1
+ if not self._is_reconnect_enabled(try_reconnecting):
+ raise
+ if self._reached_max_attempts(num_attempts):
+ raise
+ traceback.print_exc()
+ print ("Can't connect to database; reconnecting in %s sec" %
+ self.reconnect_delay_sec)
+ time.sleep(self.reconnect_delay_sec)
+ self.disconnect()
+
+
+ def connect(self, db_type=None, host=None, username=None, password=None,
+ db_name=None, try_reconnecting=None):
+ """
+ Parameters passed to this function will override defaults from global
+ config. try_reconnecting, if passed, will override
+ self.reconnect_enabled.
+ """
+ self.disconnect()
+ self._read_options(db_type, host, username, password, db_name)
+
+ self._backend = self._get_backend(self.db_type)
+ _copy_exceptions(self._backend, self)
+ self._connect_backend(try_reconnecting)
+
+
+ def disconnect(self):
+ if self._backend:
+ self._backend.disconnect()
+
+
+ def execute(self, query, parameters=None, try_reconnecting=None):
+ """
+ Execute a query and return cursor.fetchall(). try_reconnecting, if
+ passed, will override self.reconnect_enabled.
+ """
+ # _connect_backend() contains a retry loop, so don't loop here
+ try:
+ results = self._backend.execute(query, parameters)
+ except self._backend.OperationalError:
+ if not self._is_reconnect_enabled(try_reconnecting):
+ raise
+ traceback.print_exc()
+ print ("MYSQL connection died; reconnecting")
+ self.disconnect()
+ self._connect_backend(try_reconnecting)
+ results = self._backend.execute(query, parameters)
+
+ self.rowcount = self._backend.rowcount
+ return results
+
+
+ def get_database_info(self):
+ return dict((attribute, getattr(self, attribute))
+ for attribute in self._DATABASE_ATTRIBUTES)
+
+
+ @classmethod
+ def get_test_database(cls):
+ """
+ Factory method returning a DatabaseConnection for a temporary in-memory
+ database.
+ """
+ database = cls()
+ database.reconnect_enabled = False
+ database.connect(db_type='sqlite', db_name=':memory:')
+ return database
diff --git a/database/database_connection_unittest.py b/database/database_connection_unittest.py
new file mode 100644
index 00000000..95df738a
--- /dev/null
+++ b/database/database_connection_unittest.py
@@ -0,0 +1,186 @@
+#!/usr/bin/python2.4
+
+import unittest, time
+import MySQLdb
+import pysqlite2.dbapi2
+import common
+from autotest_lib.client.common_lib import global_config
+from autotest_lib.client.common_lib.test_utils import mock
+from autotest_lib.database import database_connection
+
+_CONFIG_SECTION = 'TKO'
+_HOST = 'myhost'
+_USER = 'myuser'
+_PASS = 'mypass'
+_DB_NAME = 'mydb'
+_DB_TYPE = 'mydbtype'
+
+_CONNECT_KWARGS = dict(host=_HOST, username=_USER, password=_PASS,
+ db_name=_DB_NAME)
+_RECONNECT_DELAY = 10
+
+class FakeDatabaseError(Exception):
+ pass
+
+
+class DatabaseConnectionTest(unittest.TestCase):
+ def setUp(self):
+ self.god = mock.mock_god()
+ self.god.stub_function(time, 'sleep')
+
+
+ def tearDown(self):
+ global_config.global_config.reset_config_values()
+ self.god.unstub_all()
+
+
+ def _get_database_connection(self, config_section=_CONFIG_SECTION):
+ if config_section == _CONFIG_SECTION:
+ self._override_config()
+ db = database_connection.DatabaseConnection(config_section)
+
+ self._fake_backend = self.god.create_mock_class(
+ database_connection._GenericBackend, 'fake_backend')
+ for exception in database_connection._DB_EXCEPTIONS:
+ setattr(self._fake_backend, exception, FakeDatabaseError)
+ self._fake_backend.rowcount = 0
+
+ def get_fake_backend(db_type):
+ self._db_type = db_type
+ return self._fake_backend
+ self.god.stub_with(db, '_get_backend', get_fake_backend)
+
+ db.reconnect_delay_sec = _RECONNECT_DELAY
+ return db
+
+
+ def _override_config(self):
+ c = global_config.global_config
+ c.override_config_value(_CONFIG_SECTION, 'host', _HOST)
+ c.override_config_value(_CONFIG_SECTION, 'user', _USER)
+ c.override_config_value(_CONFIG_SECTION, 'password', _PASS)
+ c.override_config_value(_CONFIG_SECTION, 'database', _DB_NAME)
+ c.override_config_value(_CONFIG_SECTION, 'db_type', _DB_TYPE)
+
+
+ def test_connect(self):
+ db = self._get_database_connection(config_section=None)
+ self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
+
+ db.connect(db_type=_DB_TYPE, host=_HOST, username=_USER,
+ password=_PASS, db_name=_DB_NAME)
+
+ self.assertEquals(self._db_type, _DB_TYPE)
+ self.god.check_playback()
+
+
+ def test_global_config(self):
+ db = self._get_database_connection()
+ self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
+
+ db.connect()
+
+ self.assertEquals(self._db_type, _DB_TYPE)
+ self.god.check_playback()
+
+
+ def _expect_reconnect(self, fail=False):
+ self._fake_backend.disconnect.expect_call()
+ call = self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
+ if fail:
+ call.and_raises(FakeDatabaseError())
+
+
+ def _expect_fail_and_reconnect(self, num_reconnects, fail_last=False):
+ self._fake_backend.connect.expect_call(**_CONNECT_KWARGS).and_raises(
+ FakeDatabaseError())
+ for i in xrange(num_reconnects):
+ time.sleep.expect_call(_RECONNECT_DELAY)
+ if i < num_reconnects - 1:
+ self._expect_reconnect(fail=True)
+ else:
+ self._expect_reconnect(fail=fail_last)
+
+
+ def test_connect_retry(self):
+ db = self._get_database_connection()
+ self._expect_fail_and_reconnect(1)
+
+ db.connect()
+ self.god.check_playback()
+
+ self._fake_backend.disconnect.expect_call()
+ self._expect_fail_and_reconnect(0)
+ self.assertRaises(FakeDatabaseError, db.connect,
+ try_reconnecting=False)
+ self.god.check_playback()
+
+ db.reconnect_enabled = False
+ self._fake_backend.disconnect.expect_call()
+ self._expect_fail_and_reconnect(0)
+ self.assertRaises(FakeDatabaseError, db.connect)
+ self.god.check_playback()
+
+
+ def test_max_reconnect(self):
+ db = self._get_database_connection()
+ db.max_reconnect_attempts = 5
+ self._expect_fail_and_reconnect(5, fail_last=True)
+
+ self.assertRaises(FakeDatabaseError, db.connect)
+ self.god.check_playback()
+
+
+ def test_reconnect_forever(self):
+ db = self._get_database_connection()
+ db.max_reconnect_attempts = database_connection.RECONNECT_FOREVER
+ self._expect_fail_and_reconnect(30)
+
+ db.connect()
+ self.god.check_playback()
+
+
+ def _simple_connect(self, db):
+ self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
+ db.connect()
+ self.god.check_playback()
+
+
+ def test_disconnect(self):
+ db = self._get_database_connection()
+ self._simple_connect(db)
+ self._fake_backend.disconnect.expect_call()
+
+ db.disconnect()
+ self.god.check_playback()
+
+
+ def test_execute(self):
+ db = self._get_database_connection()
+ self._simple_connect(db)
+ params = object()
+ self._fake_backend.execute.expect_call('query', params)
+
+ db.execute('query', params)
+ self.god.check_playback()
+
+
+ def test_execute_retry(self):
+ db = self._get_database_connection()
+ self._simple_connect(db)
+ self._fake_backend.execute.expect_call('query', None).and_raises(
+ FakeDatabaseError())
+ self._expect_reconnect()
+ self._fake_backend.execute.expect_call('query', None)
+
+ db.execute('query')
+ self.god.check_playback()
+
+ self._fake_backend.execute.expect_call('query', None).and_raises(
+ FakeDatabaseError())
+ self.assertRaises(FakeDatabaseError, db.execute, 'query',
+ try_reconnecting=False)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/frontend/migrations/001_initial_db.py b/frontend/migrations/001_initial_db.py
index 7b78ffa4..15495f71 100644
--- a/frontend/migrations/001_initial_db.py
+++ b/frontend/migrations/001_initial_db.py
@@ -6,8 +6,8 @@ required_tables = ('acl_groups', 'acl_groups_hosts', 'acl_groups_users',
def migrate_up(manager):
- manager.execute("SHOW TABLES")
- tables = [row[0] for row in manager.cursor.fetchall()]
+ rows = manager.execute("SHOW TABLES")
+ tables = [row[0] for row in rows]
db_initialized = True
for table in required_tables:
if table not in tables:
diff --git a/migrate/migrate.py b/migrate/migrate.py
index 241d2194..37c78577 100755
--- a/migrate/migrate.py
+++ b/migrate/migrate.py
@@ -5,6 +5,7 @@ import MySQLdb, MySQLdb.constants.ER
from optparse import OptionParser
import common
from autotest_lib.client.common_lib import global_config
+from autotest_lib.database import database_connection
MIGRATE_TABLE = 'migrate_info'
@@ -37,50 +38,29 @@ class MigrationManager(object):
cursor = None
migrations_dir = None
- def __init__(self, database, migrations_dir=None, force=False):
- self.database = database
+ def __init__(self, database_connection, migrations_dir=None, force=False):
+ self._database = database_connection
self.force = force
+ self._set_migrations_dir(migrations_dir)
+
+
+ def _set_migrations_dir(self, migrations_dir=None):
+ config_section = self._database.global_config_section
if migrations_dir is None:
migrations_dir = os.path.abspath(
- _MIGRATIONS_DIRS.get(database, _DEFAULT_MIGRATIONS_DIR))
+ _MIGRATIONS_DIRS.get(config_section, _DEFAULT_MIGRATIONS_DIR))
self.migrations_dir = migrations_dir
sys.path.append(migrations_dir)
- assert os.path.exists(migrations_dir)
-
- self.db_host = None
- self.db_name = None
- self.username = None
- self.password = None
-
-
- def read_db_info(self):
- # grab the config file and parse for info
- c = global_config.global_config
- self.db_host = c.get_config_value(self.database, "host")
- self.db_name = c.get_config_value(self.database, "database")
- self.username = c.get_config_value(self.database, "user")
- self.password = c.get_config_value(self.database, "password")
-
+ assert os.path.exists(migrations_dir), migrations_dir + " doesn't exist"
- def connect(self, host, db_name, username, password):
- return MySQLdb.connect(host=host, db=db_name, user=username,
- passwd=password)
-
- def open_connection(self):
- self.connection = self.connect(self.db_host, self.db_name,
- self.username, self.password)
- self.connection.autocommit(True)
- self.cursor = self.connection.cursor()
-
-
- def close_connection(self):
- self.connection.close()
+ def _get_db_name(self):
+ return self._database.get_database_info()['db_name']
def execute(self, query, *parameters):
#print 'SQL:', query % parameters
- return self.cursor.execute(query, parameters)
+ return self._database.execute(query, parameters)
def execute_script(self, script):
@@ -95,11 +75,10 @@ class MigrationManager(object):
try:
self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
return True
- except MySQLdb.ProgrammingError, exc:
- error_code, _ = exc.args
- if error_code == MySQLdb.constants.ER.NO_SUCH_TABLE:
- return False
- raise
+ except self._database.DatabaseError, exc:
+ # we can't check for more specifics due to differences between DB
+ # backends (we can't even check for a subclass of DatabaseError)
+ return False
def create_migrate_table(self):
@@ -109,21 +88,20 @@ class MigrationManager(object):
else:
self.execute("DELETE FROM %s" % MIGRATE_TABLE)
self.execute("INSERT INTO %s VALUES (0)" % MIGRATE_TABLE)
- assert self.cursor.rowcount == 1
+ assert self._database.rowcount == 1
def set_db_version(self, version):
assert isinstance(version, int)
self.execute("UPDATE %s SET version=%%s" % MIGRATE_TABLE,
version)
- assert self.cursor.rowcount == 1
+ assert self._database.rowcount == 1
def get_db_version(self):
if not self.check_migrate_table_exists():
return 0
- self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
- rows = self.cursor.fetchall()
+ rows = self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
if len(rows) == 0:
return 0
assert len(rows) == 1 and len(rows[0]) == 1
@@ -190,30 +168,28 @@ class MigrationManager(object):
def initialize_test_db(self):
- self.read_db_info()
- test_db_name = 'test_' + self.db_name
+ db_name = self._get_db_name()
+ test_db_name = 'test_' + db_name
# first, connect to no DB so we can create a test DB
- self.db_name = ''
- self.open_connection()
+ self._database.connect(db_name='')
print 'Creating test DB', test_db_name
self.execute('CREATE DATABASE ' + test_db_name)
- self.close_connection()
+ self._database.disconnect()
# now connect to the test DB
- self.db_name = test_db_name
- self.open_connection()
+ self._database.connect(db_name=test_db_name)
def remove_test_db(self):
print 'Removing test DB'
- self.execute('DROP DATABASE ' + self.db_name)
+ self.execute('DROP DATABASE ' + self._get_db_name())
+ # reset connection back to real DB
+ self._database.disconnect()
+ self._database.connect()
def get_mysql_args(self):
- return ('-u %(user)s -p%(password)s -h %(host)s %(db)s' % {
- 'user' : self.username,
- 'password' : self.password,
- 'host' : self.db_host,
- 'db' : self.db_name})
+ return ('-u %(username)s -p%(password)s -h %(host)s %(db_name)s' %
+ self._database.get_database_info())
def migrate_to_version_or_latest(self, version):
@@ -224,9 +200,7 @@ class MigrationManager(object):
def do_sync_db(self, version=None):
- self.read_db_info()
- self.open_connection()
- print 'Migration starting for database', self.db_name
+ print 'Migration starting for database', self._get_db_name()
self.migrate_to_version_or_latest(version)
print 'Migration complete'
@@ -237,7 +211,7 @@ class MigrationManager(object):
"""
self.initialize_test_db()
try:
- print 'Starting migration test on DB', self.db_name
+ print 'Starting migration test on DB', self._get_db_name()
self.migrate_to_version_or_latest(version)
# show schema to the user
os.system('mysqldump %s --no-data=true '
@@ -253,28 +227,24 @@ class MigrationManager(object):
Create a fresh DB, copy the existing DB to it, and then
try to synchronize it.
"""
- self.read_db_info()
- self.open_connection()
db_version = self.get_db_version()
- self.close_connection()
# don't do anything if we're already at the latest version
if db_version == self.get_latest_version():
print 'Skipping simulation, already at latest version'
return
# get existing data
- self.read_db_info()
print 'Dumping existing data'
dump_fd, dump_file = tempfile.mkstemp('.migrate_dump')
- os.close(dump_fd)
os.system('mysqldump %s >%s' %
(self.get_mysql_args(), dump_file))
# fill in test DB
self.initialize_test_db()
print 'Filling in test DB'
os.system('mysql %s <%s' % (self.get_mysql_args(), dump_file))
+ os.close(dump_fd)
os.remove(dump_file)
try:
- print 'Starting migration test on DB', self.db_name
+ print 'Starting migration test on DB', self._get_db_name()
self.migrate_to_version_or_latest(version)
finally:
self.remove_test_db()
@@ -299,7 +269,10 @@ def main():
parser.add_option("-f", "--force", help="don't ask for confirmation",
action="store_true")
(options, args) = parser.parse_args()
- manager = MigrationManager(options.database, force=options.force)
+ database = database_connection.DatabaseConnection(options.database)
+ database.reconnect_enabled = False
+ database.connect()
+ manager = MigrationManager(database, force=options.force)
if len(args) > 0:
if len(args) > 1:
diff --git a/migrate/migrate_unittest.py b/migrate/migrate_unittest.py
index d33dccd9..72389d6b 100644
--- a/migrate/migrate_unittest.py
+++ b/migrate/migrate_unittest.py
@@ -1,10 +1,11 @@
#!/usr/bin/python2.4
-import unittest
+import unittest, tempfile, os
import MySQLdb
import migrate
import common
from autotest_lib.client.common_lib import global_config
+from autotest_lib.database import database_connection
# Which section of the global config to pull info from. We won't actually use
# that DB, we'll use the corresponding test DB (test_<db name>).
@@ -54,18 +55,8 @@ MIGRATIONS = [DummyMigration(n) for n in xrange(1, NUM_MIGRATIONS + 1)]
class TestableMigrationManager(migrate.MigrationManager):
- def __init__(self, database, migrations_dir=None):
- self.database = database
- self.migrations_dir = migrations_dir
- self.db_host = None
- self.db_name = None
- self.username = None
- self.password = None
-
-
- def read_db_info(self):
- migrate.MigrationManager.read_db_info(self)
- self.db_name = 'test_' + self.db_name
+ def _set_migrations_dir(self, migrations_dir=None):
+ pass
def get_migrations(self, minimum_version=None, maximum_version=None):
@@ -75,39 +66,16 @@ class TestableMigrationManager(migrate.MigrationManager):
class MigrateManagerTest(unittest.TestCase):
- config = global_config.global_config
- host = config.get_config_value(CONFIG_DB, 'host')
- db_name = 'test_' + config.get_config_value(CONFIG_DB, 'database')
- user = config.get_config_value(CONFIG_DB, 'user')
- password = config.get_config_value(CONFIG_DB, 'password')
-
- def do_sql(self, sql):
- self.con = MySQLdb.connect(host=self.host, user=self.user,
- passwd=self.password)
- self.con.autocommit(True)
- self.cur = self.con.cursor()
- try:
- self.cur.execute(sql)
- finally:
- self.con.close()
-
-
- def remove_db(self):
- self.do_sql('DROP DATABASE ' + self.db_name)
-
-
def setUp(self):
- self.do_sql('CREATE DATABASE ' + self.db_name)
- try:
- self.manager = TestableMigrationManager(CONFIG_DB)
- except MySQLdb.OperationalError:
- self.remove_db()
- raise
+ self._database = (
+ database_connection.DatabaseConnection.get_test_database())
+ self._database.connect()
+ self.manager = TestableMigrationManager(self._database)
DummyMigration.clear_migrations_done()
def tearDown(self):
- self.remove_db()
+ self._database.disconnect()
def test_sync(self):
diff --git a/scheduler/monitor_db_unittest.py b/scheduler/monitor_db_unittest.py
index 7eab7307..11ed7db1 100644
--- a/scheduler/monitor_db_unittest.py
+++ b/scheduler/monitor_db_unittest.py
@@ -6,6 +6,7 @@ import common
from autotest_lib.client.common_lib import global_config, host_protections
from autotest_lib.client.common_lib.test_utils import mock
from autotest_lib.migrate import migrate
+from autotest_lib.database import database_connection
import monitor_db
@@ -102,10 +103,9 @@ class BaseDispatcherTest(unittest.TestCase):
self._do_query('CREATE DATABASE ' + self._db_name)
self._disconnect_from_db()
- migration_dir = os.path.join(os.path.dirname(__file__),
- '..', 'frontend', 'migrations')
- manager = migrate.MigrationManager('AUTOTEST_WEB', migration_dir,
- force=True)
+ database = database_connection.DatabaseConnection('AUTOTEST_WEB')
+ database.connect(db_name=self._db_name)
+ manager = migrate.MigrationManager(database, force=True)
manager.do_sync_db()
self._connect_to_db(self._db_name)
diff --git a/tko/migrations/001_initial_db.py b/tko/migrations/001_initial_db.py
index 8b964753..c67f3c2f 100755
--- a/tko/migrations/001_initial_db.py
+++ b/tko/migrations/001_initial_db.py
@@ -4,8 +4,8 @@ required_tables = ('machines', 'jobs', 'patches', 'tests', 'test_attributes',
'iteration_result')
def migrate_up(manager):
- manager.execute("SHOW TABLES")
- tables = [row[0] for row in manager.cursor.fetchall()]
+ rows = manager.execute("SHOW TABLES")
+ tables = [row[0] for row in rows]
db_initialized = True
for table in required_tables:
if table not in tables: