1*9c5db199SXin Li# pylint: disable-msg=C0111 2*9c5db199SXin Li 3*9c5db199SXin Liimport re, time, traceback 4*9c5db199SXin Liimport common 5*9c5db199SXin Lifrom autotest_lib.client.common_lib import global_config 6*9c5db199SXin Li 7*9c5db199SXin LiRECONNECT_FOREVER = object() 8*9c5db199SXin Li 9*9c5db199SXin Li_DB_EXCEPTIONS = ('DatabaseError', 'OperationalError', 'ProgrammingError') 10*9c5db199SXin Li_GLOBAL_CONFIG_NAMES = { 11*9c5db199SXin Li 'username' : 'user', 12*9c5db199SXin Li 'db_name' : 'database', 13*9c5db199SXin Li} 14*9c5db199SXin Li 15*9c5db199SXin Lidef _copy_exceptions(source, destination): 16*9c5db199SXin Li for exception_name in _DB_EXCEPTIONS: 17*9c5db199SXin Li try: 18*9c5db199SXin Li setattr(destination, exception_name, 19*9c5db199SXin Li getattr(source, exception_name)) 20*9c5db199SXin Li except AttributeError: 21*9c5db199SXin Li # Under the django backend: 22*9c5db199SXin Li # Django 1.3 does not have OperationalError and ProgrammingError. 23*9c5db199SXin Li # Let's just mock these classes with the base DatabaseError. 24*9c5db199SXin Li setattr(destination, exception_name, 25*9c5db199SXin Li getattr(source, 'DatabaseError')) 26*9c5db199SXin Li 27*9c5db199SXin Li 28*9c5db199SXin Liclass _GenericBackend(object): 29*9c5db199SXin Li def __init__(self, database_module): 30*9c5db199SXin Li self._database_module = database_module 31*9c5db199SXin Li self._connection = None 32*9c5db199SXin Li self._cursor = None 33*9c5db199SXin Li self.rowcount = None 34*9c5db199SXin Li _copy_exceptions(database_module, self) 35*9c5db199SXin Li 36*9c5db199SXin Li 37*9c5db199SXin Li def connect(self, host=None, username=None, password=None, db_name=None): 38*9c5db199SXin Li """ 39*9c5db199SXin Li This is assumed to enable autocommit. 40*9c5db199SXin Li """ 41*9c5db199SXin Li raise NotImplementedError 42*9c5db199SXin Li 43*9c5db199SXin Li 44*9c5db199SXin Li def disconnect(self): 45*9c5db199SXin Li if self._connection: 46*9c5db199SXin Li self._connection.close() 47*9c5db199SXin Li self._connection = None 48*9c5db199SXin Li self._cursor = None 49*9c5db199SXin Li 50*9c5db199SXin Li 51*9c5db199SXin Li def execute(self, query, parameters=None): 52*9c5db199SXin Li if parameters is None: 53*9c5db199SXin Li parameters = () 54*9c5db199SXin Li self._cursor.execute(query, parameters) 55*9c5db199SXin Li self.rowcount = self._cursor.rowcount 56*9c5db199SXin Li return self._cursor.fetchall() 57*9c5db199SXin Li 58*9c5db199SXin Li 59*9c5db199SXin Liclass _MySqlBackend(_GenericBackend): 60*9c5db199SXin Li def __init__(self): 61*9c5db199SXin Li import MySQLdb 62*9c5db199SXin Li super(_MySqlBackend, self).__init__(MySQLdb) 63*9c5db199SXin Li 64*9c5db199SXin Li 65*9c5db199SXin Li @staticmethod 66*9c5db199SXin Li def convert_boolean(boolean, conversion_dict): 67*9c5db199SXin Li 'Convert booleans to integer strings' 68*9c5db199SXin Li return str(int(boolean)) 69*9c5db199SXin Li 70*9c5db199SXin Li 71*9c5db199SXin Li def connect(self, host=None, username=None, password=None, db_name=None): 72*9c5db199SXin Li import MySQLdb.converters 73*9c5db199SXin Li convert_dict = MySQLdb.converters.conversions 74*9c5db199SXin Li convert_dict.setdefault(bool, self.convert_boolean) 75*9c5db199SXin Li 76*9c5db199SXin Li self._connection = self._database_module.connect( 77*9c5db199SXin Li host=host, user=username, passwd=password, db=db_name, 78*9c5db199SXin Li conv=convert_dict) 79*9c5db199SXin Li self._connection.autocommit(True) 80*9c5db199SXin Li self._cursor = self._connection.cursor() 81*9c5db199SXin Li 82*9c5db199SXin Li 83*9c5db199SXin Liclass _SqliteBackend(_GenericBackend): 84*9c5db199SXin Li def __init__(self): 85*9c5db199SXin Li try: 86*9c5db199SXin Li from pysqlite2 import dbapi2 87*9c5db199SXin Li except ImportError: 88*9c5db199SXin Li from sqlite3 import dbapi2 89*9c5db199SXin Li super(_SqliteBackend, self).__init__(dbapi2) 90*9c5db199SXin Li self._last_insert_id_re = re.compile(r'\sLAST_INSERT_ID\(\)', 91*9c5db199SXin Li re.IGNORECASE) 92*9c5db199SXin Li 93*9c5db199SXin Li 94*9c5db199SXin Li def connect(self, host=None, username=None, password=None, db_name=None): 95*9c5db199SXin Li self._connection = self._database_module.connect(db_name) 96*9c5db199SXin Li self._connection.isolation_level = None # enable autocommit 97*9c5db199SXin Li self._cursor = self._connection.cursor() 98*9c5db199SXin Li 99*9c5db199SXin Li 100*9c5db199SXin Li def execute(self, query, parameters=None): 101*9c5db199SXin Li # pysqlite2 uses paramstyle=qmark 102*9c5db199SXin Li # TODO: make this more sophisticated if necessary 103*9c5db199SXin Li query = query.replace('%s', '?') 104*9c5db199SXin Li # pysqlite2 can't handle parameters=None (it throws a nonsense 105*9c5db199SXin Li # exception) 106*9c5db199SXin Li if parameters is None: 107*9c5db199SXin Li parameters = () 108*9c5db199SXin Li # sqlite3 doesn't support MySQL's LAST_INSERT_ID(). Instead it has 109*9c5db199SXin Li # something similar called LAST_INSERT_ROWID() that will do enough of 110*9c5db199SXin Li # what we want (for our non-concurrent unittest use case). 111*9c5db199SXin Li query = self._last_insert_id_re.sub(' LAST_INSERT_ROWID()', query) 112*9c5db199SXin Li return super(_SqliteBackend, self).execute(query, parameters) 113*9c5db199SXin Li 114*9c5db199SXin Li 115*9c5db199SXin Liclass _DjangoBackend(_GenericBackend): 116*9c5db199SXin Li def __init__(self): 117*9c5db199SXin Li from django.db import backend, connection, transaction 118*9c5db199SXin Li import django.db as django_db 119*9c5db199SXin Li super(_DjangoBackend, self).__init__(django_db) 120*9c5db199SXin Li self._django_connection = connection 121*9c5db199SXin Li self._django_transaction = transaction 122*9c5db199SXin Li 123*9c5db199SXin Li 124*9c5db199SXin Li def connect(self, host=None, username=None, password=None, db_name=None): 125*9c5db199SXin Li self._connection = self._django_connection 126*9c5db199SXin Li self._cursor = self._connection.cursor() 127*9c5db199SXin Li 128*9c5db199SXin Li 129*9c5db199SXin Li def execute(self, query, parameters=None): 130*9c5db199SXin Li try: 131*9c5db199SXin Li return super(_DjangoBackend, self).execute(query, 132*9c5db199SXin Li parameters=parameters) 133*9c5db199SXin Li finally: 134*9c5db199SXin Li self._django_transaction.commit_unless_managed() 135*9c5db199SXin Li 136*9c5db199SXin Li 137*9c5db199SXin Li_BACKEND_MAP = { 138*9c5db199SXin Li 'mysql': _MySqlBackend, 139*9c5db199SXin Li 'sqlite': _SqliteBackend, 140*9c5db199SXin Li 'django': _DjangoBackend, 141*9c5db199SXin Li} 142*9c5db199SXin Li 143*9c5db199SXin Li 144*9c5db199SXin Liclass DatabaseConnection(object): 145*9c5db199SXin Li """ 146*9c5db199SXin Li Generic wrapper for a database connection. Supports both mysql and sqlite 147*9c5db199SXin Li backends. 148*9c5db199SXin Li 149*9c5db199SXin Li Public attributes: 150*9c5db199SXin Li * reconnect_enabled: if True, when an OperationalError occurs the class will 151*9c5db199SXin Li try to reconnect to the database automatically. 152*9c5db199SXin Li * reconnect_delay_sec: seconds to wait before reconnecting 153*9c5db199SXin Li * max_reconnect_attempts: maximum number of time to try reconnecting before 154*9c5db199SXin Li giving up. Setting to RECONNECT_FOREVER removes the limit. 155*9c5db199SXin Li * rowcount - will hold cursor.rowcount after each call to execute(). 156*9c5db199SXin Li * global_config_section - the section in which to find DB information. this 157*9c5db199SXin Li should be passed to the constructor, not set later, and may be None, in 158*9c5db199SXin Li which case information must be passed to connect(). 159*9c5db199SXin Li * debug - if set True, all queries will be printed before being executed 160*9c5db199SXin Li """ 161*9c5db199SXin Li _DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password', 162*9c5db199SXin Li 'db_name') 163*9c5db199SXin Li 164*9c5db199SXin Li def __init__(self, global_config_section=None, debug=False): 165*9c5db199SXin Li self.global_config_section = global_config_section 166*9c5db199SXin Li self._backend = None 167*9c5db199SXin Li self.rowcount = None 168*9c5db199SXin Li self.debug = debug 169*9c5db199SXin Li 170*9c5db199SXin Li # reconnect defaults 171*9c5db199SXin Li self.reconnect_enabled = True 172*9c5db199SXin Li self.reconnect_delay_sec = 20 173*9c5db199SXin Li self.max_reconnect_attempts = 10 174*9c5db199SXin Li 175*9c5db199SXin Li self._read_options() 176*9c5db199SXin Li 177*9c5db199SXin Li 178*9c5db199SXin Li def _get_option(self, name, provided_value, use_afe_setting=False): 179*9c5db199SXin Li """Get value of given option from global config. 180*9c5db199SXin Li 181*9c5db199SXin Li @param name: Name of the config. 182*9c5db199SXin Li @param provided_value: Value being provided to override the one from 183*9c5db199SXin Li global config. 184*9c5db199SXin Li @param use_afe_setting: Force to use the settings in AFE, default is 185*9c5db199SXin Li False. 186*9c5db199SXin Li """ 187*9c5db199SXin Li # TODO(dshi): This function returns the option value depends on multiple 188*9c5db199SXin Li # conditions. The value of `provided_value` has highest priority, then 189*9c5db199SXin Li # the code checks if use_afe_setting is True, if that's the case, force 190*9c5db199SXin Li # to use settings in AUTOTEST_WEB. At last the value is retrieved from 191*9c5db199SXin Li # specified global config section. 192*9c5db199SXin Li # The logic is too complicated for a generic function named like 193*9c5db199SXin Li # _get_option. Ideally we want to make it clear from caller that it 194*9c5db199SXin Li # wants to get database credential from one of the 3 ways: 195*9c5db199SXin Li # 1. Use the credential from given config section 196*9c5db199SXin Li # 2. Use the credential from AUTOTEST_WEB section 197*9c5db199SXin Li # 3. Use the credential provided by caller. 198*9c5db199SXin Li if provided_value is not None: 199*9c5db199SXin Li return provided_value 200*9c5db199SXin Li section = ('AUTOTEST_WEB' if use_afe_setting else 201*9c5db199SXin Li self.global_config_section) 202*9c5db199SXin Li if section: 203*9c5db199SXin Li global_config_name = _GLOBAL_CONFIG_NAMES.get(name, name) 204*9c5db199SXin Li return global_config.global_config.get_config_value( 205*9c5db199SXin Li section, global_config_name) 206*9c5db199SXin Li 207*9c5db199SXin Li return getattr(self, name, None) 208*9c5db199SXin Li 209*9c5db199SXin Li 210*9c5db199SXin Li def _read_options(self, db_type=None, host=None, username=None, 211*9c5db199SXin Li password=None, db_name=None): 212*9c5db199SXin Li """Read database information from global config. 213*9c5db199SXin Li 214*9c5db199SXin Li Unless any parameter is specified a value, the connection will use 215*9c5db199SXin Li database name from given configure section (self.global_config_section), 216*9c5db199SXin Li and database credential from AFE database settings (AUTOTEST_WEB). 217*9c5db199SXin Li 218*9c5db199SXin Li @param db_type: database type, default to None. 219*9c5db199SXin Li @param host: database hostname, default to None. 220*9c5db199SXin Li @param username: user name for database connection, default to None. 221*9c5db199SXin Li @param password: database password, default to None. 222*9c5db199SXin Li @param db_name: database name, default to None. 223*9c5db199SXin Li """ 224*9c5db199SXin Li self.db_name = self._get_option('db_name', db_name) 225*9c5db199SXin Li use_afe_setting = not bool(db_type or host or username or password) 226*9c5db199SXin Li 227*9c5db199SXin Li # Database credential can be provided by the caller, as passed in from 228*9c5db199SXin Li # function connect. 229*9c5db199SXin Li self.db_type = self._get_option('db_type', db_type, use_afe_setting) 230*9c5db199SXin Li self.host = self._get_option('host', host, use_afe_setting) 231*9c5db199SXin Li self.username = self._get_option('username', username, use_afe_setting) 232*9c5db199SXin Li self.password = self._get_option('password', password, use_afe_setting) 233*9c5db199SXin Li 234*9c5db199SXin Li 235*9c5db199SXin Li def _get_backend(self, db_type): 236*9c5db199SXin Li if db_type not in _BACKEND_MAP: 237*9c5db199SXin Li raise ValueError('Invalid database type: %s, should be one of %s' % 238*9c5db199SXin Li (db_type, ', '.join(_BACKEND_MAP.keys()))) 239*9c5db199SXin Li backend_class = _BACKEND_MAP[db_type] 240*9c5db199SXin Li return backend_class() 241*9c5db199SXin Li 242*9c5db199SXin Li 243*9c5db199SXin Li def _reached_max_attempts(self, num_attempts): 244*9c5db199SXin Li return (self.max_reconnect_attempts is not RECONNECT_FOREVER and 245*9c5db199SXin Li num_attempts > self.max_reconnect_attempts) 246*9c5db199SXin Li 247*9c5db199SXin Li 248*9c5db199SXin Li def _is_reconnect_enabled(self, supplied_param): 249*9c5db199SXin Li if supplied_param is not None: 250*9c5db199SXin Li return supplied_param 251*9c5db199SXin Li return self.reconnect_enabled 252*9c5db199SXin Li 253*9c5db199SXin Li 254*9c5db199SXin Li def _connect_backend(self, try_reconnecting=None): 255*9c5db199SXin Li num_attempts = 0 256*9c5db199SXin Li while True: 257*9c5db199SXin Li try: 258*9c5db199SXin Li self._backend.connect(host=self.host, username=self.username, 259*9c5db199SXin Li password=self.password, 260*9c5db199SXin Li db_name=self.db_name) 261*9c5db199SXin Li return 262*9c5db199SXin Li except self._backend.OperationalError: 263*9c5db199SXin Li num_attempts += 1 264*9c5db199SXin Li if not self._is_reconnect_enabled(try_reconnecting): 265*9c5db199SXin Li raise 266*9c5db199SXin Li if self._reached_max_attempts(num_attempts): 267*9c5db199SXin Li raise 268*9c5db199SXin Li traceback.print_exc() 269*9c5db199SXin Li print("Can't connect to database; reconnecting in %s sec" % 270*9c5db199SXin Li self.reconnect_delay_sec) 271*9c5db199SXin Li time.sleep(self.reconnect_delay_sec) 272*9c5db199SXin Li self.disconnect() 273*9c5db199SXin Li 274*9c5db199SXin Li 275*9c5db199SXin Li def connect(self, db_type=None, host=None, username=None, password=None, 276*9c5db199SXin Li db_name=None, try_reconnecting=None): 277*9c5db199SXin Li """ 278*9c5db199SXin Li Parameters passed to this function will override defaults from global 279*9c5db199SXin Li config. try_reconnecting, if passed, will override 280*9c5db199SXin Li self.reconnect_enabled. 281*9c5db199SXin Li """ 282*9c5db199SXin Li self.disconnect() 283*9c5db199SXin Li self._read_options(db_type, host, username, password, db_name) 284*9c5db199SXin Li 285*9c5db199SXin Li self._backend = self._get_backend(self.db_type) 286*9c5db199SXin Li _copy_exceptions(self._backend, self) 287*9c5db199SXin Li self._connect_backend(try_reconnecting) 288*9c5db199SXin Li 289*9c5db199SXin Li 290*9c5db199SXin Li def disconnect(self): 291*9c5db199SXin Li if self._backend: 292*9c5db199SXin Li self._backend.disconnect() 293*9c5db199SXin Li 294*9c5db199SXin Li 295*9c5db199SXin Li def execute(self, query, parameters=None, try_reconnecting=None): 296*9c5db199SXin Li """ 297*9c5db199SXin Li Execute a query and return cursor.fetchall(). try_reconnecting, if 298*9c5db199SXin Li passed, will override self.reconnect_enabled. 299*9c5db199SXin Li """ 300*9c5db199SXin Li if self.debug: 301*9c5db199SXin Li print('Executing %s, %s' % (query, parameters)) 302*9c5db199SXin Li # _connect_backend() contains a retry loop, so don't loop here 303*9c5db199SXin Li try: 304*9c5db199SXin Li results = self._backend.execute(query, parameters) 305*9c5db199SXin Li except self._backend.OperationalError: 306*9c5db199SXin Li if not self._is_reconnect_enabled(try_reconnecting): 307*9c5db199SXin Li raise 308*9c5db199SXin Li traceback.print_exc() 309*9c5db199SXin Li print("MYSQL connection died; reconnecting") 310*9c5db199SXin Li self.disconnect() 311*9c5db199SXin Li self._connect_backend(try_reconnecting) 312*9c5db199SXin Li results = self._backend.execute(query, parameters) 313*9c5db199SXin Li 314*9c5db199SXin Li self.rowcount = self._backend.rowcount 315*9c5db199SXin Li return results 316*9c5db199SXin Li 317*9c5db199SXin Li 318*9c5db199SXin Li def get_database_info(self): 319*9c5db199SXin Li return dict((attribute, getattr(self, attribute)) 320*9c5db199SXin Li for attribute in self._DATABASE_ATTRIBUTES) 321*9c5db199SXin Li 322*9c5db199SXin Li 323*9c5db199SXin Li @classmethod 324*9c5db199SXin Li def get_test_database(cls, file_path=':memory:', **constructor_kwargs): 325*9c5db199SXin Li """ 326*9c5db199SXin Li Factory method returning a DatabaseConnection for a temporary in-memory 327*9c5db199SXin Li database. 328*9c5db199SXin Li """ 329*9c5db199SXin Li database = cls(**constructor_kwargs) 330*9c5db199SXin Li database.reconnect_enabled = False 331*9c5db199SXin Li database.connect(db_type='sqlite', db_name=file_path) 332*9c5db199SXin Li return database 333*9c5db199SXin Li 334*9c5db199SXin Li 335*9c5db199SXin Liclass TranslatingDatabase(DatabaseConnection): 336*9c5db199SXin Li """ 337*9c5db199SXin Li Database wrapper than applies arbitrary substitution regexps to each query 338*9c5db199SXin Li string. Useful for SQLite testing. 339*9c5db199SXin Li """ 340*9c5db199SXin Li def __init__(self, translators): 341*9c5db199SXin Li """ 342*9c5db199SXin Li @param translation_regexps: list of callables to apply to each query 343*9c5db199SXin Li string (in order). Each accepts a query string and returns a 344*9c5db199SXin Li (possibly) modified query string. 345*9c5db199SXin Li """ 346*9c5db199SXin Li super(TranslatingDatabase, self).__init__() 347*9c5db199SXin Li self._translators = translators 348*9c5db199SXin Li 349*9c5db199SXin Li 350*9c5db199SXin Li def execute(self, query, parameters=None, try_reconnecting=None): 351*9c5db199SXin Li for translator in self._translators: 352*9c5db199SXin Li query = translator(query) 353*9c5db199SXin Li return super(TranslatingDatabase, self).execute( 354*9c5db199SXin Li query, parameters=parameters, try_reconnecting=try_reconnecting) 355*9c5db199SXin Li 356*9c5db199SXin Li 357*9c5db199SXin Li @classmethod 358*9c5db199SXin Li def make_regexp_translator(cls, search_re, replace_str): 359*9c5db199SXin Li """ 360*9c5db199SXin Li Returns a translator that calls re.sub() on the query with the given 361*9c5db199SXin Li search and replace arguments. 362*9c5db199SXin Li """ 363*9c5db199SXin Li def translator(query): 364*9c5db199SXin Li return re.sub(search_re, replace_str, query) 365*9c5db199SXin Li return translator 366