xref: /aosp_15_r20/external/autotest/database/database_connection.py (revision 9c5db1993ded3edbeafc8092d69fe5de2ee02df7)
1*9c5db199SXin Li# pylint: disable-msg=C0111
2*9c5db199SXin Li
3*9c5db199SXin Liimport re, time, traceback
4*9c5db199SXin Liimport common
5*9c5db199SXin Lifrom autotest_lib.client.common_lib import global_config
6*9c5db199SXin Li
7*9c5db199SXin LiRECONNECT_FOREVER = object()
8*9c5db199SXin Li
9*9c5db199SXin Li_DB_EXCEPTIONS = ('DatabaseError', 'OperationalError', 'ProgrammingError')
10*9c5db199SXin Li_GLOBAL_CONFIG_NAMES = {
11*9c5db199SXin Li    'username' : 'user',
12*9c5db199SXin Li    'db_name' : 'database',
13*9c5db199SXin Li}
14*9c5db199SXin Li
15*9c5db199SXin Lidef _copy_exceptions(source, destination):
16*9c5db199SXin Li    for exception_name in _DB_EXCEPTIONS:
17*9c5db199SXin Li        try:
18*9c5db199SXin Li            setattr(destination, exception_name,
19*9c5db199SXin Li                    getattr(source, exception_name))
20*9c5db199SXin Li        except AttributeError:
21*9c5db199SXin Li            # Under the django backend:
22*9c5db199SXin Li            # Django 1.3 does not have OperationalError and ProgrammingError.
23*9c5db199SXin Li            # Let's just mock these classes with the base DatabaseError.
24*9c5db199SXin Li            setattr(destination, exception_name,
25*9c5db199SXin Li                    getattr(source, 'DatabaseError'))
26*9c5db199SXin Li
27*9c5db199SXin Li
28*9c5db199SXin Liclass _GenericBackend(object):
29*9c5db199SXin Li    def __init__(self, database_module):
30*9c5db199SXin Li        self._database_module = database_module
31*9c5db199SXin Li        self._connection = None
32*9c5db199SXin Li        self._cursor = None
33*9c5db199SXin Li        self.rowcount = None
34*9c5db199SXin Li        _copy_exceptions(database_module, self)
35*9c5db199SXin Li
36*9c5db199SXin Li
37*9c5db199SXin Li    def connect(self, host=None, username=None, password=None, db_name=None):
38*9c5db199SXin Li        """
39*9c5db199SXin Li        This is assumed to enable autocommit.
40*9c5db199SXin Li        """
41*9c5db199SXin Li        raise NotImplementedError
42*9c5db199SXin Li
43*9c5db199SXin Li
44*9c5db199SXin Li    def disconnect(self):
45*9c5db199SXin Li        if self._connection:
46*9c5db199SXin Li            self._connection.close()
47*9c5db199SXin Li        self._connection = None
48*9c5db199SXin Li        self._cursor = None
49*9c5db199SXin Li
50*9c5db199SXin Li
51*9c5db199SXin Li    def execute(self, query, parameters=None):
52*9c5db199SXin Li        if parameters is None:
53*9c5db199SXin Li            parameters = ()
54*9c5db199SXin Li        self._cursor.execute(query, parameters)
55*9c5db199SXin Li        self.rowcount = self._cursor.rowcount
56*9c5db199SXin Li        return self._cursor.fetchall()
57*9c5db199SXin Li
58*9c5db199SXin Li
59*9c5db199SXin Liclass _MySqlBackend(_GenericBackend):
60*9c5db199SXin Li    def __init__(self):
61*9c5db199SXin Li        import MySQLdb
62*9c5db199SXin Li        super(_MySqlBackend, self).__init__(MySQLdb)
63*9c5db199SXin Li
64*9c5db199SXin Li
65*9c5db199SXin Li    @staticmethod
66*9c5db199SXin Li    def convert_boolean(boolean, conversion_dict):
67*9c5db199SXin Li        'Convert booleans to integer strings'
68*9c5db199SXin Li        return str(int(boolean))
69*9c5db199SXin Li
70*9c5db199SXin Li
71*9c5db199SXin Li    def connect(self, host=None, username=None, password=None, db_name=None):
72*9c5db199SXin Li        import MySQLdb.converters
73*9c5db199SXin Li        convert_dict = MySQLdb.converters.conversions
74*9c5db199SXin Li        convert_dict.setdefault(bool, self.convert_boolean)
75*9c5db199SXin Li
76*9c5db199SXin Li        self._connection = self._database_module.connect(
77*9c5db199SXin Li            host=host, user=username, passwd=password, db=db_name,
78*9c5db199SXin Li            conv=convert_dict)
79*9c5db199SXin Li        self._connection.autocommit(True)
80*9c5db199SXin Li        self._cursor = self._connection.cursor()
81*9c5db199SXin Li
82*9c5db199SXin Li
83*9c5db199SXin Liclass _SqliteBackend(_GenericBackend):
84*9c5db199SXin Li    def __init__(self):
85*9c5db199SXin Li        try:
86*9c5db199SXin Li            from pysqlite2 import dbapi2
87*9c5db199SXin Li        except ImportError:
88*9c5db199SXin Li            from sqlite3 import dbapi2
89*9c5db199SXin Li        super(_SqliteBackend, self).__init__(dbapi2)
90*9c5db199SXin Li        self._last_insert_id_re = re.compile(r'\sLAST_INSERT_ID\(\)',
91*9c5db199SXin Li                                             re.IGNORECASE)
92*9c5db199SXin Li
93*9c5db199SXin Li
94*9c5db199SXin Li    def connect(self, host=None, username=None, password=None, db_name=None):
95*9c5db199SXin Li        self._connection = self._database_module.connect(db_name)
96*9c5db199SXin Li        self._connection.isolation_level = None # enable autocommit
97*9c5db199SXin Li        self._cursor = self._connection.cursor()
98*9c5db199SXin Li
99*9c5db199SXin Li
100*9c5db199SXin Li    def execute(self, query, parameters=None):
101*9c5db199SXin Li        # pysqlite2 uses paramstyle=qmark
102*9c5db199SXin Li        # TODO: make this more sophisticated if necessary
103*9c5db199SXin Li        query = query.replace('%s', '?')
104*9c5db199SXin Li        # pysqlite2 can't handle parameters=None (it throws a nonsense
105*9c5db199SXin Li        # exception)
106*9c5db199SXin Li        if parameters is None:
107*9c5db199SXin Li            parameters = ()
108*9c5db199SXin Li        # sqlite3 doesn't support MySQL's LAST_INSERT_ID().  Instead it has
109*9c5db199SXin Li        # something similar called LAST_INSERT_ROWID() that will do enough of
110*9c5db199SXin Li        # what we want (for our non-concurrent unittest use case).
111*9c5db199SXin Li        query = self._last_insert_id_re.sub(' LAST_INSERT_ROWID()', query)
112*9c5db199SXin Li        return super(_SqliteBackend, self).execute(query, parameters)
113*9c5db199SXin Li
114*9c5db199SXin Li
115*9c5db199SXin Liclass _DjangoBackend(_GenericBackend):
116*9c5db199SXin Li    def __init__(self):
117*9c5db199SXin Li        from django.db import backend, connection, transaction
118*9c5db199SXin Li        import django.db as django_db
119*9c5db199SXin Li        super(_DjangoBackend, self).__init__(django_db)
120*9c5db199SXin Li        self._django_connection = connection
121*9c5db199SXin Li        self._django_transaction = transaction
122*9c5db199SXin Li
123*9c5db199SXin Li
124*9c5db199SXin Li    def connect(self, host=None, username=None, password=None, db_name=None):
125*9c5db199SXin Li        self._connection = self._django_connection
126*9c5db199SXin Li        self._cursor = self._connection.cursor()
127*9c5db199SXin Li
128*9c5db199SXin Li
129*9c5db199SXin Li    def execute(self, query, parameters=None):
130*9c5db199SXin Li        try:
131*9c5db199SXin Li            return super(_DjangoBackend, self).execute(query,
132*9c5db199SXin Li                                                       parameters=parameters)
133*9c5db199SXin Li        finally:
134*9c5db199SXin Li            self._django_transaction.commit_unless_managed()
135*9c5db199SXin Li
136*9c5db199SXin Li
137*9c5db199SXin Li_BACKEND_MAP = {
138*9c5db199SXin Li    'mysql': _MySqlBackend,
139*9c5db199SXin Li    'sqlite': _SqliteBackend,
140*9c5db199SXin Li    'django': _DjangoBackend,
141*9c5db199SXin Li}
142*9c5db199SXin Li
143*9c5db199SXin Li
144*9c5db199SXin Liclass DatabaseConnection(object):
145*9c5db199SXin Li    """
146*9c5db199SXin Li    Generic wrapper for a database connection.  Supports both mysql and sqlite
147*9c5db199SXin Li    backends.
148*9c5db199SXin Li
149*9c5db199SXin Li    Public attributes:
150*9c5db199SXin Li    * reconnect_enabled: if True, when an OperationalError occurs the class will
151*9c5db199SXin Li      try to reconnect to the database automatically.
152*9c5db199SXin Li    * reconnect_delay_sec: seconds to wait before reconnecting
153*9c5db199SXin Li    * max_reconnect_attempts: maximum number of time to try reconnecting before
154*9c5db199SXin Li      giving up.  Setting to RECONNECT_FOREVER removes the limit.
155*9c5db199SXin Li    * rowcount - will hold cursor.rowcount after each call to execute().
156*9c5db199SXin Li    * global_config_section - the section in which to find DB information. this
157*9c5db199SXin Li      should be passed to the constructor, not set later, and may be None, in
158*9c5db199SXin Li      which case information must be passed to connect().
159*9c5db199SXin Li    * debug - if set True, all queries will be printed before being executed
160*9c5db199SXin Li    """
161*9c5db199SXin Li    _DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password',
162*9c5db199SXin Li                            'db_name')
163*9c5db199SXin Li
164*9c5db199SXin Li    def __init__(self, global_config_section=None, debug=False):
165*9c5db199SXin Li        self.global_config_section = global_config_section
166*9c5db199SXin Li        self._backend = None
167*9c5db199SXin Li        self.rowcount = None
168*9c5db199SXin Li        self.debug = debug
169*9c5db199SXin Li
170*9c5db199SXin Li        # reconnect defaults
171*9c5db199SXin Li        self.reconnect_enabled = True
172*9c5db199SXin Li        self.reconnect_delay_sec = 20
173*9c5db199SXin Li        self.max_reconnect_attempts = 10
174*9c5db199SXin Li
175*9c5db199SXin Li        self._read_options()
176*9c5db199SXin Li
177*9c5db199SXin Li
178*9c5db199SXin Li    def _get_option(self, name, provided_value, use_afe_setting=False):
179*9c5db199SXin Li        """Get value of given option from global config.
180*9c5db199SXin Li
181*9c5db199SXin Li        @param name: Name of the config.
182*9c5db199SXin Li        @param provided_value: Value being provided to override the one from
183*9c5db199SXin Li                               global config.
184*9c5db199SXin Li        @param use_afe_setting: Force to use the settings in AFE, default is
185*9c5db199SXin Li                                False.
186*9c5db199SXin Li        """
187*9c5db199SXin Li        # TODO(dshi): This function returns the option value depends on multiple
188*9c5db199SXin Li        # conditions. The value of `provided_value` has highest priority, then
189*9c5db199SXin Li        # the code checks if use_afe_setting is True, if that's the case, force
190*9c5db199SXin Li        # to use settings in AUTOTEST_WEB. At last the value is retrieved from
191*9c5db199SXin Li        # specified global config section.
192*9c5db199SXin Li        # The logic is too complicated for a generic function named like
193*9c5db199SXin Li        # _get_option. Ideally we want to make it clear from caller that it
194*9c5db199SXin Li        # wants to get database credential from one of the 3 ways:
195*9c5db199SXin Li        # 1. Use the credential from given config section
196*9c5db199SXin Li        # 2. Use the credential from AUTOTEST_WEB section
197*9c5db199SXin Li        # 3. Use the credential provided by caller.
198*9c5db199SXin Li        if provided_value is not None:
199*9c5db199SXin Li            return provided_value
200*9c5db199SXin Li        section = ('AUTOTEST_WEB' if use_afe_setting else
201*9c5db199SXin Li                   self.global_config_section)
202*9c5db199SXin Li        if section:
203*9c5db199SXin Li            global_config_name = _GLOBAL_CONFIG_NAMES.get(name, name)
204*9c5db199SXin Li            return global_config.global_config.get_config_value(
205*9c5db199SXin Li                    section, global_config_name)
206*9c5db199SXin Li
207*9c5db199SXin Li        return getattr(self, name, None)
208*9c5db199SXin Li
209*9c5db199SXin Li
210*9c5db199SXin Li    def _read_options(self, db_type=None, host=None, username=None,
211*9c5db199SXin Li                      password=None, db_name=None):
212*9c5db199SXin Li        """Read database information from global config.
213*9c5db199SXin Li
214*9c5db199SXin Li        Unless any parameter is specified a value, the connection will use
215*9c5db199SXin Li        database name from given configure section (self.global_config_section),
216*9c5db199SXin Li        and database credential from AFE database settings (AUTOTEST_WEB).
217*9c5db199SXin Li
218*9c5db199SXin Li        @param db_type: database type, default to None.
219*9c5db199SXin Li        @param host: database hostname, default to None.
220*9c5db199SXin Li        @param username: user name for database connection, default to None.
221*9c5db199SXin Li        @param password: database password, default to None.
222*9c5db199SXin Li        @param db_name: database name, default to None.
223*9c5db199SXin Li        """
224*9c5db199SXin Li        self.db_name = self._get_option('db_name', db_name)
225*9c5db199SXin Li        use_afe_setting = not bool(db_type or host or username or password)
226*9c5db199SXin Li
227*9c5db199SXin Li        # Database credential can be provided by the caller, as passed in from
228*9c5db199SXin Li        # function connect.
229*9c5db199SXin Li        self.db_type = self._get_option('db_type', db_type, use_afe_setting)
230*9c5db199SXin Li        self.host = self._get_option('host', host, use_afe_setting)
231*9c5db199SXin Li        self.username = self._get_option('username', username, use_afe_setting)
232*9c5db199SXin Li        self.password = self._get_option('password', password, use_afe_setting)
233*9c5db199SXin Li
234*9c5db199SXin Li
235*9c5db199SXin Li    def _get_backend(self, db_type):
236*9c5db199SXin Li        if db_type not in _BACKEND_MAP:
237*9c5db199SXin Li            raise ValueError('Invalid database type: %s, should be one of %s' %
238*9c5db199SXin Li                             (db_type, ', '.join(_BACKEND_MAP.keys())))
239*9c5db199SXin Li        backend_class = _BACKEND_MAP[db_type]
240*9c5db199SXin Li        return backend_class()
241*9c5db199SXin Li
242*9c5db199SXin Li
243*9c5db199SXin Li    def _reached_max_attempts(self, num_attempts):
244*9c5db199SXin Li        return (self.max_reconnect_attempts is not RECONNECT_FOREVER and
245*9c5db199SXin Li                num_attempts > self.max_reconnect_attempts)
246*9c5db199SXin Li
247*9c5db199SXin Li
248*9c5db199SXin Li    def _is_reconnect_enabled(self, supplied_param):
249*9c5db199SXin Li        if supplied_param is not None:
250*9c5db199SXin Li            return supplied_param
251*9c5db199SXin Li        return self.reconnect_enabled
252*9c5db199SXin Li
253*9c5db199SXin Li
254*9c5db199SXin Li    def _connect_backend(self, try_reconnecting=None):
255*9c5db199SXin Li        num_attempts = 0
256*9c5db199SXin Li        while True:
257*9c5db199SXin Li            try:
258*9c5db199SXin Li                self._backend.connect(host=self.host, username=self.username,
259*9c5db199SXin Li                                      password=self.password,
260*9c5db199SXin Li                                      db_name=self.db_name)
261*9c5db199SXin Li                return
262*9c5db199SXin Li            except self._backend.OperationalError:
263*9c5db199SXin Li                num_attempts += 1
264*9c5db199SXin Li                if not self._is_reconnect_enabled(try_reconnecting):
265*9c5db199SXin Li                    raise
266*9c5db199SXin Li                if self._reached_max_attempts(num_attempts):
267*9c5db199SXin Li                    raise
268*9c5db199SXin Li                traceback.print_exc()
269*9c5db199SXin Li                print("Can't connect to database; reconnecting in %s sec" %
270*9c5db199SXin Li                       self.reconnect_delay_sec)
271*9c5db199SXin Li                time.sleep(self.reconnect_delay_sec)
272*9c5db199SXin Li                self.disconnect()
273*9c5db199SXin Li
274*9c5db199SXin Li
275*9c5db199SXin Li    def connect(self, db_type=None, host=None, username=None, password=None,
276*9c5db199SXin Li                db_name=None, try_reconnecting=None):
277*9c5db199SXin Li        """
278*9c5db199SXin Li        Parameters passed to this function will override defaults from global
279*9c5db199SXin Li        config.  try_reconnecting, if passed, will override
280*9c5db199SXin Li        self.reconnect_enabled.
281*9c5db199SXin Li        """
282*9c5db199SXin Li        self.disconnect()
283*9c5db199SXin Li        self._read_options(db_type, host, username, password, db_name)
284*9c5db199SXin Li
285*9c5db199SXin Li        self._backend = self._get_backend(self.db_type)
286*9c5db199SXin Li        _copy_exceptions(self._backend, self)
287*9c5db199SXin Li        self._connect_backend(try_reconnecting)
288*9c5db199SXin Li
289*9c5db199SXin Li
290*9c5db199SXin Li    def disconnect(self):
291*9c5db199SXin Li        if self._backend:
292*9c5db199SXin Li            self._backend.disconnect()
293*9c5db199SXin Li
294*9c5db199SXin Li
295*9c5db199SXin Li    def execute(self, query, parameters=None, try_reconnecting=None):
296*9c5db199SXin Li        """
297*9c5db199SXin Li        Execute a query and return cursor.fetchall(). try_reconnecting, if
298*9c5db199SXin Li        passed, will override self.reconnect_enabled.
299*9c5db199SXin Li        """
300*9c5db199SXin Li        if self.debug:
301*9c5db199SXin Li            print('Executing %s, %s' % (query, parameters))
302*9c5db199SXin Li        # _connect_backend() contains a retry loop, so don't loop here
303*9c5db199SXin Li        try:
304*9c5db199SXin Li            results = self._backend.execute(query, parameters)
305*9c5db199SXin Li        except self._backend.OperationalError:
306*9c5db199SXin Li            if not self._is_reconnect_enabled(try_reconnecting):
307*9c5db199SXin Li                raise
308*9c5db199SXin Li            traceback.print_exc()
309*9c5db199SXin Li            print("MYSQL connection died; reconnecting")
310*9c5db199SXin Li            self.disconnect()
311*9c5db199SXin Li            self._connect_backend(try_reconnecting)
312*9c5db199SXin Li            results = self._backend.execute(query, parameters)
313*9c5db199SXin Li
314*9c5db199SXin Li        self.rowcount = self._backend.rowcount
315*9c5db199SXin Li        return results
316*9c5db199SXin Li
317*9c5db199SXin Li
318*9c5db199SXin Li    def get_database_info(self):
319*9c5db199SXin Li        return dict((attribute, getattr(self, attribute))
320*9c5db199SXin Li                    for attribute in self._DATABASE_ATTRIBUTES)
321*9c5db199SXin Li
322*9c5db199SXin Li
323*9c5db199SXin Li    @classmethod
324*9c5db199SXin Li    def get_test_database(cls, file_path=':memory:', **constructor_kwargs):
325*9c5db199SXin Li        """
326*9c5db199SXin Li        Factory method returning a DatabaseConnection for a temporary in-memory
327*9c5db199SXin Li        database.
328*9c5db199SXin Li        """
329*9c5db199SXin Li        database = cls(**constructor_kwargs)
330*9c5db199SXin Li        database.reconnect_enabled = False
331*9c5db199SXin Li        database.connect(db_type='sqlite', db_name=file_path)
332*9c5db199SXin Li        return database
333*9c5db199SXin Li
334*9c5db199SXin Li
335*9c5db199SXin Liclass TranslatingDatabase(DatabaseConnection):
336*9c5db199SXin Li    """
337*9c5db199SXin Li    Database wrapper than applies arbitrary substitution regexps to each query
338*9c5db199SXin Li    string.  Useful for SQLite testing.
339*9c5db199SXin Li    """
340*9c5db199SXin Li    def __init__(self, translators):
341*9c5db199SXin Li        """
342*9c5db199SXin Li        @param translation_regexps: list of callables to apply to each query
343*9c5db199SXin Li                string (in order).  Each accepts a query string and returns a
344*9c5db199SXin Li                (possibly) modified query string.
345*9c5db199SXin Li        """
346*9c5db199SXin Li        super(TranslatingDatabase, self).__init__()
347*9c5db199SXin Li        self._translators = translators
348*9c5db199SXin Li
349*9c5db199SXin Li
350*9c5db199SXin Li    def execute(self, query, parameters=None, try_reconnecting=None):
351*9c5db199SXin Li        for translator in self._translators:
352*9c5db199SXin Li            query = translator(query)
353*9c5db199SXin Li        return super(TranslatingDatabase, self).execute(
354*9c5db199SXin Li                query, parameters=parameters, try_reconnecting=try_reconnecting)
355*9c5db199SXin Li
356*9c5db199SXin Li
357*9c5db199SXin Li    @classmethod
358*9c5db199SXin Li    def make_regexp_translator(cls, search_re, replace_str):
359*9c5db199SXin Li        """
360*9c5db199SXin Li        Returns a translator that calls re.sub() on the query with the given
361*9c5db199SXin Li        search and replace arguments.
362*9c5db199SXin Li        """
363*9c5db199SXin Li        def translator(query):
364*9c5db199SXin Li            return re.sub(search_re, replace_str, query)
365*9c5db199SXin Li        return translator
366