1*9c5db199SXin Li""" 2*9c5db199SXin LiExtensions to Django's model logic. 3*9c5db199SXin Li""" 4*9c5db199SXin Li 5*9c5db199SXin Liimport django.core.exceptions 6*9c5db199SXin Liimport django.db.models.sql.where 7*9c5db199SXin Liimport six 8*9c5db199SXin Lifrom autotest_lib.client.common_lib import error 9*9c5db199SXin Lifrom autotest_lib.frontend.afe import rdb_model_extensions 10*9c5db199SXin Lifrom django.db import connection, connections 11*9c5db199SXin Lifrom django.db import models as dbmodels 12*9c5db199SXin Lifrom django.db import transaction 13*9c5db199SXin Lifrom django.db.models.sql import query 14*9c5db199SXin Li 15*9c5db199SXin Li 16*9c5db199SXin Liclass ValidationError(django.core.exceptions.ValidationError): 17*9c5db199SXin Li """\ 18*9c5db199SXin Li Data validation error in adding or updating an object. The associated 19*9c5db199SXin Li value is a dictionary mapping field names to error strings. 20*9c5db199SXin Li """ 21*9c5db199SXin Li 22*9c5db199SXin Lidef _quote_name(name): 23*9c5db199SXin Li """Shorthand for connection.ops.quote_name().""" 24*9c5db199SXin Li return connection.ops.quote_name(name) 25*9c5db199SXin Li 26*9c5db199SXin Li 27*9c5db199SXin Liclass LeasedHostManager(dbmodels.Manager): 28*9c5db199SXin Li """Query manager for unleased, unlocked hosts. 29*9c5db199SXin Li """ 30*9c5db199SXin Li def get_query_set(self): 31*9c5db199SXin Li return (super(LeasedHostManager, self).get_query_set().filter( 32*9c5db199SXin Li leased=0, locked=0)) 33*9c5db199SXin Li 34*9c5db199SXin Li 35*9c5db199SXin Liclass ExtendedManager(dbmodels.Manager): 36*9c5db199SXin Li """\ 37*9c5db199SXin Li Extended manager supporting subquery filtering. 38*9c5db199SXin Li """ 39*9c5db199SXin Li 40*9c5db199SXin Li class CustomQuery(query.Query): 41*9c5db199SXin Li """A custom query""" 42*9c5db199SXin Li 43*9c5db199SXin Li def __init__(self, *args, **kwargs): 44*9c5db199SXin Li super(ExtendedManager.CustomQuery, self).__init__(*args, **kwargs) 45*9c5db199SXin Li self._custom_joins = [] 46*9c5db199SXin Li 47*9c5db199SXin Li 48*9c5db199SXin Li def clone(self, klass=None, **kwargs): 49*9c5db199SXin Li """Clones the query and returns the clone.""" 50*9c5db199SXin Li obj = super(ExtendedManager.CustomQuery, self).clone(klass) 51*9c5db199SXin Li obj._custom_joins = list(self._custom_joins) 52*9c5db199SXin Li return obj 53*9c5db199SXin Li 54*9c5db199SXin Li 55*9c5db199SXin Li def combine(self, rhs, connector): 56*9c5db199SXin Li """Combines query with another query.""" 57*9c5db199SXin Li super(ExtendedManager.CustomQuery, self).combine(rhs, connector) 58*9c5db199SXin Li if hasattr(rhs, '_custom_joins'): 59*9c5db199SXin Li self._custom_joins.extend(rhs._custom_joins) 60*9c5db199SXin Li 61*9c5db199SXin Li 62*9c5db199SXin Li def add_custom_join(self, table, condition, join_type, 63*9c5db199SXin Li condition_values=(), alias=None): 64*9c5db199SXin Li """Adds a custom join to the query.""" 65*9c5db199SXin Li if alias is None: 66*9c5db199SXin Li alias = table 67*9c5db199SXin Li join_dict = dict(table=table, 68*9c5db199SXin Li condition=condition, 69*9c5db199SXin Li condition_values=condition_values, 70*9c5db199SXin Li join_type=join_type, 71*9c5db199SXin Li alias=alias) 72*9c5db199SXin Li self._custom_joins.append(join_dict) 73*9c5db199SXin Li 74*9c5db199SXin Li 75*9c5db199SXin Li @classmethod 76*9c5db199SXin Li def convert_query(self, query_set): 77*9c5db199SXin Li """ 78*9c5db199SXin Li Convert the query set's "query" attribute to a CustomQuery. 79*9c5db199SXin Li """ 80*9c5db199SXin Li # Make a copy of the query set 81*9c5db199SXin Li query_set = query_set.all() 82*9c5db199SXin Li query_set.query = query_set.query.clone( 83*9c5db199SXin Li klass=ExtendedManager.CustomQuery, 84*9c5db199SXin Li _custom_joins=[]) 85*9c5db199SXin Li return query_set 86*9c5db199SXin Li 87*9c5db199SXin Li 88*9c5db199SXin Li class _WhereClause(object): 89*9c5db199SXin Li """Object allowing us to inject arbitrary SQL into Django queries. 90*9c5db199SXin Li 91*9c5db199SXin Li By using this instead of extra(where=...), we can still freely combine 92*9c5db199SXin Li queries with & and |. 93*9c5db199SXin Li """ 94*9c5db199SXin Li def __init__(self, clause, values=()): 95*9c5db199SXin Li self._clause = clause 96*9c5db199SXin Li self._values = values 97*9c5db199SXin Li 98*9c5db199SXin Li 99*9c5db199SXin Li def as_sql(self, qn=None, connection=None): 100*9c5db199SXin Li """Converts the clause to SQL and returns it.""" 101*9c5db199SXin Li return self._clause, self._values 102*9c5db199SXin Li 103*9c5db199SXin Li 104*9c5db199SXin Li def relabel_aliases(self, change_map): 105*9c5db199SXin Li """Does nothing.""" 106*9c5db199SXin Li return 107*9c5db199SXin Li 108*9c5db199SXin Li 109*9c5db199SXin Li def add_join(self, query_set, join_table, join_key, join_condition='', 110*9c5db199SXin Li join_condition_values=(), join_from_key=None, alias=None, 111*9c5db199SXin Li suffix='', exclude=False, force_left_join=False): 112*9c5db199SXin Li """Add a join to query_set. 113*9c5db199SXin Li 114*9c5db199SXin Li Join looks like this: 115*9c5db199SXin Li (INNER|LEFT) JOIN <join_table> AS <alias> 116*9c5db199SXin Li ON (<this table>.<join_from_key> = <join_table>.<join_key> 117*9c5db199SXin Li and <join_condition>) 118*9c5db199SXin Li 119*9c5db199SXin Li @param join_table table to join to 120*9c5db199SXin Li @param join_key field referencing back to this model to use for the join 121*9c5db199SXin Li @param join_condition extra condition for the ON clause of the join 122*9c5db199SXin Li @param join_condition_values values to substitute into join_condition 123*9c5db199SXin Li @param join_from_key column on this model to join from. 124*9c5db199SXin Li @param alias alias to use for for join 125*9c5db199SXin Li @param suffix suffix to add to join_table for the join alias, if no 126*9c5db199SXin Li alias is provided 127*9c5db199SXin Li @param exclude if true, exclude rows that match this join (will use a 128*9c5db199SXin Li LEFT OUTER JOIN and an appropriate WHERE condition) 129*9c5db199SXin Li @param force_left_join - if true, a LEFT OUTER JOIN will be used 130*9c5db199SXin Li instead of an INNER JOIN regardless of other options 131*9c5db199SXin Li """ 132*9c5db199SXin Li join_from_table = query_set.model._meta.db_table 133*9c5db199SXin Li if join_from_key is None: 134*9c5db199SXin Li join_from_key = self.model._meta.pk.name 135*9c5db199SXin Li if alias is None: 136*9c5db199SXin Li alias = join_table + suffix 137*9c5db199SXin Li full_join_key = _quote_name(alias) + '.' + _quote_name(join_key) 138*9c5db199SXin Li full_join_condition = '%s = %s.%s' % (full_join_key, 139*9c5db199SXin Li _quote_name(join_from_table), 140*9c5db199SXin Li _quote_name(join_from_key)) 141*9c5db199SXin Li if join_condition: 142*9c5db199SXin Li full_join_condition += ' AND (' + join_condition + ')' 143*9c5db199SXin Li if exclude or force_left_join: 144*9c5db199SXin Li join_type = query_set.query.LOUTER 145*9c5db199SXin Li else: 146*9c5db199SXin Li join_type = query_set.query.INNER 147*9c5db199SXin Li 148*9c5db199SXin Li query_set = self.CustomQuery.convert_query(query_set) 149*9c5db199SXin Li query_set.query.add_custom_join(join_table, 150*9c5db199SXin Li full_join_condition, 151*9c5db199SXin Li join_type, 152*9c5db199SXin Li condition_values=join_condition_values, 153*9c5db199SXin Li alias=alias) 154*9c5db199SXin Li 155*9c5db199SXin Li if exclude: 156*9c5db199SXin Li query_set = query_set.extra(where=[full_join_key + ' IS NULL']) 157*9c5db199SXin Li 158*9c5db199SXin Li return query_set 159*9c5db199SXin Li 160*9c5db199SXin Li 161*9c5db199SXin Li def _info_for_many_to_one_join(self, field, join_to_query, alias): 162*9c5db199SXin Li """ 163*9c5db199SXin Li @param field: the ForeignKey field on the related model 164*9c5db199SXin Li @param join_to_query: the query over the related model that we're 165*9c5db199SXin Li joining to 166*9c5db199SXin Li @param alias: alias of joined table 167*9c5db199SXin Li """ 168*9c5db199SXin Li info = {} 169*9c5db199SXin Li rhs_table = join_to_query.model._meta.db_table 170*9c5db199SXin Li info['rhs_table'] = rhs_table 171*9c5db199SXin Li info['rhs_column'] = field.column 172*9c5db199SXin Li info['lhs_column'] = field.rel.get_related_field().column 173*9c5db199SXin Li rhs_where = join_to_query.query.where 174*9c5db199SXin Li rhs_where.relabel_aliases({rhs_table: alias}) 175*9c5db199SXin Li compiler = join_to_query.query.get_compiler(using=join_to_query.db) 176*9c5db199SXin Li initial_clause, values = compiler.as_sql() 177*9c5db199SXin Li # initial_clause is compiled from `join_to_query`, which is a SELECT 178*9c5db199SXin Li # query returns at most one record. For it to be used in WHERE clause, 179*9c5db199SXin Li # it must be converted to a boolean value using EXISTS. 180*9c5db199SXin Li all_clauses = ('EXISTS (%s)' % initial_clause,) 181*9c5db199SXin Li if hasattr(join_to_query.query, 'extra_where'): 182*9c5db199SXin Li all_clauses += join_to_query.query.extra_where 183*9c5db199SXin Li info['where_clause'] = ( 184*9c5db199SXin Li ' AND '.join('(%s)' % clause for clause in all_clauses)) 185*9c5db199SXin Li info['values'] = values 186*9c5db199SXin Li return info 187*9c5db199SXin Li 188*9c5db199SXin Li 189*9c5db199SXin Li def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias, 190*9c5db199SXin Li m2m_is_on_this_model): 191*9c5db199SXin Li """ 192*9c5db199SXin Li @param m2m_field: a Django field representing the M2M relationship. 193*9c5db199SXin Li It uses a pivot table with the following structure: 194*9c5db199SXin Li this model table <---> M2M pivot table <---> joined model table 195*9c5db199SXin Li @param join_to_query: the query over the related model that we're 196*9c5db199SXin Li joining to. 197*9c5db199SXin Li @param alias: alias of joined table 198*9c5db199SXin Li """ 199*9c5db199SXin Li if m2m_is_on_this_model: 200*9c5db199SXin Li # referenced field on this model 201*9c5db199SXin Li lhs_id_field = self.model._meta.pk 202*9c5db199SXin Li # foreign key on the pivot table referencing lhs_id_field 203*9c5db199SXin Li m2m_lhs_column = m2m_field.m2m_column_name() 204*9c5db199SXin Li # foreign key on the pivot table referencing rhd_id_field 205*9c5db199SXin Li m2m_rhs_column = m2m_field.m2m_reverse_name() 206*9c5db199SXin Li # referenced field on related model 207*9c5db199SXin Li rhs_id_field = m2m_field.rel.get_related_field() 208*9c5db199SXin Li else: 209*9c5db199SXin Li lhs_id_field = m2m_field.rel.get_related_field() 210*9c5db199SXin Li m2m_lhs_column = m2m_field.m2m_reverse_name() 211*9c5db199SXin Li m2m_rhs_column = m2m_field.m2m_column_name() 212*9c5db199SXin Li rhs_id_field = join_to_query.model._meta.pk 213*9c5db199SXin Li 214*9c5db199SXin Li info = {} 215*9c5db199SXin Li info['rhs_table'] = m2m_field.m2m_db_table() 216*9c5db199SXin Li info['rhs_column'] = m2m_lhs_column 217*9c5db199SXin Li info['lhs_column'] = lhs_id_field.column 218*9c5db199SXin Li 219*9c5db199SXin Li # select the ID of related models relevant to this join. we can only do 220*9c5db199SXin Li # a single join, so we need to gather this information up front and 221*9c5db199SXin Li # include it in the join condition. 222*9c5db199SXin Li rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True) 223*9c5db199SXin Li assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only ' 224*9c5db199SXin Li 'match a single related object.') 225*9c5db199SXin Li rhs_id = rhs_ids[0] 226*9c5db199SXin Li 227*9c5db199SXin Li info['where_clause'] = '%s.%s = %s' % (_quote_name(alias), 228*9c5db199SXin Li _quote_name(m2m_rhs_column), 229*9c5db199SXin Li rhs_id) 230*9c5db199SXin Li info['values'] = () 231*9c5db199SXin Li return info 232*9c5db199SXin Li 233*9c5db199SXin Li 234*9c5db199SXin Li def join_custom_field(self, query_set, join_to_query, alias, 235*9c5db199SXin Li left_join=True): 236*9c5db199SXin Li """Join to a related model to create a custom field in the given query. 237*9c5db199SXin Li 238*9c5db199SXin Li This method is used to construct a custom field on the given query based 239*9c5db199SXin Li on a many-valued relationsip. join_to_query should be a simple query 240*9c5db199SXin Li (no joins) on the related model which returns at most one related row 241*9c5db199SXin Li per instance of this model. 242*9c5db199SXin Li 243*9c5db199SXin Li For many-to-one relationships, the joined table contains the matching 244*9c5db199SXin Li row from the related model it one is related, NULL otherwise. 245*9c5db199SXin Li 246*9c5db199SXin Li For many-to-many relationships, the joined table contains the matching 247*9c5db199SXin Li row if it's related, NULL otherwise. 248*9c5db199SXin Li """ 249*9c5db199SXin Li relationship_type, field = self.determine_relationship( 250*9c5db199SXin Li join_to_query.model) 251*9c5db199SXin Li 252*9c5db199SXin Li if relationship_type == self.MANY_TO_ONE: 253*9c5db199SXin Li info = self._info_for_many_to_one_join(field, join_to_query, alias) 254*9c5db199SXin Li elif relationship_type == self.M2M_ON_RELATED_MODEL: 255*9c5db199SXin Li info = self._info_for_many_to_many_join( 256*9c5db199SXin Li m2m_field=field, join_to_query=join_to_query, alias=alias, 257*9c5db199SXin Li m2m_is_on_this_model=False) 258*9c5db199SXin Li elif relationship_type ==self.M2M_ON_THIS_MODEL: 259*9c5db199SXin Li info = self._info_for_many_to_many_join( 260*9c5db199SXin Li m2m_field=field, join_to_query=join_to_query, alias=alias, 261*9c5db199SXin Li m2m_is_on_this_model=True) 262*9c5db199SXin Li 263*9c5db199SXin Li return self.add_join(query_set, info['rhs_table'], info['rhs_column'], 264*9c5db199SXin Li join_from_key=info['lhs_column'], 265*9c5db199SXin Li join_condition=info['where_clause'], 266*9c5db199SXin Li join_condition_values=info['values'], 267*9c5db199SXin Li alias=alias, 268*9c5db199SXin Li force_left_join=left_join) 269*9c5db199SXin Li 270*9c5db199SXin Li 271*9c5db199SXin Li def add_where(self, query_set, where, values=()): 272*9c5db199SXin Li """Adds a where clause to the query_set.""" 273*9c5db199SXin Li query_set = query_set.all() 274*9c5db199SXin Li query_set.query.where.add(self._WhereClause(where, values), 275*9c5db199SXin Li django.db.models.sql.where.AND) 276*9c5db199SXin Li return query_set 277*9c5db199SXin Li 278*9c5db199SXin Li 279*9c5db199SXin Li def _get_quoted_field(self, table, field): 280*9c5db199SXin Li return _quote_name(table) + '.' + _quote_name(field) 281*9c5db199SXin Li 282*9c5db199SXin Li 283*9c5db199SXin Li def get_key_on_this_table(self, key_field=None): 284*9c5db199SXin Li if key_field is None: 285*9c5db199SXin Li # default to primary key 286*9c5db199SXin Li key_field = self.model._meta.pk.column 287*9c5db199SXin Li return self._get_quoted_field(self.model._meta.db_table, key_field) 288*9c5db199SXin Li 289*9c5db199SXin Li 290*9c5db199SXin Li def escape_user_sql(self, sql): 291*9c5db199SXin Li """Escapes % in sql.""" 292*9c5db199SXin Li return sql.replace('%', '%%') 293*9c5db199SXin Li 294*9c5db199SXin Li 295*9c5db199SXin Li def _custom_select_query(self, query_set, selects): 296*9c5db199SXin Li """Execute a custom select query. 297*9c5db199SXin Li 298*9c5db199SXin Li @param query_set: query set as returned by query_objects. 299*9c5db199SXin Li @param selects: Tables/Columns to select, e.g. tko_test_labels_list.id. 300*9c5db199SXin Li 301*9c5db199SXin Li @returns: Result of the query as returned by cursor.fetchall(). 302*9c5db199SXin Li """ 303*9c5db199SXin Li compiler = query_set.query.get_compiler(using=query_set.db) 304*9c5db199SXin Li sql, params = compiler.as_sql() 305*9c5db199SXin Li from_ = sql[sql.find(' FROM'):] 306*9c5db199SXin Li 307*9c5db199SXin Li if query_set.query.distinct: 308*9c5db199SXin Li distinct = 'DISTINCT ' 309*9c5db199SXin Li else: 310*9c5db199SXin Li distinct = '' 311*9c5db199SXin Li 312*9c5db199SXin Li sql_query = ('SELECT ' + distinct + ','.join(selects) + from_) 313*9c5db199SXin Li # Chose the connection that's responsible for this type of object 314*9c5db199SXin Li cursor = connections[query_set.db].cursor() 315*9c5db199SXin Li cursor.execute(sql_query, params) 316*9c5db199SXin Li return cursor.fetchall() 317*9c5db199SXin Li 318*9c5db199SXin Li 319*9c5db199SXin Li def _is_relation_to(self, field, model_class): 320*9c5db199SXin Li return field.rel and field.rel.to is model_class 321*9c5db199SXin Li 322*9c5db199SXin Li 323*9c5db199SXin Li MANY_TO_ONE = object() 324*9c5db199SXin Li M2M_ON_RELATED_MODEL = object() 325*9c5db199SXin Li M2M_ON_THIS_MODEL = object() 326*9c5db199SXin Li 327*9c5db199SXin Li def determine_relationship(self, related_model): 328*9c5db199SXin Li """ 329*9c5db199SXin Li Determine the relationship between this model and related_model. 330*9c5db199SXin Li 331*9c5db199SXin Li related_model must have some sort of many-valued relationship to this 332*9c5db199SXin Li manager's model. 333*9c5db199SXin Li @returns (relationship_type, field), where relationship_type is one of 334*9c5db199SXin Li MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field 335*9c5db199SXin Li is the Django field object for the relationship. 336*9c5db199SXin Li """ 337*9c5db199SXin Li # look for a foreign key field on related_model relating to this model 338*9c5db199SXin Li for field in related_model._meta.fields: 339*9c5db199SXin Li if self._is_relation_to(field, self.model): 340*9c5db199SXin Li return self.MANY_TO_ONE, field 341*9c5db199SXin Li 342*9c5db199SXin Li # look for an M2M field on related_model relating to this model 343*9c5db199SXin Li for field in related_model._meta.many_to_many: 344*9c5db199SXin Li if self._is_relation_to(field, self.model): 345*9c5db199SXin Li return self.M2M_ON_RELATED_MODEL, field 346*9c5db199SXin Li 347*9c5db199SXin Li # maybe this model has the many-to-many field 348*9c5db199SXin Li for field in self.model._meta.many_to_many: 349*9c5db199SXin Li if self._is_relation_to(field, related_model): 350*9c5db199SXin Li return self.M2M_ON_THIS_MODEL, field 351*9c5db199SXin Li 352*9c5db199SXin Li raise ValueError('%s has no relation to %s' % 353*9c5db199SXin Li (related_model, self.model)) 354*9c5db199SXin Li 355*9c5db199SXin Li 356*9c5db199SXin Li def _get_pivot_iterator(self, base_objects_by_id, related_model): 357*9c5db199SXin Li """ 358*9c5db199SXin Li Determine the relationship between this model and related_model, and 359*9c5db199SXin Li return a pivot iterator. 360*9c5db199SXin Li @param base_objects_by_id: dict of instances of this model indexed by 361*9c5db199SXin Li their IDs 362*9c5db199SXin Li @returns a pivot iterator, which yields a tuple (base_object, 363*9c5db199SXin Li related_object) for each relationship between a base object and a 364*9c5db199SXin Li related object. all base_object instances come from base_objects_by_id. 365*9c5db199SXin Li Note -- this depends on Django model internals. 366*9c5db199SXin Li """ 367*9c5db199SXin Li relationship_type, field = self.determine_relationship(related_model) 368*9c5db199SXin Li if relationship_type == self.MANY_TO_ONE: 369*9c5db199SXin Li return self._many_to_one_pivot(base_objects_by_id, 370*9c5db199SXin Li related_model, field) 371*9c5db199SXin Li elif relationship_type == self.M2M_ON_RELATED_MODEL: 372*9c5db199SXin Li return self._many_to_many_pivot( 373*9c5db199SXin Li base_objects_by_id, related_model, field.m2m_db_table(), 374*9c5db199SXin Li field.m2m_reverse_name(), field.m2m_column_name()) 375*9c5db199SXin Li else: 376*9c5db199SXin Li assert relationship_type == self.M2M_ON_THIS_MODEL 377*9c5db199SXin Li return self._many_to_many_pivot( 378*9c5db199SXin Li base_objects_by_id, related_model, field.m2m_db_table(), 379*9c5db199SXin Li field.m2m_column_name(), field.m2m_reverse_name()) 380*9c5db199SXin Li 381*9c5db199SXin Li 382*9c5db199SXin Li def _many_to_one_pivot(self, base_objects_by_id, related_model, 383*9c5db199SXin Li foreign_key_field): 384*9c5db199SXin Li """ 385*9c5db199SXin Li @returns a pivot iterator - see _get_pivot_iterator() 386*9c5db199SXin Li """ 387*9c5db199SXin Li filter_data = {foreign_key_field.name + '__pk__in': 388*9c5db199SXin Li base_objects_by_id.keys()} 389*9c5db199SXin Li for related_object in related_model.objects.filter(**filter_data): 390*9c5db199SXin Li # lookup base object in the dict, rather than grabbing it from the 391*9c5db199SXin Li # related object. we need to return instances from the dict, not 392*9c5db199SXin Li # fresh instances of the same models (and grabbing model instances 393*9c5db199SXin Li # from the related models incurs a DB query each time). 394*9c5db199SXin Li base_object_id = getattr(related_object, foreign_key_field.attname) 395*9c5db199SXin Li base_object = base_objects_by_id[base_object_id] 396*9c5db199SXin Li yield base_object, related_object 397*9c5db199SXin Li 398*9c5db199SXin Li 399*9c5db199SXin Li def _query_pivot_table(self, base_objects_by_id, pivot_table, 400*9c5db199SXin Li pivot_from_field, pivot_to_field, related_model): 401*9c5db199SXin Li """ 402*9c5db199SXin Li @param id_list list of IDs of self.model objects to include 403*9c5db199SXin Li @param pivot_table the name of the pivot table 404*9c5db199SXin Li @param pivot_from_field a field name on pivot_table referencing 405*9c5db199SXin Li self.model 406*9c5db199SXin Li @param pivot_to_field a field name on pivot_table referencing the 407*9c5db199SXin Li related model. 408*9c5db199SXin Li @param related_model the related model 409*9c5db199SXin Li 410*9c5db199SXin Li @returns pivot list of IDs (base_id, related_id) 411*9c5db199SXin Li """ 412*9c5db199SXin Li query = """ 413*9c5db199SXin Li SELECT %(from_field)s, %(to_field)s 414*9c5db199SXin Li FROM %(table)s 415*9c5db199SXin Li WHERE %(from_field)s IN (%(id_list)s) 416*9c5db199SXin Li """ % dict(from_field=pivot_from_field, 417*9c5db199SXin Li to_field=pivot_to_field, 418*9c5db199SXin Li table=pivot_table, 419*9c5db199SXin Li id_list=','.join( 420*9c5db199SXin Li str(id_) 421*9c5db199SXin Li for id_ in six.iterkeys(base_objects_by_id))) 422*9c5db199SXin Li 423*9c5db199SXin Li # Chose the connection that's responsible for this type of object 424*9c5db199SXin Li # The databases for related_model and the current model will always 425*9c5db199SXin Li # be the same, related_model is just easier to obtain here because 426*9c5db199SXin Li # self is only a ExtendedManager, not the object. 427*9c5db199SXin Li cursor = connections[related_model.objects.db].cursor() 428*9c5db199SXin Li cursor.execute(query) 429*9c5db199SXin Li return cursor.fetchall() 430*9c5db199SXin Li 431*9c5db199SXin Li 432*9c5db199SXin Li def _many_to_many_pivot(self, base_objects_by_id, related_model, 433*9c5db199SXin Li pivot_table, pivot_from_field, pivot_to_field): 434*9c5db199SXin Li """ 435*9c5db199SXin Li @param pivot_table: see _query_pivot_table 436*9c5db199SXin Li @param pivot_from_field: see _query_pivot_table 437*9c5db199SXin Li @param pivot_to_field: see _query_pivot_table 438*9c5db199SXin Li @returns a pivot iterator - see _get_pivot_iterator() 439*9c5db199SXin Li """ 440*9c5db199SXin Li id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table, 441*9c5db199SXin Li pivot_from_field, pivot_to_field, 442*9c5db199SXin Li related_model) 443*9c5db199SXin Li 444*9c5db199SXin Li all_related_ids = list(set(related_id for base_id, related_id 445*9c5db199SXin Li in id_pivot)) 446*9c5db199SXin Li related_objects_by_id = related_model.objects.in_bulk(all_related_ids) 447*9c5db199SXin Li 448*9c5db199SXin Li for base_id, related_id in id_pivot: 449*9c5db199SXin Li yield base_objects_by_id[base_id], related_objects_by_id[related_id] 450*9c5db199SXin Li 451*9c5db199SXin Li 452*9c5db199SXin Li def populate_relationships(self, base_objects, related_model, 453*9c5db199SXin Li related_list_name): 454*9c5db199SXin Li """ 455*9c5db199SXin Li For each instance of this model in base_objects, add a field named 456*9c5db199SXin Li related_list_name listing all the related objects of type related_model. 457*9c5db199SXin Li related_model must be in a many-to-one or many-to-many relationship with 458*9c5db199SXin Li this model. 459*9c5db199SXin Li @param base_objects - list of instances of this model 460*9c5db199SXin Li @param related_model - model class related to this model 461*9c5db199SXin Li @param related_list_name - attribute name in which to store the related 462*9c5db199SXin Li object list. 463*9c5db199SXin Li """ 464*9c5db199SXin Li if not base_objects: 465*9c5db199SXin Li # if we don't bail early, we'll get a SQL error later 466*9c5db199SXin Li return 467*9c5db199SXin Li 468*9c5db199SXin Li # The default maximum value of a host parameter number in SQLite is 999. 469*9c5db199SXin Li # Exceed this will get a DatabaseError later. 470*9c5db199SXin Li batch_size = 900 471*9c5db199SXin Li for i in range(0, len(base_objects), batch_size): 472*9c5db199SXin Li base_objects_batch = base_objects[i:i + batch_size] 473*9c5db199SXin Li base_objects_by_id = dict((base_object._get_pk_val(), base_object) 474*9c5db199SXin Li for base_object in base_objects_batch) 475*9c5db199SXin Li pivot_iterator = self._get_pivot_iterator(base_objects_by_id, 476*9c5db199SXin Li related_model) 477*9c5db199SXin Li 478*9c5db199SXin Li for base_object in base_objects_batch: 479*9c5db199SXin Li setattr(base_object, related_list_name, []) 480*9c5db199SXin Li 481*9c5db199SXin Li for base_object, related_object in pivot_iterator: 482*9c5db199SXin Li getattr(base_object, related_list_name).append(related_object) 483*9c5db199SXin Li 484*9c5db199SXin Li 485*9c5db199SXin Liclass ModelWithInvalidQuerySet(dbmodels.query.QuerySet): 486*9c5db199SXin Li """ 487*9c5db199SXin Li QuerySet that handles delete() properly for models with an "invalid" bit 488*9c5db199SXin Li """ 489*9c5db199SXin Li def delete(self): 490*9c5db199SXin Li """Deletes the QuerySet.""" 491*9c5db199SXin Li for model in self: 492*9c5db199SXin Li model.delete() 493*9c5db199SXin Li 494*9c5db199SXin Li 495*9c5db199SXin Liclass ModelWithInvalidManager(ExtendedManager): 496*9c5db199SXin Li """ 497*9c5db199SXin Li Manager for objects with an "invalid" bit 498*9c5db199SXin Li """ 499*9c5db199SXin Li def get_query_set(self): 500*9c5db199SXin Li return ModelWithInvalidQuerySet(self.model) 501*9c5db199SXin Li 502*9c5db199SXin Li 503*9c5db199SXin Liclass ValidObjectsManager(ModelWithInvalidManager): 504*9c5db199SXin Li """ 505*9c5db199SXin Li Manager returning only objects with invalid=False. 506*9c5db199SXin Li """ 507*9c5db199SXin Li def get_query_set(self): 508*9c5db199SXin Li queryset = super(ValidObjectsManager, self).get_query_set() 509*9c5db199SXin Li return queryset.filter(invalid=False) 510*9c5db199SXin Li 511*9c5db199SXin Li 512*9c5db199SXin Liclass ModelExtensions(rdb_model_extensions.ModelValidators): 513*9c5db199SXin Li """\ 514*9c5db199SXin Li Mixin with convenience functions for models, built on top of 515*9c5db199SXin Li the model validators in rdb_model_extensions. 516*9c5db199SXin Li """ 517*9c5db199SXin Li # TODO: at least some of these functions really belong in a custom 518*9c5db199SXin Li # Manager class 519*9c5db199SXin Li 520*9c5db199SXin Li 521*9c5db199SXin Li SERIALIZATION_LINKS_TO_FOLLOW = set() 522*9c5db199SXin Li """ 523*9c5db199SXin Li To be able to send jobs and hosts to shards, it's necessary to find their 524*9c5db199SXin Li dependencies. 525*9c5db199SXin Li The most generic approach for this would be to traverse all relationships 526*9c5db199SXin Li to other objects recursively. This would list all objects that are related 527*9c5db199SXin Li in any way. 528*9c5db199SXin Li But this approach finds too many objects: If a host should be transferred, 529*9c5db199SXin Li all it's relationships would be traversed. This would find an acl group. 530*9c5db199SXin Li If then the acl group's relationships are traversed, the relationship 531*9c5db199SXin Li would be followed backwards and many other hosts would be found. 532*9c5db199SXin Li 533*9c5db199SXin Li This mapping tells that algorithm which relations to follow explicitly. 534*9c5db199SXin Li """ 535*9c5db199SXin Li 536*9c5db199SXin Li 537*9c5db199SXin Li SERIALIZATION_LINKS_TO_KEEP = set() 538*9c5db199SXin Li """This set stores foreign keys which we don't want to follow, but 539*9c5db199SXin Li still want to include in the serialized dictionary. For 540*9c5db199SXin Li example, we follow the relationship `Host.hostattribute_set`, 541*9c5db199SXin Li but we do not want to follow `HostAttributes.host_id` back to 542*9c5db199SXin Li to Host, which would otherwise lead to a circle. However, we still 543*9c5db199SXin Li like to serialize HostAttribute.`host_id`.""" 544*9c5db199SXin Li 545*9c5db199SXin Li SERIALIZATION_LOCAL_LINKS_TO_UPDATE = set() 546*9c5db199SXin Li """ 547*9c5db199SXin Li On deserializion, if the object to persist already exists, local fields 548*9c5db199SXin Li will only be updated, if their name is in this set. 549*9c5db199SXin Li """ 550*9c5db199SXin Li 551*9c5db199SXin Li 552*9c5db199SXin Li @classmethod 553*9c5db199SXin Li def convert_human_readable_values(cls, data, to_human_readable=False): 554*9c5db199SXin Li """\ 555*9c5db199SXin Li Performs conversions on user-supplied field data, to make it 556*9c5db199SXin Li easier for users to pass human-readable data. 557*9c5db199SXin Li 558*9c5db199SXin Li For all fields that have choice sets, convert their values 559*9c5db199SXin Li from human-readable strings to enum values, if necessary. This 560*9c5db199SXin Li allows users to pass strings instead of the corresponding 561*9c5db199SXin Li integer values. 562*9c5db199SXin Li 563*9c5db199SXin Li For all foreign key fields, call smart_get with the supplied 564*9c5db199SXin Li data. This allows the user to pass either an ID value or 565*9c5db199SXin Li the name of the object as a string. 566*9c5db199SXin Li 567*9c5db199SXin Li If to_human_readable=True, perform the inverse - i.e. convert 568*9c5db199SXin Li numeric values to human readable values. 569*9c5db199SXin Li 570*9c5db199SXin Li This method modifies data in-place. 571*9c5db199SXin Li """ 572*9c5db199SXin Li field_dict = cls.get_field_dict() 573*9c5db199SXin Li for field_name in data: 574*9c5db199SXin Li if field_name not in field_dict or data[field_name] is None: 575*9c5db199SXin Li continue 576*9c5db199SXin Li field_obj = field_dict[field_name] 577*9c5db199SXin Li # convert enum values 578*9c5db199SXin Li if field_obj.choices: 579*9c5db199SXin Li for choice_data in field_obj.choices: 580*9c5db199SXin Li # choice_data is (value, name) 581*9c5db199SXin Li if to_human_readable: 582*9c5db199SXin Li from_val, to_val = choice_data 583*9c5db199SXin Li else: 584*9c5db199SXin Li to_val, from_val = choice_data 585*9c5db199SXin Li if from_val == data[field_name]: 586*9c5db199SXin Li data[field_name] = to_val 587*9c5db199SXin Li break 588*9c5db199SXin Li # convert foreign key values 589*9c5db199SXin Li elif field_obj.rel: 590*9c5db199SXin Li dest_obj = field_obj.rel.to.smart_get(data[field_name], 591*9c5db199SXin Li valid_only=False) 592*9c5db199SXin Li if to_human_readable: 593*9c5db199SXin Li # parameterized_jobs do not have a name_field 594*9c5db199SXin Li if (field_name != 'parameterized_job' and 595*9c5db199SXin Li dest_obj.name_field is not None): 596*9c5db199SXin Li data[field_name] = getattr(dest_obj, 597*9c5db199SXin Li dest_obj.name_field) 598*9c5db199SXin Li else: 599*9c5db199SXin Li data[field_name] = dest_obj 600*9c5db199SXin Li 601*9c5db199SXin Li 602*9c5db199SXin Li 603*9c5db199SXin Li 604*9c5db199SXin Li def _validate_unique(self): 605*9c5db199SXin Li """\ 606*9c5db199SXin Li Validate that unique fields are unique. Django manipulators do 607*9c5db199SXin Li this too, but they're a huge pain to use manually. Trust me. 608*9c5db199SXin Li """ 609*9c5db199SXin Li errors = {} 610*9c5db199SXin Li cls = type(self) 611*9c5db199SXin Li field_dict = self.get_field_dict() 612*9c5db199SXin Li manager = cls.get_valid_manager() 613*9c5db199SXin Li for field_name, field_obj in six.iteritems(field_dict): 614*9c5db199SXin Li if not field_obj.unique: 615*9c5db199SXin Li continue 616*9c5db199SXin Li 617*9c5db199SXin Li value = getattr(self, field_name) 618*9c5db199SXin Li if value is None and field_obj.auto_created: 619*9c5db199SXin Li # don't bother checking autoincrement fields about to be 620*9c5db199SXin Li # generated 621*9c5db199SXin Li continue 622*9c5db199SXin Li 623*9c5db199SXin Li existing_objs = manager.filter(**{field_name : value}) 624*9c5db199SXin Li num_existing = existing_objs.count() 625*9c5db199SXin Li 626*9c5db199SXin Li if num_existing == 0: 627*9c5db199SXin Li continue 628*9c5db199SXin Li if num_existing == 1 and existing_objs[0].id == self.id: 629*9c5db199SXin Li continue 630*9c5db199SXin Li errors[field_name] = ( 631*9c5db199SXin Li 'This value must be unique (%s)' % (value)) 632*9c5db199SXin Li return errors 633*9c5db199SXin Li 634*9c5db199SXin Li 635*9c5db199SXin Li def _validate(self): 636*9c5db199SXin Li """ 637*9c5db199SXin Li First coerces all fields on this instance to their proper Python types. 638*9c5db199SXin Li Then runs validation on every field. Returns a dictionary of 639*9c5db199SXin Li field_name -> error_list. 640*9c5db199SXin Li 641*9c5db199SXin Li Based on validate() from django.db.models.Model in Django 0.96, which 642*9c5db199SXin Li was removed in Django 1.0. It should reappear in a later version. See: 643*9c5db199SXin Li http://code.djangoproject.com/ticket/6845 644*9c5db199SXin Li """ 645*9c5db199SXin Li error_dict = {} 646*9c5db199SXin Li for f in self._meta.fields: 647*9c5db199SXin Li try: 648*9c5db199SXin Li python_value = f.to_python( 649*9c5db199SXin Li getattr(self, f.attname, f.get_default())) 650*9c5db199SXin Li except django.core.exceptions.ValidationError as e: 651*9c5db199SXin Li error_dict[f.name] = str(e) 652*9c5db199SXin Li continue 653*9c5db199SXin Li 654*9c5db199SXin Li if not f.blank and not python_value: 655*9c5db199SXin Li error_dict[f.name] = 'This field is required.' 656*9c5db199SXin Li continue 657*9c5db199SXin Li 658*9c5db199SXin Li setattr(self, f.attname, python_value) 659*9c5db199SXin Li 660*9c5db199SXin Li return error_dict 661*9c5db199SXin Li 662*9c5db199SXin Li 663*9c5db199SXin Li def do_validate(self): 664*9c5db199SXin Li """Validate fields.""" 665*9c5db199SXin Li errors = self._validate() 666*9c5db199SXin Li unique_errors = self._validate_unique() 667*9c5db199SXin Li for field_name, error in six.iteritems(unique_errors): 668*9c5db199SXin Li errors.setdefault(field_name, error) 669*9c5db199SXin Li if errors: 670*9c5db199SXin Li raise ValidationError(errors) 671*9c5db199SXin Li 672*9c5db199SXin Li 673*9c5db199SXin Li # actually (externally) useful methods follow 674*9c5db199SXin Li 675*9c5db199SXin Li @classmethod 676*9c5db199SXin Li def add_object(cls, data={}, **kwargs): 677*9c5db199SXin Li """\ 678*9c5db199SXin Li Returns a new object created with the given data (a dictionary 679*9c5db199SXin Li mapping field names to values). Merges any extra keyword args 680*9c5db199SXin Li into data. 681*9c5db199SXin Li """ 682*9c5db199SXin Li data = dict(data) 683*9c5db199SXin Li data.update(kwargs) 684*9c5db199SXin Li data = cls.prepare_data_args(data) 685*9c5db199SXin Li cls.convert_human_readable_values(data) 686*9c5db199SXin Li data = cls.provide_default_values(data) 687*9c5db199SXin Li 688*9c5db199SXin Li obj = cls(**data) 689*9c5db199SXin Li obj.do_validate() 690*9c5db199SXin Li obj.save() 691*9c5db199SXin Li return obj 692*9c5db199SXin Li 693*9c5db199SXin Li 694*9c5db199SXin Li def update_object(self, data={}, **kwargs): 695*9c5db199SXin Li """\ 696*9c5db199SXin Li Updates the object with the given data (a dictionary mapping 697*9c5db199SXin Li field names to values). Merges any extra keyword args into 698*9c5db199SXin Li data. 699*9c5db199SXin Li """ 700*9c5db199SXin Li data = dict(data) 701*9c5db199SXin Li data.update(kwargs) 702*9c5db199SXin Li data = self.prepare_data_args(data) 703*9c5db199SXin Li self.convert_human_readable_values(data) 704*9c5db199SXin Li for field_name, value in six.iteritems(data): 705*9c5db199SXin Li setattr(self, field_name, value) 706*9c5db199SXin Li self.do_validate() 707*9c5db199SXin Li self.save() 708*9c5db199SXin Li 709*9c5db199SXin Li 710*9c5db199SXin Li # see query_objects() 711*9c5db199SXin Li _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by', 712*9c5db199SXin Li 'extra_args', 'extra_where', 'no_distinct') 713*9c5db199SXin Li 714*9c5db199SXin Li 715*9c5db199SXin Li @classmethod 716*9c5db199SXin Li def _extract_special_params(cls, filter_data): 717*9c5db199SXin Li """ 718*9c5db199SXin Li @returns a tuple of dicts (special_params, regular_filters), where 719*9c5db199SXin Li special_params contains the parameters we handle specially and 720*9c5db199SXin Li regular_filters is the remaining data to be handled by Django. 721*9c5db199SXin Li """ 722*9c5db199SXin Li regular_filters = dict(filter_data) 723*9c5db199SXin Li special_params = {} 724*9c5db199SXin Li for key in cls._SPECIAL_FILTER_KEYS: 725*9c5db199SXin Li if key in regular_filters: 726*9c5db199SXin Li special_params[key] = regular_filters.pop(key) 727*9c5db199SXin Li return special_params, regular_filters 728*9c5db199SXin Li 729*9c5db199SXin Li 730*9c5db199SXin Li @classmethod 731*9c5db199SXin Li def apply_presentation(cls, query, filter_data): 732*9c5db199SXin Li """ 733*9c5db199SXin Li Apply presentation parameters -- sorting and paging -- to the given 734*9c5db199SXin Li query. 735*9c5db199SXin Li @returns new query with presentation applied 736*9c5db199SXin Li """ 737*9c5db199SXin Li special_params, _ = cls._extract_special_params(filter_data) 738*9c5db199SXin Li sort_by = special_params.get('sort_by', None) 739*9c5db199SXin Li if sort_by: 740*9c5db199SXin Li assert isinstance(sort_by, list) or isinstance(sort_by, tuple) 741*9c5db199SXin Li query = query.extra(order_by=sort_by) 742*9c5db199SXin Li 743*9c5db199SXin Li query_start = special_params.get('query_start', None) 744*9c5db199SXin Li query_limit = special_params.get('query_limit', None) 745*9c5db199SXin Li if query_start is not None: 746*9c5db199SXin Li if query_limit is None: 747*9c5db199SXin Li raise ValueError('Cannot pass query_start without query_limit') 748*9c5db199SXin Li # query_limit is passed as a page size 749*9c5db199SXin Li query_limit += query_start 750*9c5db199SXin Li return query[query_start:query_limit] 751*9c5db199SXin Li 752*9c5db199SXin Li 753*9c5db199SXin Li @classmethod 754*9c5db199SXin Li def query_objects(cls, filter_data, valid_only=True, initial_query=None, 755*9c5db199SXin Li apply_presentation=True): 756*9c5db199SXin Li """\ 757*9c5db199SXin Li Returns a QuerySet object for querying the given model_class 758*9c5db199SXin Li with the given filter_data. Optional special arguments in 759*9c5db199SXin Li filter_data include: 760*9c5db199SXin Li -query_start: index of first return to return 761*9c5db199SXin Li -query_limit: maximum number of results to return 762*9c5db199SXin Li -sort_by: list of fields to sort on. prefixing a '-' onto a 763*9c5db199SXin Li field name changes the sort to descending order. 764*9c5db199SXin Li -extra_args: keyword args to pass to query.extra() (see Django 765*9c5db199SXin Li DB layer documentation) 766*9c5db199SXin Li -extra_where: extra WHERE clause to append 767*9c5db199SXin Li -no_distinct: if True, a DISTINCT will not be added to the SELECT 768*9c5db199SXin Li """ 769*9c5db199SXin Li special_params, regular_filters = cls._extract_special_params( 770*9c5db199SXin Li filter_data) 771*9c5db199SXin Li 772*9c5db199SXin Li if initial_query is None: 773*9c5db199SXin Li if valid_only: 774*9c5db199SXin Li initial_query = cls.get_valid_manager() 775*9c5db199SXin Li else: 776*9c5db199SXin Li initial_query = cls.objects 777*9c5db199SXin Li 778*9c5db199SXin Li query = initial_query.filter(**regular_filters) 779*9c5db199SXin Li 780*9c5db199SXin Li use_distinct = not special_params.get('no_distinct', False) 781*9c5db199SXin Li if use_distinct: 782*9c5db199SXin Li query = query.distinct() 783*9c5db199SXin Li 784*9c5db199SXin Li extra_args = special_params.get('extra_args', {}) 785*9c5db199SXin Li extra_where = special_params.get('extra_where', None) 786*9c5db199SXin Li if extra_where: 787*9c5db199SXin Li # escape %'s 788*9c5db199SXin Li extra_where = cls.objects.escape_user_sql(extra_where) 789*9c5db199SXin Li extra_args.setdefault('where', []).append(extra_where) 790*9c5db199SXin Li if extra_args: 791*9c5db199SXin Li query = query.extra(**extra_args) 792*9c5db199SXin Li # TODO: Use readonly connection for these queries. 793*9c5db199SXin Li # This has been disabled, because it's not used anyway, as the 794*9c5db199SXin Li # configured readonly user is the same as the real user anyway. 795*9c5db199SXin Li 796*9c5db199SXin Li if apply_presentation: 797*9c5db199SXin Li query = cls.apply_presentation(query, filter_data) 798*9c5db199SXin Li 799*9c5db199SXin Li return query 800*9c5db199SXin Li 801*9c5db199SXin Li 802*9c5db199SXin Li @classmethod 803*9c5db199SXin Li def query_count(cls, filter_data, initial_query=None): 804*9c5db199SXin Li """\ 805*9c5db199SXin Li Like query_objects, but retreive only the count of results. 806*9c5db199SXin Li """ 807*9c5db199SXin Li filter_data.pop('query_start', None) 808*9c5db199SXin Li filter_data.pop('query_limit', None) 809*9c5db199SXin Li query = cls.query_objects(filter_data, initial_query=initial_query) 810*9c5db199SXin Li return query.count() 811*9c5db199SXin Li 812*9c5db199SXin Li 813*9c5db199SXin Li @classmethod 814*9c5db199SXin Li def clean_object_dicts(cls, field_dicts): 815*9c5db199SXin Li """\ 816*9c5db199SXin Li Take a list of dicts corresponding to object (as returned by 817*9c5db199SXin Li query.values()) and clean the data to be more suitable for 818*9c5db199SXin Li returning to the user. 819*9c5db199SXin Li """ 820*9c5db199SXin Li for field_dict in field_dicts: 821*9c5db199SXin Li cls.clean_foreign_keys(field_dict) 822*9c5db199SXin Li cls._convert_booleans(field_dict) 823*9c5db199SXin Li cls.convert_human_readable_values(field_dict, 824*9c5db199SXin Li to_human_readable=True) 825*9c5db199SXin Li 826*9c5db199SXin Li 827*9c5db199SXin Li @classmethod 828*9c5db199SXin Li def list_objects(cls, filter_data, initial_query=None): 829*9c5db199SXin Li """\ 830*9c5db199SXin Li Like query_objects, but return a list of dictionaries. 831*9c5db199SXin Li """ 832*9c5db199SXin Li query = cls.query_objects(filter_data, initial_query=initial_query) 833*9c5db199SXin Li extra_fields = query.query.extra_select.keys() 834*9c5db199SXin Li field_dicts = [model_object.get_object_dict(extra_fields=extra_fields) 835*9c5db199SXin Li for model_object in query] 836*9c5db199SXin Li return field_dicts 837*9c5db199SXin Li 838*9c5db199SXin Li 839*9c5db199SXin Li @classmethod 840*9c5db199SXin Li def smart_get(cls, id_or_name, valid_only=True): 841*9c5db199SXin Li """\ 842*9c5db199SXin Li smart_get(integer) -> get object by ID 843*9c5db199SXin Li smart_get(string) -> get object by name_field 844*9c5db199SXin Li """ 845*9c5db199SXin Li if valid_only: 846*9c5db199SXin Li manager = cls.get_valid_manager() 847*9c5db199SXin Li else: 848*9c5db199SXin Li manager = cls.objects 849*9c5db199SXin Li 850*9c5db199SXin Li if isinstance(id_or_name, six.integer_types): 851*9c5db199SXin Li return manager.get(pk=id_or_name) 852*9c5db199SXin Li if isinstance(id_or_name, six.string_types) and hasattr( 853*9c5db199SXin Li cls, 'name_field'): 854*9c5db199SXin Li return manager.get(**{cls.name_field : id_or_name}) 855*9c5db199SXin Li raise ValueError( 856*9c5db199SXin Li 'Invalid positional argument: %s (%s)' % (id_or_name, 857*9c5db199SXin Li type(id_or_name))) 858*9c5db199SXin Li 859*9c5db199SXin Li 860*9c5db199SXin Li @classmethod 861*9c5db199SXin Li def smart_get_bulk(cls, id_or_name_list): 862*9c5db199SXin Li """Like smart_get, but for a list of ids or names""" 863*9c5db199SXin Li invalid_inputs = [] 864*9c5db199SXin Li result_objects = [] 865*9c5db199SXin Li for id_or_name in id_or_name_list: 866*9c5db199SXin Li try: 867*9c5db199SXin Li result_objects.append(cls.smart_get(id_or_name)) 868*9c5db199SXin Li except cls.DoesNotExist: 869*9c5db199SXin Li invalid_inputs.append(id_or_name) 870*9c5db199SXin Li if invalid_inputs: 871*9c5db199SXin Li raise cls.DoesNotExist('The following %ss do not exist: %s' 872*9c5db199SXin Li % (cls.__name__.lower(), 873*9c5db199SXin Li ', '.join(invalid_inputs))) 874*9c5db199SXin Li return result_objects 875*9c5db199SXin Li 876*9c5db199SXin Li 877*9c5db199SXin Li def get_object_dict(self, extra_fields=None): 878*9c5db199SXin Li """\ 879*9c5db199SXin Li Return a dictionary mapping fields to this object's values. @param 880*9c5db199SXin Li extra_fields: list of extra attribute names to include, in addition to 881*9c5db199SXin Li the fields defined on this object. 882*9c5db199SXin Li """ 883*9c5db199SXin Li fields = self.get_field_dict().keys() 884*9c5db199SXin Li if extra_fields: 885*9c5db199SXin Li fields += extra_fields 886*9c5db199SXin Li object_dict = dict((field_name, getattr(self, field_name)) 887*9c5db199SXin Li for field_name in fields) 888*9c5db199SXin Li self.clean_object_dicts([object_dict]) 889*9c5db199SXin Li self._postprocess_object_dict(object_dict) 890*9c5db199SXin Li return object_dict 891*9c5db199SXin Li 892*9c5db199SXin Li 893*9c5db199SXin Li def _postprocess_object_dict(self, object_dict): 894*9c5db199SXin Li """For subclasses to override.""" 895*9c5db199SXin Li pass 896*9c5db199SXin Li 897*9c5db199SXin Li 898*9c5db199SXin Li @classmethod 899*9c5db199SXin Li def get_valid_manager(cls): 900*9c5db199SXin Li return cls.objects 901*9c5db199SXin Li 902*9c5db199SXin Li 903*9c5db199SXin Li def _record_attributes(self, attributes): 904*9c5db199SXin Li """ 905*9c5db199SXin Li See on_attribute_changed. 906*9c5db199SXin Li """ 907*9c5db199SXin Li assert not isinstance(attributes, six.string_types) 908*9c5db199SXin Li self._recorded_attributes = dict((attribute, getattr(self, attribute)) 909*9c5db199SXin Li for attribute in attributes) 910*9c5db199SXin Li 911*9c5db199SXin Li 912*9c5db199SXin Li def _check_for_updated_attributes(self): 913*9c5db199SXin Li """ 914*9c5db199SXin Li See on_attribute_changed. 915*9c5db199SXin Li """ 916*9c5db199SXin Li for attribute, original_value in six.iteritems( 917*9c5db199SXin Li self._recorded_attributes): 918*9c5db199SXin Li new_value = getattr(self, attribute) 919*9c5db199SXin Li if original_value != new_value: 920*9c5db199SXin Li self.on_attribute_changed(attribute, original_value) 921*9c5db199SXin Li self._record_attributes(self._recorded_attributes.keys()) 922*9c5db199SXin Li 923*9c5db199SXin Li 924*9c5db199SXin Li def on_attribute_changed(self, attribute, old_value): 925*9c5db199SXin Li """ 926*9c5db199SXin Li Called whenever an attribute is updated. To be overridden. 927*9c5db199SXin Li 928*9c5db199SXin Li To use this method, you must: 929*9c5db199SXin Li * call _record_attributes() from __init__() (after making the super 930*9c5db199SXin Li call) with a list of attributes for which you want to be notified upon 931*9c5db199SXin Li change. 932*9c5db199SXin Li * call _check_for_updated_attributes() from save(). 933*9c5db199SXin Li """ 934*9c5db199SXin Li pass 935*9c5db199SXin Li 936*9c5db199SXin Li 937*9c5db199SXin Li def serialize(self, include_dependencies=True): 938*9c5db199SXin Li """Serializes the object with dependencies. 939*9c5db199SXin Li 940*9c5db199SXin Li The variable SERIALIZATION_LINKS_TO_FOLLOW defines which dependencies 941*9c5db199SXin Li this function will serialize with the object. 942*9c5db199SXin Li 943*9c5db199SXin Li @param include_dependencies: Whether or not to follow relations to 944*9c5db199SXin Li objects this object depends on. 945*9c5db199SXin Li This parameter is used when uploading 946*9c5db199SXin Li jobs from a shard to the main, as the 947*9c5db199SXin Li main already has all the dependent 948*9c5db199SXin Li objects. 949*9c5db199SXin Li 950*9c5db199SXin Li @returns: Dictionary representation of the object. 951*9c5db199SXin Li """ 952*9c5db199SXin Li serialized = {} 953*9c5db199SXin Li for field in self._meta.concrete_model._meta.local_fields: 954*9c5db199SXin Li if field.rel is None: 955*9c5db199SXin Li serialized[field.name] = field._get_val_from_obj(self) 956*9c5db199SXin Li elif field.name in self.SERIALIZATION_LINKS_TO_KEEP: 957*9c5db199SXin Li # attname will contain "_id" suffix for foreign keys, 958*9c5db199SXin Li # e.g. HostAttribute.host will be serialized as 'host_id'. 959*9c5db199SXin Li # Use it for easy deserialization. 960*9c5db199SXin Li serialized[field.attname] = field._get_val_from_obj(self) 961*9c5db199SXin Li 962*9c5db199SXin Li if include_dependencies: 963*9c5db199SXin Li for link in self.SERIALIZATION_LINKS_TO_FOLLOW: 964*9c5db199SXin Li serialized[link] = self._serialize_relation(link) 965*9c5db199SXin Li 966*9c5db199SXin Li return serialized 967*9c5db199SXin Li 968*9c5db199SXin Li 969*9c5db199SXin Li def _serialize_relation(self, link): 970*9c5db199SXin Li """Serializes dependent objects given the name of the relation. 971*9c5db199SXin Li 972*9c5db199SXin Li @param link: Name of the relation to take objects from. 973*9c5db199SXin Li 974*9c5db199SXin Li @returns For To-Many relationships a list of the serialized related 975*9c5db199SXin Li objects, for To-One relationships the serialized related object. 976*9c5db199SXin Li """ 977*9c5db199SXin Li try: 978*9c5db199SXin Li attr = getattr(self, link) 979*9c5db199SXin Li except AttributeError: 980*9c5db199SXin Li # One-To-One relationships that point to None may raise this 981*9c5db199SXin Li return None 982*9c5db199SXin Li 983*9c5db199SXin Li if attr is None: 984*9c5db199SXin Li return None 985*9c5db199SXin Li if hasattr(attr, 'all'): 986*9c5db199SXin Li return [obj.serialize() for obj in attr.all()] 987*9c5db199SXin Li return attr.serialize() 988*9c5db199SXin Li 989*9c5db199SXin Li 990*9c5db199SXin Li @classmethod 991*9c5db199SXin Li def _split_local_from_foreign_values(cls, data): 992*9c5db199SXin Li """This splits local from foreign values in a serialized object. 993*9c5db199SXin Li 994*9c5db199SXin Li @param data: The serialized object. 995*9c5db199SXin Li 996*9c5db199SXin Li @returns A tuple of two lists, both containing tuples in the form 997*9c5db199SXin Li (link_name, link_value). The first list contains all links 998*9c5db199SXin Li for local fields, the second one contains those for foreign 999*9c5db199SXin Li fields/objects. 1000*9c5db199SXin Li """ 1001*9c5db199SXin Li links_to_local_values, links_to_related_values = [], [] 1002*9c5db199SXin Li for link, value in six.iteritems(data): 1003*9c5db199SXin Li if link in cls.SERIALIZATION_LINKS_TO_FOLLOW: 1004*9c5db199SXin Li # It's a foreign key 1005*9c5db199SXin Li links_to_related_values.append((link, value)) 1006*9c5db199SXin Li else: 1007*9c5db199SXin Li # It's a local attribute or a foreign key 1008*9c5db199SXin Li # we don't want to follow. 1009*9c5db199SXin Li links_to_local_values.append((link, value)) 1010*9c5db199SXin Li return links_to_local_values, links_to_related_values 1011*9c5db199SXin Li 1012*9c5db199SXin Li 1013*9c5db199SXin Li @classmethod 1014*9c5db199SXin Li def _filter_update_allowed_fields(cls, data): 1015*9c5db199SXin Li """Filters data and returns only files that updates are allowed on. 1016*9c5db199SXin Li 1017*9c5db199SXin Li This is i.e. needed for syncing aborted bits from the main to shards. 1018*9c5db199SXin Li 1019*9c5db199SXin Li Local links are only allowed to be updated, if they are in 1020*9c5db199SXin Li SERIALIZATION_LOCAL_LINKS_TO_UPDATE. 1021*9c5db199SXin Li Overwriting existing values is allowed in order to be able to sync i.e. 1022*9c5db199SXin Li the aborted bit from the main to a shard. 1023*9c5db199SXin Li 1024*9c5db199SXin Li The allowlisting mechanism is in place to prevent overwriting local 1025*9c5db199SXin Li status: If all fields were overwritten, jobs would be completely be 1026*9c5db199SXin Li set back to their original (unstarted) state. 1027*9c5db199SXin Li 1028*9c5db199SXin Li @param data: List with tuples of the form (link_name, link_value), as 1029*9c5db199SXin Li returned by _split_local_from_foreign_values. 1030*9c5db199SXin Li 1031*9c5db199SXin Li @returns List of the same format as data, but only containing data for 1032*9c5db199SXin Li fields that updates are allowed on. 1033*9c5db199SXin Li """ 1034*9c5db199SXin Li return [pair for pair in data 1035*9c5db199SXin Li if pair[0] in cls.SERIALIZATION_LOCAL_LINKS_TO_UPDATE] 1036*9c5db199SXin Li 1037*9c5db199SXin Li 1038*9c5db199SXin Li @classmethod 1039*9c5db199SXin Li def delete_matching_record(cls, **filter_args): 1040*9c5db199SXin Li """Delete records matching the filter. 1041*9c5db199SXin Li 1042*9c5db199SXin Li @param filter_args: Arguments for the django filter 1043*9c5db199SXin Li used to locate the record to delete. 1044*9c5db199SXin Li """ 1045*9c5db199SXin Li try: 1046*9c5db199SXin Li existing_record = cls.objects.get(**filter_args) 1047*9c5db199SXin Li except cls.DoesNotExist: 1048*9c5db199SXin Li return 1049*9c5db199SXin Li existing_record.delete() 1050*9c5db199SXin Li 1051*9c5db199SXin Li 1052*9c5db199SXin Li def _deserialize_local(self, data): 1053*9c5db199SXin Li """Set local attributes from a list of tuples. 1054*9c5db199SXin Li 1055*9c5db199SXin Li @param data: List of tuples like returned by 1056*9c5db199SXin Li _split_local_from_foreign_values. 1057*9c5db199SXin Li """ 1058*9c5db199SXin Li if not data: 1059*9c5db199SXin Li return 1060*9c5db199SXin Li 1061*9c5db199SXin Li for link, value in data: 1062*9c5db199SXin Li setattr(self, link, value) 1063*9c5db199SXin Li # Overwridden save() methods are prone to errors, so don't execute them. 1064*9c5db199SXin Li # This is because: 1065*9c5db199SXin Li # - the overwritten methods depend on ACL groups that don't yet exist 1066*9c5db199SXin Li # and don't handle errors 1067*9c5db199SXin Li # - the overwritten methods think this object already exists in the db 1068*9c5db199SXin Li # because the id is already set 1069*9c5db199SXin Li super(type(self), self).save() 1070*9c5db199SXin Li 1071*9c5db199SXin Li 1072*9c5db199SXin Li def _deserialize_relations(self, data): 1073*9c5db199SXin Li """Set foreign attributes from a list of tuples. 1074*9c5db199SXin Li 1075*9c5db199SXin Li This deserialized the related objects using their own deserialize() 1076*9c5db199SXin Li function and then sets the relation. 1077*9c5db199SXin Li 1078*9c5db199SXin Li @param data: List of tuples like returned by 1079*9c5db199SXin Li _split_local_from_foreign_values. 1080*9c5db199SXin Li """ 1081*9c5db199SXin Li for link, value in data: 1082*9c5db199SXin Li self._deserialize_relation(link, value) 1083*9c5db199SXin Li # See comment in _deserialize_local 1084*9c5db199SXin Li super(type(self), self).save() 1085*9c5db199SXin Li 1086*9c5db199SXin Li 1087*9c5db199SXin Li @classmethod 1088*9c5db199SXin Li def get_record(cls, data): 1089*9c5db199SXin Li """Retrieve a record with the data in the given input arg. 1090*9c5db199SXin Li 1091*9c5db199SXin Li @param data: A dictionary containing the information to use in a query 1092*9c5db199SXin Li for data. If child models have different constraints of 1093*9c5db199SXin Li uniqueness they should override this model. 1094*9c5db199SXin Li 1095*9c5db199SXin Li @return: An object with matching data. 1096*9c5db199SXin Li 1097*9c5db199SXin Li @raises DoesNotExist: If a record with the given data doesn't exist. 1098*9c5db199SXin Li """ 1099*9c5db199SXin Li return cls.objects.get(id=data['id']) 1100*9c5db199SXin Li 1101*9c5db199SXin Li 1102*9c5db199SXin Li @classmethod 1103*9c5db199SXin Li def deserialize(cls, data): 1104*9c5db199SXin Li """Recursively deserializes and saves an object with it's dependencies. 1105*9c5db199SXin Li 1106*9c5db199SXin Li This takes the result of the serialize method and creates objects 1107*9c5db199SXin Li in the database that are just like the original. 1108*9c5db199SXin Li 1109*9c5db199SXin Li If an object of the same type with the same id already exists, it's 1110*9c5db199SXin Li local values will be left untouched, unless they are explicitly 1111*9c5db199SXin Li allowlisted in SERIALIZATION_LOCAL_LINKS_TO_UPDATE. 1112*9c5db199SXin Li 1113*9c5db199SXin Li Deserialize will always recursively propagate to all related objects 1114*9c5db199SXin Li present in data though. 1115*9c5db199SXin Li I.e. this is necessary to add users to an already existing acl-group. 1116*9c5db199SXin Li 1117*9c5db199SXin Li @param data: Representation of an object and its dependencies, as 1118*9c5db199SXin Li returned by serialize. 1119*9c5db199SXin Li 1120*9c5db199SXin Li @returns: The object represented by data if it didn't exist before, 1121*9c5db199SXin Li otherwise the object that existed before and has the same type 1122*9c5db199SXin Li and id as the one described by data. 1123*9c5db199SXin Li """ 1124*9c5db199SXin Li if data is None: 1125*9c5db199SXin Li return None 1126*9c5db199SXin Li 1127*9c5db199SXin Li local, related = cls._split_local_from_foreign_values(data) 1128*9c5db199SXin Li try: 1129*9c5db199SXin Li instance = cls.get_record(data) 1130*9c5db199SXin Li local = cls._filter_update_allowed_fields(local) 1131*9c5db199SXin Li except cls.DoesNotExist: 1132*9c5db199SXin Li instance = cls() 1133*9c5db199SXin Li 1134*9c5db199SXin Li instance._deserialize_local(local) 1135*9c5db199SXin Li instance._deserialize_relations(related) 1136*9c5db199SXin Li 1137*9c5db199SXin Li return instance 1138*9c5db199SXin Li 1139*9c5db199SXin Li 1140*9c5db199SXin Li def _check_update_from_shard(self, shard, updated_serialized, 1141*9c5db199SXin Li *args, **kwargs): 1142*9c5db199SXin Li """Check if an update sent from a shard is legitimate. 1143*9c5db199SXin Li 1144*9c5db199SXin Li @raises error.UnallowedRecordsSentToMain if an update is not 1145*9c5db199SXin Li legitimate. 1146*9c5db199SXin Li """ 1147*9c5db199SXin Li raise NotImplementedError( 1148*9c5db199SXin Li '_check_update_from_shard must be implemented by subclass %s ' 1149*9c5db199SXin Li 'for type %s' % type(self)) 1150*9c5db199SXin Li 1151*9c5db199SXin Li 1152*9c5db199SXin Li @transaction.commit_on_success 1153*9c5db199SXin Li def update_from_serialized(self, serialized): 1154*9c5db199SXin Li """Updates local fields of an existing object from a serialized form. 1155*9c5db199SXin Li 1156*9c5db199SXin Li This is different than the normal deserialize() in the way that it 1157*9c5db199SXin Li does update local values, which deserialize doesn't, but doesn't 1158*9c5db199SXin Li recursively propagate to related objects, which deserialize() does. 1159*9c5db199SXin Li 1160*9c5db199SXin Li The use case of this function is to update job records on the main 1161*9c5db199SXin Li after the jobs have been executed on a shard, as the main is not 1162*9c5db199SXin Li interested in updates for users, labels, specialtasks, etc. 1163*9c5db199SXin Li 1164*9c5db199SXin Li @param serialized: Representation of an object and its dependencies, as 1165*9c5db199SXin Li returned by serialize. 1166*9c5db199SXin Li 1167*9c5db199SXin Li @raises ValueError: if serialized contains related objects, i.e. not 1168*9c5db199SXin Li only local fields. 1169*9c5db199SXin Li """ 1170*9c5db199SXin Li local, related = ( 1171*9c5db199SXin Li self._split_local_from_foreign_values(serialized)) 1172*9c5db199SXin Li if related: 1173*9c5db199SXin Li raise ValueError('Serialized must not contain foreign ' 1174*9c5db199SXin Li 'objects: %s' % related) 1175*9c5db199SXin Li 1176*9c5db199SXin Li self._deserialize_local(local) 1177*9c5db199SXin Li 1178*9c5db199SXin Li 1179*9c5db199SXin Li def custom_deserialize_relation(self, link, data): 1180*9c5db199SXin Li """Allows overriding the deserialization behaviour by subclasses.""" 1181*9c5db199SXin Li raise NotImplementedError( 1182*9c5db199SXin Li 'custom_deserialize_relation must be implemented by subclass %s ' 1183*9c5db199SXin Li 'for relation %s' % (type(self), link)) 1184*9c5db199SXin Li 1185*9c5db199SXin Li 1186*9c5db199SXin Li def _deserialize_relation(self, link, data): 1187*9c5db199SXin Li """Deserializes related objects and sets references on this object. 1188*9c5db199SXin Li 1189*9c5db199SXin Li Relations that point to a list of objects are handled automatically. 1190*9c5db199SXin Li For many-to-one or one-to-one relations custom_deserialize_relation 1191*9c5db199SXin Li must be overridden by the subclass. 1192*9c5db199SXin Li 1193*9c5db199SXin Li Related objects are deserialized using their deserialize() method. 1194*9c5db199SXin Li Thereby they and their dependencies are created if they don't exist 1195*9c5db199SXin Li and saved to the database. 1196*9c5db199SXin Li 1197*9c5db199SXin Li @param link: Name of the relation. 1198*9c5db199SXin Li @param data: Serialized representation of the related object(s). 1199*9c5db199SXin Li This means a list of dictionaries for to-many relations, 1200*9c5db199SXin Li just a dictionary for to-one relations. 1201*9c5db199SXin Li """ 1202*9c5db199SXin Li field = getattr(self, link) 1203*9c5db199SXin Li 1204*9c5db199SXin Li if field and hasattr(field, 'all'): 1205*9c5db199SXin Li self._deserialize_2m_relation(link, data, field.model) 1206*9c5db199SXin Li else: 1207*9c5db199SXin Li self.custom_deserialize_relation(link, data) 1208*9c5db199SXin Li 1209*9c5db199SXin Li 1210*9c5db199SXin Li def _deserialize_2m_relation(self, link, data, related_class): 1211*9c5db199SXin Li """Deserialize related objects for one to-many relationship. 1212*9c5db199SXin Li 1213*9c5db199SXin Li @param link: Name of the relation. 1214*9c5db199SXin Li @param data: Serialized representation of the related objects. 1215*9c5db199SXin Li This is a list with of dictionaries. 1216*9c5db199SXin Li @param related_class: A class representing a django model, with which 1217*9c5db199SXin Li this class has a one-to-many relationship. 1218*9c5db199SXin Li """ 1219*9c5db199SXin Li relation_set = getattr(self, link) 1220*9c5db199SXin Li if related_class == self.get_attribute_model(): 1221*9c5db199SXin Li # When deserializing a model together with 1222*9c5db199SXin Li # its attributes, clear all the exising attributes to ensure 1223*9c5db199SXin Li # db consistency. Note 'update' won't be sufficient, as we also 1224*9c5db199SXin Li # want to remove any attributes that no longer exist in |data|. 1225*9c5db199SXin Li # 1226*9c5db199SXin Li # core_filters is a dictionary of filters, defines how 1227*9c5db199SXin Li # RelatedMangager would query for the 1-to-many relationship. E.g. 1228*9c5db199SXin Li # Host.objects.get( 1229*9c5db199SXin Li # id=20).hostattribute_set.core_filters = {host_id:20} 1230*9c5db199SXin Li # We use it to delete objects related to the current object. 1231*9c5db199SXin Li related_class.objects.filter(**relation_set.core_filters).delete() 1232*9c5db199SXin Li for serialized in data: 1233*9c5db199SXin Li relation_set.add(related_class.deserialize(serialized)) 1234*9c5db199SXin Li 1235*9c5db199SXin Li 1236*9c5db199SXin Li @classmethod 1237*9c5db199SXin Li def get_attribute_model(cls): 1238*9c5db199SXin Li """Return the attribute model. 1239*9c5db199SXin Li 1240*9c5db199SXin Li Subclass with attribute-like model should override this to 1241*9c5db199SXin Li return the attribute model class. This method will be 1242*9c5db199SXin Li called by _deserialize_2m_relation to determine whether 1243*9c5db199SXin Li to clear the one-to-many relations first on deserialization of object. 1244*9c5db199SXin Li """ 1245*9c5db199SXin Li return None 1246*9c5db199SXin Li 1247*9c5db199SXin Li 1248*9c5db199SXin Liclass ModelWithInvalid(ModelExtensions): 1249*9c5db199SXin Li """ 1250*9c5db199SXin Li Overrides model methods save() and delete() to support invalidation in 1251*9c5db199SXin Li place of actual deletion. Subclasses must have a boolean "invalid" 1252*9c5db199SXin Li field. 1253*9c5db199SXin Li """ 1254*9c5db199SXin Li 1255*9c5db199SXin Li def save(self, *args, **kwargs): 1256*9c5db199SXin Li """Saves the model""" 1257*9c5db199SXin Li first_time = (self.id is None) 1258*9c5db199SXin Li if first_time: 1259*9c5db199SXin Li # see if this object was previously added and invalidated 1260*9c5db199SXin Li my_name = getattr(self, self.name_field) 1261*9c5db199SXin Li filters = {self.name_field : my_name, 'invalid' : True} 1262*9c5db199SXin Li try: 1263*9c5db199SXin Li old_object = self.__class__.objects.get(**filters) 1264*9c5db199SXin Li self.resurrect_object(old_object) 1265*9c5db199SXin Li except self.DoesNotExist: 1266*9c5db199SXin Li # no existing object 1267*9c5db199SXin Li pass 1268*9c5db199SXin Li 1269*9c5db199SXin Li super(ModelWithInvalid, self).save(*args, **kwargs) 1270*9c5db199SXin Li 1271*9c5db199SXin Li 1272*9c5db199SXin Li def resurrect_object(self, old_object): 1273*9c5db199SXin Li """ 1274*9c5db199SXin Li Called when self is about to be saved for the first time and is actually 1275*9c5db199SXin Li "undeleting" a previously deleted object. Can be overridden by 1276*9c5db199SXin Li subclasses to copy data as desired from the deleted entry (but this 1277*9c5db199SXin Li superclass implementation must normally be called). 1278*9c5db199SXin Li """ 1279*9c5db199SXin Li self.id = old_object.id 1280*9c5db199SXin Li 1281*9c5db199SXin Li 1282*9c5db199SXin Li def clean_object(self): 1283*9c5db199SXin Li """ 1284*9c5db199SXin Li This method is called when an object is marked invalid. 1285*9c5db199SXin Li Subclasses should override this to clean up relationships that 1286*9c5db199SXin Li should no longer exist if the object were deleted. 1287*9c5db199SXin Li """ 1288*9c5db199SXin Li pass 1289*9c5db199SXin Li 1290*9c5db199SXin Li 1291*9c5db199SXin Li def delete(self): 1292*9c5db199SXin Li """Deletes the model""" 1293*9c5db199SXin Li self.invalid = self.invalid 1294*9c5db199SXin Li assert not self.invalid 1295*9c5db199SXin Li self.invalid = True 1296*9c5db199SXin Li self.save() 1297*9c5db199SXin Li self.clean_object() 1298*9c5db199SXin Li 1299*9c5db199SXin Li 1300*9c5db199SXin Li @classmethod 1301*9c5db199SXin Li def get_valid_manager(cls): 1302*9c5db199SXin Li return cls.valid_objects 1303*9c5db199SXin Li 1304*9c5db199SXin Li 1305*9c5db199SXin Li class Manipulator(object): 1306*9c5db199SXin Li """ 1307*9c5db199SXin Li Force default manipulators to look only at valid objects - 1308*9c5db199SXin Li otherwise they will match against invalid objects when checking 1309*9c5db199SXin Li uniqueness. 1310*9c5db199SXin Li """ 1311*9c5db199SXin Li @classmethod 1312*9c5db199SXin Li def _prepare(cls, model): 1313*9c5db199SXin Li super(ModelWithInvalid.Manipulator, cls)._prepare(model) 1314*9c5db199SXin Li cls.manager = model.valid_objects 1315*9c5db199SXin Li 1316*9c5db199SXin Li 1317*9c5db199SXin Liclass ModelWithAttributes(object): 1318*9c5db199SXin Li """ 1319*9c5db199SXin Li Mixin class for models that have an attribute model associated with them. 1320*9c5db199SXin Li The attribute model is assumed to have its value field named "value". 1321*9c5db199SXin Li """ 1322*9c5db199SXin Li 1323*9c5db199SXin Li def _get_attribute_model_and_args(self, attribute): 1324*9c5db199SXin Li """ 1325*9c5db199SXin Li Subclasses should override this to return a tuple (attribute_model, 1326*9c5db199SXin Li keyword_args), where attribute_model is a model class and keyword_args 1327*9c5db199SXin Li is a dict of args to pass to attribute_model.objects.get() to get an 1328*9c5db199SXin Li instance of the given attribute on this object. 1329*9c5db199SXin Li """ 1330*9c5db199SXin Li raise NotImplementedError 1331*9c5db199SXin Li 1332*9c5db199SXin Li 1333*9c5db199SXin Li def _is_replaced_by_static_attribute(self, attribute): 1334*9c5db199SXin Li """ 1335*9c5db199SXin Li Subclasses could override this to indicate whether it has static 1336*9c5db199SXin Li attributes. 1337*9c5db199SXin Li """ 1338*9c5db199SXin Li return False 1339*9c5db199SXin Li 1340*9c5db199SXin Li 1341*9c5db199SXin Li def set_attribute(self, attribute, value): 1342*9c5db199SXin Li if self._is_replaced_by_static_attribute(attribute): 1343*9c5db199SXin Li raise error.UnmodifiableAttributeException( 1344*9c5db199SXin Li 'Failed to set attribute "%s" for host "%s" since it ' 1345*9c5db199SXin Li 'is static. Use go/chromeos-skylab-inventory-tools to ' 1346*9c5db199SXin Li 'modify this attribute.' % (attribute, self.hostname)) 1347*9c5db199SXin Li 1348*9c5db199SXin Li attribute_model, get_args = self._get_attribute_model_and_args( 1349*9c5db199SXin Li attribute) 1350*9c5db199SXin Li attribute_object, _ = attribute_model.objects.get_or_create(**get_args) 1351*9c5db199SXin Li attribute_object.value = value 1352*9c5db199SXin Li attribute_object.save() 1353*9c5db199SXin Li 1354*9c5db199SXin Li 1355*9c5db199SXin Li def delete_attribute(self, attribute): 1356*9c5db199SXin Li """Deletes an attribute""" 1357*9c5db199SXin Li if self._is_replaced_by_static_attribute(attribute): 1358*9c5db199SXin Li raise error.UnmodifiableAttributeException( 1359*9c5db199SXin Li 'Failed to delete attribute "%s" for host "%s" since it ' 1360*9c5db199SXin Li 'is static. Use go/chromeos-skylab-inventory-tools to ' 1361*9c5db199SXin Li 'modify this attribute.' % (attribute, self.hostname)) 1362*9c5db199SXin Li 1363*9c5db199SXin Li attribute_model, get_args = self._get_attribute_model_and_args( 1364*9c5db199SXin Li attribute) 1365*9c5db199SXin Li try: 1366*9c5db199SXin Li attribute_model.objects.get(**get_args).delete() 1367*9c5db199SXin Li except attribute_model.DoesNotExist: 1368*9c5db199SXin Li pass 1369*9c5db199SXin Li 1370*9c5db199SXin Li 1371*9c5db199SXin Li def set_or_delete_attribute(self, attribute, value): 1372*9c5db199SXin Li if value is None: 1373*9c5db199SXin Li self.delete_attribute(attribute) 1374*9c5db199SXin Li else: 1375*9c5db199SXin Li self.set_attribute(attribute, value) 1376*9c5db199SXin Li 1377*9c5db199SXin Li 1378*9c5db199SXin Liclass ModelWithHashManager(dbmodels.Manager): 1379*9c5db199SXin Li """Manager for use with the ModelWithHash abstract model class""" 1380*9c5db199SXin Li 1381*9c5db199SXin Li def create(self, **kwargs): 1382*9c5db199SXin Li """Always raises exception.""" 1383*9c5db199SXin Li raise Exception('ModelWithHash manager should use get_or_create() ' 1384*9c5db199SXin Li 'instead of create()') 1385*9c5db199SXin Li 1386*9c5db199SXin Li 1387*9c5db199SXin Li def get_or_create(self, **kwargs): 1388*9c5db199SXin Li kwargs['the_hash'] = self.model._compute_hash(**kwargs) 1389*9c5db199SXin Li return super(ModelWithHashManager, self).get_or_create(**kwargs) 1390*9c5db199SXin Li 1391*9c5db199SXin Li 1392*9c5db199SXin Liclass ModelWithHash(dbmodels.Model): 1393*9c5db199SXin Li """Superclass with methods for dealing with a hash column""" 1394*9c5db199SXin Li 1395*9c5db199SXin Li the_hash = dbmodels.CharField(max_length=40, unique=True) 1396*9c5db199SXin Li 1397*9c5db199SXin Li objects = ModelWithHashManager() 1398*9c5db199SXin Li 1399*9c5db199SXin Li class Meta: 1400*9c5db199SXin Li """Overrides dbmodels.Model.Meta.""" 1401*9c5db199SXin Li abstract = True 1402*9c5db199SXin Li 1403*9c5db199SXin Li 1404*9c5db199SXin Li @classmethod 1405*9c5db199SXin Li def _compute_hash(cls, **kwargs): 1406*9c5db199SXin Li raise NotImplementedError('Subclasses must override _compute_hash()') 1407*9c5db199SXin Li 1408*9c5db199SXin Li 1409*9c5db199SXin Li def save(self, force_insert=False, **kwargs): 1410*9c5db199SXin Li """Prevents saving the model in most cases 1411*9c5db199SXin Li 1412*9c5db199SXin Li We want these models to be immutable, so the generic save() operation 1413*9c5db199SXin Li will not work. These models should be instantiated through their the 1414*9c5db199SXin Li model.objects.get_or_create() method instead. 1415*9c5db199SXin Li 1416*9c5db199SXin Li The exception is that save(force_insert=True) will be allowed, since 1417*9c5db199SXin Li that creates a new row. However, the preferred way to make instances of 1418*9c5db199SXin Li these models is through the get_or_create() method. 1419*9c5db199SXin Li """ 1420*9c5db199SXin Li if not force_insert: 1421*9c5db199SXin Li # Allow a forced insert to happen; if it's a duplicate, the unique 1422*9c5db199SXin Li # constraint will catch it later anyways 1423*9c5db199SXin Li raise Exception('ModelWithHash is immutable') 1424*9c5db199SXin Li super(ModelWithHash, self).save(force_insert=force_insert, **kwargs) 1425