diff options
-rw-r--r-- | database/__init__.py | 0 | ||||
-rw-r--r-- | database/common.py | 8 | ||||
-rw-r--r-- | database/database_connection.py | 250 | ||||
-rw-r--r-- | database/database_connection_unittest.py | 186 | ||||
-rw-r--r-- | frontend/migrations/001_initial_db.py | 4 | ||||
-rwxr-xr-x | migrate/migrate.py | 105 | ||||
-rw-r--r-- | migrate/migrate_unittest.py | 50 | ||||
-rw-r--r-- | scheduler/monitor_db_unittest.py | 8 | ||||
-rwxr-xr-x | tko/migrations/001_initial_db.py | 4 |
9 files changed, 500 insertions, 115 deletions
diff --git a/database/__init__.py b/database/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/database/__init__.py diff --git a/database/common.py b/database/common.py new file mode 100644 index 00000000..9941b190 --- /dev/null +++ b/database/common.py @@ -0,0 +1,8 @@ +import os, sys +dirname = os.path.dirname(sys.modules[__name__].__file__) +autotest_dir = os.path.abspath(os.path.join(dirname, "..")) +client_dir = os.path.join(autotest_dir, "client") +sys.path.insert(0, client_dir) +import setup_modules +sys.path.pop(0) +setup_modules.setup(base_path=autotest_dir, root_module_name="autotest_lib") diff --git a/database/database_connection.py b/database/database_connection.py new file mode 100644 index 00000000..6d617287 --- /dev/null +++ b/database/database_connection.py @@ -0,0 +1,250 @@ +import traceback, time +import common +from autotest_lib.client.common_lib import global_config + +RECONNECT_FOREVER = object() + +_DB_EXCEPTIONS = ('DatabaseError', 'OperationalError', 'ProgrammingError') +_GLOBAL_CONFIG_NAMES = { + 'username' : 'user', + 'db_name' : 'database', +} + +def _copy_exceptions(source, destination): + for exception_name in _DB_EXCEPTIONS: + setattr(destination, exception_name, getattr(source, exception_name)) + + +class _GenericBackend(object): + def __init__(self, database_module): + self._database_module = database_module + self._connection = None + self._cursor = None + self.rowcount = None + _copy_exceptions(database_module, self) + + + def connect(self, host=None, username=None, password=None, db_name=None): + """ + This is assumed to enable autocommit. + """ + raise NotImplementedError + + + def disconnect(self): + if self._connection: + self._connection.close() + self._connection = None + self._cursor = None + + + def execute(self, query, arguments=None): + self._cursor.execute(query, arguments) + self.rowcount = self._cursor.rowcount + return self._cursor.fetchall() + + + def get_exception_details(exception): + return ExceptionDetails.UNKNOWN + + +class _MySqlBackend(_GenericBackend): + def __init__(self): + import MySQLdb + super(_MySqlBackend, self).__init__(MySQLdb) + + + @staticmethod + def convert_boolean(boolean, conversion_dict): + 'Convert booleans to integer strings' + return str(int(boolean)) + + + def connect(self, host=None, username=None, password=None, db_name=None): + import MySQLdb.converters + convert_dict = MySQLdb.converters.conversions + convert_dict.setdefault(bool, self.convert_boolean) + + self._connection = self._database_module.connect( + host=host, user=username, passwd=password, db=db_name, + conv=convert_dict) + self._connection.autocommit(True) + self._cursor = self._connection.cursor() + + + def get_exception_details(exception): + pass + + +class _SqliteBackend(_GenericBackend): + def __init__(self): + from pysqlite2 import dbapi2 + super(_SqliteBackend, self).__init__(dbapi2) + + + def connect(self, host=None, username=None, password=None, db_name=None): + self._connection = self._database_module.connect(db_name) + self._connection.isolation_level = None # enable autocommit + self._cursor = self._connection.cursor() + + + def execute(self, query, arguments=None): + # pysqlite2 uses paramstyle=qmark + # TODO: make this more sophisticated if necessary + query = query.replace('%s', '?') + return super(_SqliteBackend, self).execute(query, arguments) + + +_BACKEND_MAP = { + 'mysql' : _MySqlBackend, + 'sqlite' : _SqliteBackend, +} + + +class DatabaseConnection(object): + """ + Generic wrapper for a database connection. Supports both mysql and sqlite + backends. + + Public attributes: + * reconnect_enabled: if True, when an OperationalError occurs the class will + try to reconnect to the database automatically. + * reconnect_delay_sec: seconds to wait before reconnecting + * max_reconnect_attempts: maximum number of time to try reconnecting before + giving up. Setting to RECONNECT_FOREVER removes the limit. + * rowcount - will hold cursor.rowcount after each call to execute(). + * global_config_section - the section in which to find DB information. this + should be passed to the constructor, not set later, and may be None, in + which case information must be passed to connect(). + """ + _DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password', + 'db_name') + + def __init__(self, global_config_section=None): + self.global_config_section = global_config_section + self._backend = None + self.rowcount = None + + # reconnect defaults + self.reconnect_enabled = True + self.reconnect_delay_sec = 20 + self.max_reconnect_attempts = 10 + + self._read_options() + + + def _get_option(self, name, provided_value): + if provided_value is not None: + return provided_value + if self.global_config_section: + global_config_name = _GLOBAL_CONFIG_NAMES.get(name, name) + return global_config.global_config.get_config_value( + self.global_config_section, global_config_name) + return getattr(self, name, None) + + + def _read_options(self, db_type=None, host=None, username=None, + password=None, db_name=None): + self.db_type = self._get_option('db_type', db_type) + self.host = self._get_option('host', host) + self.username = self._get_option('username', username) + self.password = self._get_option('password', password) + self.db_name = self._get_option('db_name', db_name) + + + def _get_backend(self, db_type): + if db_type not in _BACKEND_MAP: + raise ValueError('Invalid database type: %s, should be one of %s' % + (db_type, ', '.join(_BACKEND_MAP.keys()))) + backend_class = _BACKEND_MAP[db_type] + return backend_class() + + + def _reached_max_attempts(self, num_attempts): + return (self.max_reconnect_attempts is not RECONNECT_FOREVER and + num_attempts > self.max_reconnect_attempts) + + + def _is_reconnect_enabled(self, supplied_param): + if supplied_param is not None: + return supplied_param + return self.reconnect_enabled + + + def _connect_backend(self, try_reconnecting=None): + num_attempts = 0 + while True: + try: + self._backend.connect(host=self.host, username=self.username, + password=self.password, + db_name=self.db_name) + return + except self._backend.OperationalError: + num_attempts += 1 + if not self._is_reconnect_enabled(try_reconnecting): + raise + if self._reached_max_attempts(num_attempts): + raise + traceback.print_exc() + print ("Can't connect to database; reconnecting in %s sec" % + self.reconnect_delay_sec) + time.sleep(self.reconnect_delay_sec) + self.disconnect() + + + def connect(self, db_type=None, host=None, username=None, password=None, + db_name=None, try_reconnecting=None): + """ + Parameters passed to this function will override defaults from global + config. try_reconnecting, if passed, will override + self.reconnect_enabled. + """ + self.disconnect() + self._read_options(db_type, host, username, password, db_name) + + self._backend = self._get_backend(self.db_type) + _copy_exceptions(self._backend, self) + self._connect_backend(try_reconnecting) + + + def disconnect(self): + if self._backend: + self._backend.disconnect() + + + def execute(self, query, parameters=None, try_reconnecting=None): + """ + Execute a query and return cursor.fetchall(). try_reconnecting, if + passed, will override self.reconnect_enabled. + """ + # _connect_backend() contains a retry loop, so don't loop here + try: + results = self._backend.execute(query, parameters) + except self._backend.OperationalError: + if not self._is_reconnect_enabled(try_reconnecting): + raise + traceback.print_exc() + print ("MYSQL connection died; reconnecting") + self.disconnect() + self._connect_backend(try_reconnecting) + results = self._backend.execute(query, parameters) + + self.rowcount = self._backend.rowcount + return results + + + def get_database_info(self): + return dict((attribute, getattr(self, attribute)) + for attribute in self._DATABASE_ATTRIBUTES) + + + @classmethod + def get_test_database(cls): + """ + Factory method returning a DatabaseConnection for a temporary in-memory + database. + """ + database = cls() + database.reconnect_enabled = False + database.connect(db_type='sqlite', db_name=':memory:') + return database diff --git a/database/database_connection_unittest.py b/database/database_connection_unittest.py new file mode 100644 index 00000000..95df738a --- /dev/null +++ b/database/database_connection_unittest.py @@ -0,0 +1,186 @@ +#!/usr/bin/python2.4 + +import unittest, time +import MySQLdb +import pysqlite2.dbapi2 +import common +from autotest_lib.client.common_lib import global_config +from autotest_lib.client.common_lib.test_utils import mock +from autotest_lib.database import database_connection + +_CONFIG_SECTION = 'TKO' +_HOST = 'myhost' +_USER = 'myuser' +_PASS = 'mypass' +_DB_NAME = 'mydb' +_DB_TYPE = 'mydbtype' + +_CONNECT_KWARGS = dict(host=_HOST, username=_USER, password=_PASS, + db_name=_DB_NAME) +_RECONNECT_DELAY = 10 + +class FakeDatabaseError(Exception): + pass + + +class DatabaseConnectionTest(unittest.TestCase): + def setUp(self): + self.god = mock.mock_god() + self.god.stub_function(time, 'sleep') + + + def tearDown(self): + global_config.global_config.reset_config_values() + self.god.unstub_all() + + + def _get_database_connection(self, config_section=_CONFIG_SECTION): + if config_section == _CONFIG_SECTION: + self._override_config() + db = database_connection.DatabaseConnection(config_section) + + self._fake_backend = self.god.create_mock_class( + database_connection._GenericBackend, 'fake_backend') + for exception in database_connection._DB_EXCEPTIONS: + setattr(self._fake_backend, exception, FakeDatabaseError) + self._fake_backend.rowcount = 0 + + def get_fake_backend(db_type): + self._db_type = db_type + return self._fake_backend + self.god.stub_with(db, '_get_backend', get_fake_backend) + + db.reconnect_delay_sec = _RECONNECT_DELAY + return db + + + def _override_config(self): + c = global_config.global_config + c.override_config_value(_CONFIG_SECTION, 'host', _HOST) + c.override_config_value(_CONFIG_SECTION, 'user', _USER) + c.override_config_value(_CONFIG_SECTION, 'password', _PASS) + c.override_config_value(_CONFIG_SECTION, 'database', _DB_NAME) + c.override_config_value(_CONFIG_SECTION, 'db_type', _DB_TYPE) + + + def test_connect(self): + db = self._get_database_connection(config_section=None) + self._fake_backend.connect.expect_call(**_CONNECT_KWARGS) + + db.connect(db_type=_DB_TYPE, host=_HOST, username=_USER, + password=_PASS, db_name=_DB_NAME) + + self.assertEquals(self._db_type, _DB_TYPE) + self.god.check_playback() + + + def test_global_config(self): + db = self._get_database_connection() + self._fake_backend.connect.expect_call(**_CONNECT_KWARGS) + + db.connect() + + self.assertEquals(self._db_type, _DB_TYPE) + self.god.check_playback() + + + def _expect_reconnect(self, fail=False): + self._fake_backend.disconnect.expect_call() + call = self._fake_backend.connect.expect_call(**_CONNECT_KWARGS) + if fail: + call.and_raises(FakeDatabaseError()) + + + def _expect_fail_and_reconnect(self, num_reconnects, fail_last=False): + self._fake_backend.connect.expect_call(**_CONNECT_KWARGS).and_raises( + FakeDatabaseError()) + for i in xrange(num_reconnects): + time.sleep.expect_call(_RECONNECT_DELAY) + if i < num_reconnects - 1: + self._expect_reconnect(fail=True) + else: + self._expect_reconnect(fail=fail_last) + + + def test_connect_retry(self): + db = self._get_database_connection() + self._expect_fail_and_reconnect(1) + + db.connect() + self.god.check_playback() + + self._fake_backend.disconnect.expect_call() + self._expect_fail_and_reconnect(0) + self.assertRaises(FakeDatabaseError, db.connect, + try_reconnecting=False) + self.god.check_playback() + + db.reconnect_enabled = False + self._fake_backend.disconnect.expect_call() + self._expect_fail_and_reconnect(0) + self.assertRaises(FakeDatabaseError, db.connect) + self.god.check_playback() + + + def test_max_reconnect(self): + db = self._get_database_connection() + db.max_reconnect_attempts = 5 + self._expect_fail_and_reconnect(5, fail_last=True) + + self.assertRaises(FakeDatabaseError, db.connect) + self.god.check_playback() + + + def test_reconnect_forever(self): + db = self._get_database_connection() + db.max_reconnect_attempts = database_connection.RECONNECT_FOREVER + self._expect_fail_and_reconnect(30) + + db.connect() + self.god.check_playback() + + + def _simple_connect(self, db): + self._fake_backend.connect.expect_call(**_CONNECT_KWARGS) + db.connect() + self.god.check_playback() + + + def test_disconnect(self): + db = self._get_database_connection() + self._simple_connect(db) + self._fake_backend.disconnect.expect_call() + + db.disconnect() + self.god.check_playback() + + + def test_execute(self): + db = self._get_database_connection() + self._simple_connect(db) + params = object() + self._fake_backend.execute.expect_call('query', params) + + db.execute('query', params) + self.god.check_playback() + + + def test_execute_retry(self): + db = self._get_database_connection() + self._simple_connect(db) + self._fake_backend.execute.expect_call('query', None).and_raises( + FakeDatabaseError()) + self._expect_reconnect() + self._fake_backend.execute.expect_call('query', None) + + db.execute('query') + self.god.check_playback() + + self._fake_backend.execute.expect_call('query', None).and_raises( + FakeDatabaseError()) + self.assertRaises(FakeDatabaseError, db.execute, 'query', + try_reconnecting=False) + + +if __name__ == '__main__': + unittest.main() diff --git a/frontend/migrations/001_initial_db.py b/frontend/migrations/001_initial_db.py index 7b78ffa4..15495f71 100644 --- a/frontend/migrations/001_initial_db.py +++ b/frontend/migrations/001_initial_db.py @@ -6,8 +6,8 @@ required_tables = ('acl_groups', 'acl_groups_hosts', 'acl_groups_users', def migrate_up(manager): - manager.execute("SHOW TABLES") - tables = [row[0] for row in manager.cursor.fetchall()] + rows = manager.execute("SHOW TABLES") + tables = [row[0] for row in rows] db_initialized = True for table in required_tables: if table not in tables: diff --git a/migrate/migrate.py b/migrate/migrate.py index 241d2194..37c78577 100755 --- a/migrate/migrate.py +++ b/migrate/migrate.py @@ -5,6 +5,7 @@ import MySQLdb, MySQLdb.constants.ER from optparse import OptionParser import common from autotest_lib.client.common_lib import global_config +from autotest_lib.database import database_connection MIGRATE_TABLE = 'migrate_info' @@ -37,50 +38,29 @@ class MigrationManager(object): cursor = None migrations_dir = None - def __init__(self, database, migrations_dir=None, force=False): - self.database = database + def __init__(self, database_connection, migrations_dir=None, force=False): + self._database = database_connection self.force = force + self._set_migrations_dir(migrations_dir) + + + def _set_migrations_dir(self, migrations_dir=None): + config_section = self._database.global_config_section if migrations_dir is None: migrations_dir = os.path.abspath( - _MIGRATIONS_DIRS.get(database, _DEFAULT_MIGRATIONS_DIR)) + _MIGRATIONS_DIRS.get(config_section, _DEFAULT_MIGRATIONS_DIR)) self.migrations_dir = migrations_dir sys.path.append(migrations_dir) - assert os.path.exists(migrations_dir) - - self.db_host = None - self.db_name = None - self.username = None - self.password = None - - - def read_db_info(self): - # grab the config file and parse for info - c = global_config.global_config - self.db_host = c.get_config_value(self.database, "host") - self.db_name = c.get_config_value(self.database, "database") - self.username = c.get_config_value(self.database, "user") - self.password = c.get_config_value(self.database, "password") - + assert os.path.exists(migrations_dir), migrations_dir + " doesn't exist" - def connect(self, host, db_name, username, password): - return MySQLdb.connect(host=host, db=db_name, user=username, - passwd=password) - - def open_connection(self): - self.connection = self.connect(self.db_host, self.db_name, - self.username, self.password) - self.connection.autocommit(True) - self.cursor = self.connection.cursor() - - - def close_connection(self): - self.connection.close() + def _get_db_name(self): + return self._database.get_database_info()['db_name'] def execute(self, query, *parameters): #print 'SQL:', query % parameters - return self.cursor.execute(query, parameters) + return self._database.execute(query, parameters) def execute_script(self, script): @@ -95,11 +75,10 @@ class MigrationManager(object): try: self.execute("SELECT * FROM %s" % MIGRATE_TABLE) return True - except MySQLdb.ProgrammingError, exc: - error_code, _ = exc.args - if error_code == MySQLdb.constants.ER.NO_SUCH_TABLE: - return False - raise + except self._database.DatabaseError, exc: + # we can't check for more specifics due to differences between DB + # backends (we can't even check for a subclass of DatabaseError) + return False def create_migrate_table(self): @@ -109,21 +88,20 @@ class MigrationManager(object): else: self.execute("DELETE FROM %s" % MIGRATE_TABLE) self.execute("INSERT INTO %s VALUES (0)" % MIGRATE_TABLE) - assert self.cursor.rowcount == 1 + assert self._database.rowcount == 1 def set_db_version(self, version): assert isinstance(version, int) self.execute("UPDATE %s SET version=%%s" % MIGRATE_TABLE, version) - assert self.cursor.rowcount == 1 + assert self._database.rowcount == 1 def get_db_version(self): if not self.check_migrate_table_exists(): return 0 - self.execute("SELECT * FROM %s" % MIGRATE_TABLE) - rows = self.cursor.fetchall() + rows = self.execute("SELECT * FROM %s" % MIGRATE_TABLE) if len(rows) == 0: return 0 assert len(rows) == 1 and len(rows[0]) == 1 @@ -190,30 +168,28 @@ class MigrationManager(object): def initialize_test_db(self): - self.read_db_info() - test_db_name = 'test_' + self.db_name + db_name = self._get_db_name() + test_db_name = 'test_' + db_name # first, connect to no DB so we can create a test DB - self.db_name = '' - self.open_connection() + self._database.connect(db_name='') print 'Creating test DB', test_db_name self.execute('CREATE DATABASE ' + test_db_name) - self.close_connection() + self._database.disconnect() # now connect to the test DB - self.db_name = test_db_name - self.open_connection() + self._database.connect(db_name=test_db_name) def remove_test_db(self): print 'Removing test DB' - self.execute('DROP DATABASE ' + self.db_name) + self.execute('DROP DATABASE ' + self._get_db_name()) + # reset connection back to real DB + self._database.disconnect() + self._database.connect() def get_mysql_args(self): - return ('-u %(user)s -p%(password)s -h %(host)s %(db)s' % { - 'user' : self.username, - 'password' : self.password, - 'host' : self.db_host, - 'db' : self.db_name}) + return ('-u %(username)s -p%(password)s -h %(host)s %(db_name)s' % + self._database.get_database_info()) def migrate_to_version_or_latest(self, version): @@ -224,9 +200,7 @@ class MigrationManager(object): def do_sync_db(self, version=None): - self.read_db_info() - self.open_connection() - print 'Migration starting for database', self.db_name + print 'Migration starting for database', self._get_db_name() self.migrate_to_version_or_latest(version) print 'Migration complete' @@ -237,7 +211,7 @@ class MigrationManager(object): """ self.initialize_test_db() try: - print 'Starting migration test on DB', self.db_name + print 'Starting migration test on DB', self._get_db_name() self.migrate_to_version_or_latest(version) # show schema to the user os.system('mysqldump %s --no-data=true ' @@ -253,28 +227,24 @@ class MigrationManager(object): Create a fresh DB, copy the existing DB to it, and then try to synchronize it. """ - self.read_db_info() - self.open_connection() db_version = self.get_db_version() - self.close_connection() # don't do anything if we're already at the latest version if db_version == self.get_latest_version(): print 'Skipping simulation, already at latest version' return # get existing data - self.read_db_info() print 'Dumping existing data' dump_fd, dump_file = tempfile.mkstemp('.migrate_dump') - os.close(dump_fd) os.system('mysqldump %s >%s' % (self.get_mysql_args(), dump_file)) # fill in test DB self.initialize_test_db() print 'Filling in test DB' os.system('mysql %s <%s' % (self.get_mysql_args(), dump_file)) + os.close(dump_fd) os.remove(dump_file) try: - print 'Starting migration test on DB', self.db_name + print 'Starting migration test on DB', self._get_db_name() self.migrate_to_version_or_latest(version) finally: self.remove_test_db() @@ -299,7 +269,10 @@ def main(): parser.add_option("-f", "--force", help="don't ask for confirmation", action="store_true") (options, args) = parser.parse_args() - manager = MigrationManager(options.database, force=options.force) + database = database_connection.DatabaseConnection(options.database) + database.reconnect_enabled = False + database.connect() + manager = MigrationManager(database, force=options.force) if len(args) > 0: if len(args) > 1: diff --git a/migrate/migrate_unittest.py b/migrate/migrate_unittest.py index d33dccd9..72389d6b 100644 --- a/migrate/migrate_unittest.py +++ b/migrate/migrate_unittest.py @@ -1,10 +1,11 @@ #!/usr/bin/python2.4 -import unittest +import unittest, tempfile, os import MySQLdb import migrate import common from autotest_lib.client.common_lib import global_config +from autotest_lib.database import database_connection # Which section of the global config to pull info from. We won't actually use # that DB, we'll use the corresponding test DB (test_<db name>). @@ -54,18 +55,8 @@ MIGRATIONS = [DummyMigration(n) for n in xrange(1, NUM_MIGRATIONS + 1)] class TestableMigrationManager(migrate.MigrationManager): - def __init__(self, database, migrations_dir=None): - self.database = database - self.migrations_dir = migrations_dir - self.db_host = None - self.db_name = None - self.username = None - self.password = None - - - def read_db_info(self): - migrate.MigrationManager.read_db_info(self) - self.db_name = 'test_' + self.db_name + def _set_migrations_dir(self, migrations_dir=None): + pass def get_migrations(self, minimum_version=None, maximum_version=None): @@ -75,39 +66,16 @@ class TestableMigrationManager(migrate.MigrationManager): class MigrateManagerTest(unittest.TestCase): - config = global_config.global_config - host = config.get_config_value(CONFIG_DB, 'host') - db_name = 'test_' + config.get_config_value(CONFIG_DB, 'database') - user = config.get_config_value(CONFIG_DB, 'user') - password = config.get_config_value(CONFIG_DB, 'password') - - def do_sql(self, sql): - self.con = MySQLdb.connect(host=self.host, user=self.user, - passwd=self.password) - self.con.autocommit(True) - self.cur = self.con.cursor() - try: - self.cur.execute(sql) - finally: - self.con.close() - - - def remove_db(self): - self.do_sql('DROP DATABASE ' + self.db_name) - - def setUp(self): - self.do_sql('CREATE DATABASE ' + self.db_name) - try: - self.manager = TestableMigrationManager(CONFIG_DB) - except MySQLdb.OperationalError: - self.remove_db() - raise + self._database = ( + database_connection.DatabaseConnection.get_test_database()) + self._database.connect() + self.manager = TestableMigrationManager(self._database) DummyMigration.clear_migrations_done() def tearDown(self): - self.remove_db() + self._database.disconnect() def test_sync(self): diff --git a/scheduler/monitor_db_unittest.py b/scheduler/monitor_db_unittest.py index 7eab7307..11ed7db1 100644 --- a/scheduler/monitor_db_unittest.py +++ b/scheduler/monitor_db_unittest.py @@ -6,6 +6,7 @@ import common from autotest_lib.client.common_lib import global_config, host_protections from autotest_lib.client.common_lib.test_utils import mock from autotest_lib.migrate import migrate +from autotest_lib.database import database_connection import monitor_db @@ -102,10 +103,9 @@ class BaseDispatcherTest(unittest.TestCase): self._do_query('CREATE DATABASE ' + self._db_name) self._disconnect_from_db() - migration_dir = os.path.join(os.path.dirname(__file__), - '..', 'frontend', 'migrations') - manager = migrate.MigrationManager('AUTOTEST_WEB', migration_dir, - force=True) + database = database_connection.DatabaseConnection('AUTOTEST_WEB') + database.connect(db_name=self._db_name) + manager = migrate.MigrationManager(database, force=True) manager.do_sync_db() self._connect_to_db(self._db_name) diff --git a/tko/migrations/001_initial_db.py b/tko/migrations/001_initial_db.py index 8b964753..c67f3c2f 100755 --- a/tko/migrations/001_initial_db.py +++ b/tko/migrations/001_initial_db.py @@ -4,8 +4,8 @@ required_tables = ('machines', 'jobs', 'patches', 'tests', 'test_attributes', 'iteration_result') def migrate_up(manager): - manager.execute("SHOW TABLES") - tables = [row[0] for row in manager.cursor.fetchall()] + rows = manager.execute("SHOW TABLES") + tables = [row[0] for row in rows] db_initialized = True for table in required_tables: if table not in tables: |