summaryrefslogtreecommitdiff
path: root/database/database_connection.py
blob: 53903c9e3b16ee2bec0d44cf6c5a50b0bf674d4d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
import re, time, traceback
import common
from autotest_lib.client.common_lib import global_config

RECONNECT_FOREVER = object()

_DB_EXCEPTIONS = ('DatabaseError', 'OperationalError', 'ProgrammingError')
_GLOBAL_CONFIG_NAMES = {
    'username' : 'user',
    'db_name' : 'database',
}

def _copy_exceptions(source, destination):
    for exception_name in _DB_EXCEPTIONS:
        try:
            setattr(destination, exception_name,
                    getattr(source, exception_name))
        except AttributeError:
            # Under the django backend:
            # Django 1.3 does not have OperationalError and ProgrammingError.
            # Let's just mock these classes with the base DatabaseError.
            setattr(destination, exception_name,
                    getattr(source, 'DatabaseError'))


class _GenericBackend(object):
    def __init__(self, database_module):
        self._database_module = database_module
        self._connection = None
        self._cursor = None
        self.rowcount = None
        _copy_exceptions(database_module, self)


    def connect(self, host=None, username=None, password=None, db_name=None):
        """
        This is assumed to enable autocommit.
        """
        raise NotImplementedError


    def disconnect(self):
        if self._connection:
            self._connection.close()
        self._connection = None
        self._cursor = None


    def execute(self, query, parameters=None):
        if parameters is None:
            parameters = ()
        self._cursor.execute(query, parameters)
        self.rowcount = self._cursor.rowcount
        return self._cursor.fetchall()


class _MySqlBackend(_GenericBackend):
    def __init__(self):
        import MySQLdb
        super(_MySqlBackend, self).__init__(MySQLdb)


    @staticmethod
    def convert_boolean(boolean, conversion_dict):
        'Convert booleans to integer strings'
        return str(int(boolean))


    def connect(self, host=None, username=None, password=None, db_name=None):
        import MySQLdb.converters
        convert_dict = MySQLdb.converters.conversions
        convert_dict.setdefault(bool, self.convert_boolean)

        self._connection = self._database_module.connect(
            host=host, user=username, passwd=password, db=db_name,
            conv=convert_dict)
        self._connection.autocommit(True)
        self._cursor = self._connection.cursor()


class _SqliteBackend(_GenericBackend):
    def __init__(self):
        from pysqlite2 import dbapi2
        super(_SqliteBackend, self).__init__(dbapi2)
        self._last_insert_id_re = re.compile(r'\sLAST_INSERT_ID\(\)',
                                             re.IGNORECASE)


    def connect(self, host=None, username=None, password=None, db_name=None):
        self._connection = self._database_module.connect(db_name)
        self._connection.isolation_level = None # enable autocommit
        self._cursor = self._connection.cursor()


    def execute(self, query, parameters=None):
        # pysqlite2 uses paramstyle=qmark
        # TODO: make this more sophisticated if necessary
        query = query.replace('%s', '?')
        # pysqlite2 can't handle parameters=None (it throws a nonsense
        # exception)
        if parameters is None:
            parameters = ()
        # sqlite3 doesn't support MySQL's LAST_INSERT_ID().  Instead it has
        # something similar called LAST_INSERT_ROWID() that will do enough of
        # what we want (for our non-concurrent unittest use case).
        query = self._last_insert_id_re.sub(' LAST_INSERT_ROWID()', query)
        return super(_SqliteBackend, self).execute(query, parameters)


class _DjangoBackend(_GenericBackend):
    def __init__(self):
        from django.db import backend, connection, transaction
        import django.db as django_db
        super(_DjangoBackend, self).__init__(django_db)
        self._django_connection = connection
        self._django_transaction = transaction


    def connect(self, host=None, username=None, password=None, db_name=None):
        self._connection = self._django_connection
        self._cursor = self._connection.cursor()


    def execute(self, query, parameters=None):
        try:
            return super(_DjangoBackend, self).execute(query,
                                                       parameters=parameters)
        finally:
            self._django_transaction.commit_unless_managed()


_BACKEND_MAP = {
    'mysql': _MySqlBackend,
    'sqlite': _SqliteBackend,
    'django': _DjangoBackend,
}


class DatabaseConnection(object):
    """
    Generic wrapper for a database connection.  Supports both mysql and sqlite
    backends.

    Public attributes:
    * reconnect_enabled: if True, when an OperationalError occurs the class will
      try to reconnect to the database automatically.
    * reconnect_delay_sec: seconds to wait before reconnecting
    * max_reconnect_attempts: maximum number of time to try reconnecting before
      giving up.  Setting to RECONNECT_FOREVER removes the limit.
    * rowcount - will hold cursor.rowcount after each call to execute().
    * global_config_section - the section in which to find DB information. this
      should be passed to the constructor, not set later, and may be None, in
      which case information must be passed to connect().
    * debug - if set True, all queries will be printed before being executed
    """
    _DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password',
                            'db_name')

    def __init__(self, global_config_section=None, debug=False):
        self.global_config_section = global_config_section
        self._backend = None
        self.rowcount = None
        self.debug = debug

        # reconnect defaults
        self.reconnect_enabled = True
        self.reconnect_delay_sec = 20
        self.max_reconnect_attempts = 10

        self._read_options()


    def _get_option(self, name, provided_value):
        if provided_value is not None:
            return provided_value
        if self.global_config_section:
            global_config_name = _GLOBAL_CONFIG_NAMES.get(name, name)
            return global_config.global_config.get_config_value(
                self.global_config_section, global_config_name)
        return getattr(self, name, None)


    def _read_options(self, db_type=None, host=None, username=None,
                      password=None, db_name=None):
        self.db_type = self._get_option('db_type', db_type)
        self.host = self._get_option('host', host)
        self.username = self._get_option('username', username)
        self.password = self._get_option('password', password)
        self.db_name = self._get_option('db_name', db_name)


    def _get_backend(self, db_type):
        if db_type not in _BACKEND_MAP:
            raise ValueError('Invalid database type: %s, should be one of %s' %
                             (db_type, ', '.join(_BACKEND_MAP.keys())))
        backend_class = _BACKEND_MAP[db_type]
        return backend_class()


    def _reached_max_attempts(self, num_attempts):
        return (self.max_reconnect_attempts is not RECONNECT_FOREVER and
                num_attempts > self.max_reconnect_attempts)


    def _is_reconnect_enabled(self, supplied_param):
        if supplied_param is not None:
            return supplied_param
        return self.reconnect_enabled


    def _connect_backend(self, try_reconnecting=None):
        num_attempts = 0
        while True:
            try:
                self._backend.connect(host=self.host, username=self.username,
                                      password=self.password,
                                      db_name=self.db_name)
                return
            except self._backend.OperationalError:
                num_attempts += 1
                if not self._is_reconnect_enabled(try_reconnecting):
                    raise
                if self._reached_max_attempts(num_attempts):
                    raise
                traceback.print_exc()
                print ("Can't connect to database; reconnecting in %s sec" %
                       self.reconnect_delay_sec)
                time.sleep(self.reconnect_delay_sec)
                self.disconnect()


    def connect(self, db_type=None, host=None, username=None, password=None,
                db_name=None, try_reconnecting=None):
        """
        Parameters passed to this function will override defaults from global
        config.  try_reconnecting, if passed, will override
        self.reconnect_enabled.
        """
        self.disconnect()
        self._read_options(db_type, host, username, password, db_name)

        self._backend = self._get_backend(self.db_type)
        _copy_exceptions(self._backend, self)
        self._connect_backend(try_reconnecting)


    def disconnect(self):
        if self._backend:
            self._backend.disconnect()


    def execute(self, query, parameters=None, try_reconnecting=None):
        """
        Execute a query and return cursor.fetchall(). try_reconnecting, if
        passed, will override self.reconnect_enabled.
        """
        if self.debug:
            print 'Executing %s, %s' % (query, parameters)
        # _connect_backend() contains a retry loop, so don't loop here
        try:
            results = self._backend.execute(query, parameters)
        except self._backend.OperationalError:
            if not self._is_reconnect_enabled(try_reconnecting):
                raise
            traceback.print_exc()
            print ("MYSQL connection died; reconnecting")
            self.disconnect()
            self._connect_backend(try_reconnecting)
            results = self._backend.execute(query, parameters)

        self.rowcount = self._backend.rowcount
        return results


    def get_database_info(self):
        return dict((attribute, getattr(self, attribute))
                    for attribute in self._DATABASE_ATTRIBUTES)


    @classmethod
    def get_test_database(cls, file_path=':memory:', **constructor_kwargs):
        """
        Factory method returning a DatabaseConnection for a temporary in-memory
        database.
        """
        database = cls(**constructor_kwargs)
        database.reconnect_enabled = False
        database.connect(db_type='sqlite', db_name=file_path)
        return database


class TranslatingDatabase(DatabaseConnection):
    """
    Database wrapper than applies arbitrary substitution regexps to each query
    string.  Useful for SQLite testing.
    """
    def __init__(self, translators):
        """
        @param translation_regexps: list of callables to apply to each query
                string (in order).  Each accepts a query string and returns a
                (possibly) modified query string.
        """
        super(TranslatingDatabase, self).__init__()
        self._translators = translators


    def execute(self, query, parameters=None, try_reconnecting=None):
        for translator in self._translators:
            query = translator(query)
        return super(TranslatingDatabase, self).execute(
                query, parameters=parameters, try_reconnecting=try_reconnecting)


    @classmethod
    def make_regexp_translator(cls, search_re, replace_str):
        """
        Returns a translator that calls re.sub() on the query with the given
        search and replace arguments.
        """
        def translator(query):
            return re.sub(search_re, replace_str, query)
        return translator