xref: /aosp_15_r20/external/autotest/frontend/afe/model_logic.py (revision 9c5db1993ded3edbeafc8092d69fe5de2ee02df7)
1*9c5db199SXin Li"""
2*9c5db199SXin LiExtensions to Django's model logic.
3*9c5db199SXin Li"""
4*9c5db199SXin Li
5*9c5db199SXin Liimport django.core.exceptions
6*9c5db199SXin Liimport django.db.models.sql.where
7*9c5db199SXin Liimport six
8*9c5db199SXin Lifrom autotest_lib.client.common_lib import error
9*9c5db199SXin Lifrom autotest_lib.frontend.afe import rdb_model_extensions
10*9c5db199SXin Lifrom django.db import connection, connections
11*9c5db199SXin Lifrom django.db import models as dbmodels
12*9c5db199SXin Lifrom django.db import transaction
13*9c5db199SXin Lifrom django.db.models.sql import query
14*9c5db199SXin Li
15*9c5db199SXin Li
16*9c5db199SXin Liclass ValidationError(django.core.exceptions.ValidationError):
17*9c5db199SXin Li    """\
18*9c5db199SXin Li    Data validation error in adding or updating an object. The associated
19*9c5db199SXin Li    value is a dictionary mapping field names to error strings.
20*9c5db199SXin Li    """
21*9c5db199SXin Li
22*9c5db199SXin Lidef _quote_name(name):
23*9c5db199SXin Li    """Shorthand for connection.ops.quote_name()."""
24*9c5db199SXin Li    return connection.ops.quote_name(name)
25*9c5db199SXin Li
26*9c5db199SXin Li
27*9c5db199SXin Liclass LeasedHostManager(dbmodels.Manager):
28*9c5db199SXin Li    """Query manager for unleased, unlocked hosts.
29*9c5db199SXin Li    """
30*9c5db199SXin Li    def get_query_set(self):
31*9c5db199SXin Li        return (super(LeasedHostManager, self).get_query_set().filter(
32*9c5db199SXin Li                leased=0, locked=0))
33*9c5db199SXin Li
34*9c5db199SXin Li
35*9c5db199SXin Liclass ExtendedManager(dbmodels.Manager):
36*9c5db199SXin Li    """\
37*9c5db199SXin Li    Extended manager supporting subquery filtering.
38*9c5db199SXin Li    """
39*9c5db199SXin Li
40*9c5db199SXin Li    class CustomQuery(query.Query):
41*9c5db199SXin Li        """A custom query"""
42*9c5db199SXin Li
43*9c5db199SXin Li        def __init__(self, *args, **kwargs):
44*9c5db199SXin Li            super(ExtendedManager.CustomQuery, self).__init__(*args, **kwargs)
45*9c5db199SXin Li            self._custom_joins = []
46*9c5db199SXin Li
47*9c5db199SXin Li
48*9c5db199SXin Li        def clone(self, klass=None, **kwargs):
49*9c5db199SXin Li            """Clones the query and returns the clone."""
50*9c5db199SXin Li            obj = super(ExtendedManager.CustomQuery, self).clone(klass)
51*9c5db199SXin Li            obj._custom_joins = list(self._custom_joins)
52*9c5db199SXin Li            return obj
53*9c5db199SXin Li
54*9c5db199SXin Li
55*9c5db199SXin Li        def combine(self, rhs, connector):
56*9c5db199SXin Li            """Combines query with another query."""
57*9c5db199SXin Li            super(ExtendedManager.CustomQuery, self).combine(rhs, connector)
58*9c5db199SXin Li            if hasattr(rhs, '_custom_joins'):
59*9c5db199SXin Li                self._custom_joins.extend(rhs._custom_joins)
60*9c5db199SXin Li
61*9c5db199SXin Li
62*9c5db199SXin Li        def add_custom_join(self, table, condition, join_type,
63*9c5db199SXin Li                            condition_values=(), alias=None):
64*9c5db199SXin Li            """Adds a custom join to the query."""
65*9c5db199SXin Li            if alias is None:
66*9c5db199SXin Li                alias = table
67*9c5db199SXin Li            join_dict = dict(table=table,
68*9c5db199SXin Li                             condition=condition,
69*9c5db199SXin Li                             condition_values=condition_values,
70*9c5db199SXin Li                             join_type=join_type,
71*9c5db199SXin Li                             alias=alias)
72*9c5db199SXin Li            self._custom_joins.append(join_dict)
73*9c5db199SXin Li
74*9c5db199SXin Li
75*9c5db199SXin Li        @classmethod
76*9c5db199SXin Li        def convert_query(self, query_set):
77*9c5db199SXin Li            """
78*9c5db199SXin Li            Convert the query set's "query" attribute to a CustomQuery.
79*9c5db199SXin Li            """
80*9c5db199SXin Li            # Make a copy of the query set
81*9c5db199SXin Li            query_set = query_set.all()
82*9c5db199SXin Li            query_set.query = query_set.query.clone(
83*9c5db199SXin Li                    klass=ExtendedManager.CustomQuery,
84*9c5db199SXin Li                    _custom_joins=[])
85*9c5db199SXin Li            return query_set
86*9c5db199SXin Li
87*9c5db199SXin Li
88*9c5db199SXin Li    class _WhereClause(object):
89*9c5db199SXin Li        """Object allowing us to inject arbitrary SQL into Django queries.
90*9c5db199SXin Li
91*9c5db199SXin Li        By using this instead of extra(where=...), we can still freely combine
92*9c5db199SXin Li        queries with & and |.
93*9c5db199SXin Li        """
94*9c5db199SXin Li        def __init__(self, clause, values=()):
95*9c5db199SXin Li            self._clause = clause
96*9c5db199SXin Li            self._values = values
97*9c5db199SXin Li
98*9c5db199SXin Li
99*9c5db199SXin Li        def as_sql(self, qn=None, connection=None):
100*9c5db199SXin Li            """Converts the clause to SQL and returns it."""
101*9c5db199SXin Li            return self._clause, self._values
102*9c5db199SXin Li
103*9c5db199SXin Li
104*9c5db199SXin Li        def relabel_aliases(self, change_map):
105*9c5db199SXin Li            """Does nothing."""
106*9c5db199SXin Li            return
107*9c5db199SXin Li
108*9c5db199SXin Li
109*9c5db199SXin Li    def add_join(self, query_set, join_table, join_key, join_condition='',
110*9c5db199SXin Li                 join_condition_values=(), join_from_key=None, alias=None,
111*9c5db199SXin Li                 suffix='', exclude=False, force_left_join=False):
112*9c5db199SXin Li        """Add a join to query_set.
113*9c5db199SXin Li
114*9c5db199SXin Li        Join looks like this:
115*9c5db199SXin Li                (INNER|LEFT) JOIN <join_table> AS <alias>
116*9c5db199SXin Li                    ON (<this table>.<join_from_key> = <join_table>.<join_key>
117*9c5db199SXin Li                        and <join_condition>)
118*9c5db199SXin Li
119*9c5db199SXin Li        @param join_table table to join to
120*9c5db199SXin Li        @param join_key field referencing back to this model to use for the join
121*9c5db199SXin Li        @param join_condition extra condition for the ON clause of the join
122*9c5db199SXin Li        @param join_condition_values values to substitute into join_condition
123*9c5db199SXin Li        @param join_from_key column on this model to join from.
124*9c5db199SXin Li        @param alias alias to use for for join
125*9c5db199SXin Li        @param suffix suffix to add to join_table for the join alias, if no
126*9c5db199SXin Li                alias is provided
127*9c5db199SXin Li        @param exclude if true, exclude rows that match this join (will use a
128*9c5db199SXin Li        LEFT OUTER JOIN and an appropriate WHERE condition)
129*9c5db199SXin Li        @param force_left_join - if true, a LEFT OUTER JOIN will be used
130*9c5db199SXin Li        instead of an INNER JOIN regardless of other options
131*9c5db199SXin Li        """
132*9c5db199SXin Li        join_from_table = query_set.model._meta.db_table
133*9c5db199SXin Li        if join_from_key is None:
134*9c5db199SXin Li            join_from_key = self.model._meta.pk.name
135*9c5db199SXin Li        if alias is None:
136*9c5db199SXin Li            alias = join_table + suffix
137*9c5db199SXin Li        full_join_key = _quote_name(alias) + '.' + _quote_name(join_key)
138*9c5db199SXin Li        full_join_condition = '%s = %s.%s' % (full_join_key,
139*9c5db199SXin Li                                              _quote_name(join_from_table),
140*9c5db199SXin Li                                              _quote_name(join_from_key))
141*9c5db199SXin Li        if join_condition:
142*9c5db199SXin Li            full_join_condition += ' AND (' + join_condition + ')'
143*9c5db199SXin Li        if exclude or force_left_join:
144*9c5db199SXin Li            join_type = query_set.query.LOUTER
145*9c5db199SXin Li        else:
146*9c5db199SXin Li            join_type = query_set.query.INNER
147*9c5db199SXin Li
148*9c5db199SXin Li        query_set = self.CustomQuery.convert_query(query_set)
149*9c5db199SXin Li        query_set.query.add_custom_join(join_table,
150*9c5db199SXin Li                                        full_join_condition,
151*9c5db199SXin Li                                        join_type,
152*9c5db199SXin Li                                        condition_values=join_condition_values,
153*9c5db199SXin Li                                        alias=alias)
154*9c5db199SXin Li
155*9c5db199SXin Li        if exclude:
156*9c5db199SXin Li            query_set = query_set.extra(where=[full_join_key + ' IS NULL'])
157*9c5db199SXin Li
158*9c5db199SXin Li        return query_set
159*9c5db199SXin Li
160*9c5db199SXin Li
161*9c5db199SXin Li    def _info_for_many_to_one_join(self, field, join_to_query, alias):
162*9c5db199SXin Li        """
163*9c5db199SXin Li        @param field: the ForeignKey field on the related model
164*9c5db199SXin Li        @param join_to_query: the query over the related model that we're
165*9c5db199SXin Li                joining to
166*9c5db199SXin Li        @param alias: alias of joined table
167*9c5db199SXin Li        """
168*9c5db199SXin Li        info = {}
169*9c5db199SXin Li        rhs_table = join_to_query.model._meta.db_table
170*9c5db199SXin Li        info['rhs_table'] = rhs_table
171*9c5db199SXin Li        info['rhs_column'] = field.column
172*9c5db199SXin Li        info['lhs_column'] = field.rel.get_related_field().column
173*9c5db199SXin Li        rhs_where = join_to_query.query.where
174*9c5db199SXin Li        rhs_where.relabel_aliases({rhs_table: alias})
175*9c5db199SXin Li        compiler = join_to_query.query.get_compiler(using=join_to_query.db)
176*9c5db199SXin Li        initial_clause, values = compiler.as_sql()
177*9c5db199SXin Li        # initial_clause is compiled from `join_to_query`, which is a SELECT
178*9c5db199SXin Li        # query returns at most one record. For it to be used in WHERE clause,
179*9c5db199SXin Li        # it must be converted to a boolean value using EXISTS.
180*9c5db199SXin Li        all_clauses = ('EXISTS (%s)' % initial_clause,)
181*9c5db199SXin Li        if hasattr(join_to_query.query, 'extra_where'):
182*9c5db199SXin Li            all_clauses += join_to_query.query.extra_where
183*9c5db199SXin Li        info['where_clause'] = (
184*9c5db199SXin Li                    ' AND '.join('(%s)' % clause for clause in all_clauses))
185*9c5db199SXin Li        info['values'] = values
186*9c5db199SXin Li        return info
187*9c5db199SXin Li
188*9c5db199SXin Li
189*9c5db199SXin Li    def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias,
190*9c5db199SXin Li                                    m2m_is_on_this_model):
191*9c5db199SXin Li        """
192*9c5db199SXin Li        @param m2m_field: a Django field representing the M2M relationship.
193*9c5db199SXin Li                It uses a pivot table with the following structure:
194*9c5db199SXin Li                this model table <---> M2M pivot table <---> joined model table
195*9c5db199SXin Li        @param join_to_query: the query over the related model that we're
196*9c5db199SXin Li                joining to.
197*9c5db199SXin Li        @param alias: alias of joined table
198*9c5db199SXin Li        """
199*9c5db199SXin Li        if m2m_is_on_this_model:
200*9c5db199SXin Li            # referenced field on this model
201*9c5db199SXin Li            lhs_id_field = self.model._meta.pk
202*9c5db199SXin Li            # foreign key on the pivot table referencing lhs_id_field
203*9c5db199SXin Li            m2m_lhs_column = m2m_field.m2m_column_name()
204*9c5db199SXin Li            # foreign key on the pivot table referencing rhd_id_field
205*9c5db199SXin Li            m2m_rhs_column = m2m_field.m2m_reverse_name()
206*9c5db199SXin Li            # referenced field on related model
207*9c5db199SXin Li            rhs_id_field = m2m_field.rel.get_related_field()
208*9c5db199SXin Li        else:
209*9c5db199SXin Li            lhs_id_field = m2m_field.rel.get_related_field()
210*9c5db199SXin Li            m2m_lhs_column = m2m_field.m2m_reverse_name()
211*9c5db199SXin Li            m2m_rhs_column = m2m_field.m2m_column_name()
212*9c5db199SXin Li            rhs_id_field = join_to_query.model._meta.pk
213*9c5db199SXin Li
214*9c5db199SXin Li        info = {}
215*9c5db199SXin Li        info['rhs_table'] = m2m_field.m2m_db_table()
216*9c5db199SXin Li        info['rhs_column'] = m2m_lhs_column
217*9c5db199SXin Li        info['lhs_column'] = lhs_id_field.column
218*9c5db199SXin Li
219*9c5db199SXin Li        # select the ID of related models relevant to this join.  we can only do
220*9c5db199SXin Li        # a single join, so we need to gather this information up front and
221*9c5db199SXin Li        # include it in the join condition.
222*9c5db199SXin Li        rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True)
223*9c5db199SXin Li        assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only '
224*9c5db199SXin Li                                   'match a single related object.')
225*9c5db199SXin Li        rhs_id = rhs_ids[0]
226*9c5db199SXin Li
227*9c5db199SXin Li        info['where_clause'] = '%s.%s = %s' % (_quote_name(alias),
228*9c5db199SXin Li                                               _quote_name(m2m_rhs_column),
229*9c5db199SXin Li                                               rhs_id)
230*9c5db199SXin Li        info['values'] = ()
231*9c5db199SXin Li        return info
232*9c5db199SXin Li
233*9c5db199SXin Li
234*9c5db199SXin Li    def join_custom_field(self, query_set, join_to_query, alias,
235*9c5db199SXin Li                          left_join=True):
236*9c5db199SXin Li        """Join to a related model to create a custom field in the given query.
237*9c5db199SXin Li
238*9c5db199SXin Li        This method is used to construct a custom field on the given query based
239*9c5db199SXin Li        on a many-valued relationsip.  join_to_query should be a simple query
240*9c5db199SXin Li        (no joins) on the related model which returns at most one related row
241*9c5db199SXin Li        per instance of this model.
242*9c5db199SXin Li
243*9c5db199SXin Li        For many-to-one relationships, the joined table contains the matching
244*9c5db199SXin Li        row from the related model it one is related, NULL otherwise.
245*9c5db199SXin Li
246*9c5db199SXin Li        For many-to-many relationships, the joined table contains the matching
247*9c5db199SXin Li        row if it's related, NULL otherwise.
248*9c5db199SXin Li        """
249*9c5db199SXin Li        relationship_type, field = self.determine_relationship(
250*9c5db199SXin Li                join_to_query.model)
251*9c5db199SXin Li
252*9c5db199SXin Li        if relationship_type == self.MANY_TO_ONE:
253*9c5db199SXin Li            info = self._info_for_many_to_one_join(field, join_to_query, alias)
254*9c5db199SXin Li        elif relationship_type == self.M2M_ON_RELATED_MODEL:
255*9c5db199SXin Li            info = self._info_for_many_to_many_join(
256*9c5db199SXin Li                    m2m_field=field, join_to_query=join_to_query, alias=alias,
257*9c5db199SXin Li                    m2m_is_on_this_model=False)
258*9c5db199SXin Li        elif relationship_type ==self.M2M_ON_THIS_MODEL:
259*9c5db199SXin Li            info = self._info_for_many_to_many_join(
260*9c5db199SXin Li                    m2m_field=field, join_to_query=join_to_query, alias=alias,
261*9c5db199SXin Li                    m2m_is_on_this_model=True)
262*9c5db199SXin Li
263*9c5db199SXin Li        return self.add_join(query_set, info['rhs_table'], info['rhs_column'],
264*9c5db199SXin Li                             join_from_key=info['lhs_column'],
265*9c5db199SXin Li                             join_condition=info['where_clause'],
266*9c5db199SXin Li                             join_condition_values=info['values'],
267*9c5db199SXin Li                             alias=alias,
268*9c5db199SXin Li                             force_left_join=left_join)
269*9c5db199SXin Li
270*9c5db199SXin Li
271*9c5db199SXin Li    def add_where(self, query_set, where, values=()):
272*9c5db199SXin Li        """Adds a where clause to the query_set."""
273*9c5db199SXin Li        query_set = query_set.all()
274*9c5db199SXin Li        query_set.query.where.add(self._WhereClause(where, values),
275*9c5db199SXin Li                                  django.db.models.sql.where.AND)
276*9c5db199SXin Li        return query_set
277*9c5db199SXin Li
278*9c5db199SXin Li
279*9c5db199SXin Li    def _get_quoted_field(self, table, field):
280*9c5db199SXin Li        return _quote_name(table) + '.' + _quote_name(field)
281*9c5db199SXin Li
282*9c5db199SXin Li
283*9c5db199SXin Li    def get_key_on_this_table(self, key_field=None):
284*9c5db199SXin Li        if key_field is None:
285*9c5db199SXin Li            # default to primary key
286*9c5db199SXin Li            key_field = self.model._meta.pk.column
287*9c5db199SXin Li        return self._get_quoted_field(self.model._meta.db_table, key_field)
288*9c5db199SXin Li
289*9c5db199SXin Li
290*9c5db199SXin Li    def escape_user_sql(self, sql):
291*9c5db199SXin Li        """Escapes % in sql."""
292*9c5db199SXin Li        return sql.replace('%', '%%')
293*9c5db199SXin Li
294*9c5db199SXin Li
295*9c5db199SXin Li    def _custom_select_query(self, query_set, selects):
296*9c5db199SXin Li        """Execute a custom select query.
297*9c5db199SXin Li
298*9c5db199SXin Li        @param query_set: query set as returned by query_objects.
299*9c5db199SXin Li        @param selects: Tables/Columns to select, e.g. tko_test_labels_list.id.
300*9c5db199SXin Li
301*9c5db199SXin Li        @returns: Result of the query as returned by cursor.fetchall().
302*9c5db199SXin Li        """
303*9c5db199SXin Li        compiler = query_set.query.get_compiler(using=query_set.db)
304*9c5db199SXin Li        sql, params = compiler.as_sql()
305*9c5db199SXin Li        from_ = sql[sql.find(' FROM'):]
306*9c5db199SXin Li
307*9c5db199SXin Li        if query_set.query.distinct:
308*9c5db199SXin Li            distinct = 'DISTINCT '
309*9c5db199SXin Li        else:
310*9c5db199SXin Li            distinct = ''
311*9c5db199SXin Li
312*9c5db199SXin Li        sql_query = ('SELECT ' + distinct + ','.join(selects) + from_)
313*9c5db199SXin Li        # Chose the connection that's responsible for this type of object
314*9c5db199SXin Li        cursor = connections[query_set.db].cursor()
315*9c5db199SXin Li        cursor.execute(sql_query, params)
316*9c5db199SXin Li        return cursor.fetchall()
317*9c5db199SXin Li
318*9c5db199SXin Li
319*9c5db199SXin Li    def _is_relation_to(self, field, model_class):
320*9c5db199SXin Li        return field.rel and field.rel.to is model_class
321*9c5db199SXin Li
322*9c5db199SXin Li
323*9c5db199SXin Li    MANY_TO_ONE = object()
324*9c5db199SXin Li    M2M_ON_RELATED_MODEL = object()
325*9c5db199SXin Li    M2M_ON_THIS_MODEL = object()
326*9c5db199SXin Li
327*9c5db199SXin Li    def determine_relationship(self, related_model):
328*9c5db199SXin Li        """
329*9c5db199SXin Li        Determine the relationship between this model and related_model.
330*9c5db199SXin Li
331*9c5db199SXin Li        related_model must have some sort of many-valued relationship to this
332*9c5db199SXin Li        manager's model.
333*9c5db199SXin Li        @returns (relationship_type, field), where relationship_type is one of
334*9c5db199SXin Li                MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field
335*9c5db199SXin Li                is the Django field object for the relationship.
336*9c5db199SXin Li        """
337*9c5db199SXin Li        # look for a foreign key field on related_model relating to this model
338*9c5db199SXin Li        for field in related_model._meta.fields:
339*9c5db199SXin Li            if self._is_relation_to(field, self.model):
340*9c5db199SXin Li                return self.MANY_TO_ONE, field
341*9c5db199SXin Li
342*9c5db199SXin Li        # look for an M2M field on related_model relating to this model
343*9c5db199SXin Li        for field in related_model._meta.many_to_many:
344*9c5db199SXin Li            if self._is_relation_to(field, self.model):
345*9c5db199SXin Li                return self.M2M_ON_RELATED_MODEL, field
346*9c5db199SXin Li
347*9c5db199SXin Li        # maybe this model has the many-to-many field
348*9c5db199SXin Li        for field in self.model._meta.many_to_many:
349*9c5db199SXin Li            if self._is_relation_to(field, related_model):
350*9c5db199SXin Li                return self.M2M_ON_THIS_MODEL, field
351*9c5db199SXin Li
352*9c5db199SXin Li        raise ValueError('%s has no relation to %s' %
353*9c5db199SXin Li                         (related_model, self.model))
354*9c5db199SXin Li
355*9c5db199SXin Li
356*9c5db199SXin Li    def _get_pivot_iterator(self, base_objects_by_id, related_model):
357*9c5db199SXin Li        """
358*9c5db199SXin Li        Determine the relationship between this model and related_model, and
359*9c5db199SXin Li        return a pivot iterator.
360*9c5db199SXin Li        @param base_objects_by_id: dict of instances of this model indexed by
361*9c5db199SXin Li        their IDs
362*9c5db199SXin Li        @returns a pivot iterator, which yields a tuple (base_object,
363*9c5db199SXin Li        related_object) for each relationship between a base object and a
364*9c5db199SXin Li        related object.  all base_object instances come from base_objects_by_id.
365*9c5db199SXin Li        Note -- this depends on Django model internals.
366*9c5db199SXin Li        """
367*9c5db199SXin Li        relationship_type, field = self.determine_relationship(related_model)
368*9c5db199SXin Li        if relationship_type == self.MANY_TO_ONE:
369*9c5db199SXin Li            return self._many_to_one_pivot(base_objects_by_id,
370*9c5db199SXin Li                                           related_model, field)
371*9c5db199SXin Li        elif relationship_type == self.M2M_ON_RELATED_MODEL:
372*9c5db199SXin Li            return self._many_to_many_pivot(
373*9c5db199SXin Li                    base_objects_by_id, related_model, field.m2m_db_table(),
374*9c5db199SXin Li                    field.m2m_reverse_name(), field.m2m_column_name())
375*9c5db199SXin Li        else:
376*9c5db199SXin Li            assert relationship_type == self.M2M_ON_THIS_MODEL
377*9c5db199SXin Li            return self._many_to_many_pivot(
378*9c5db199SXin Li                    base_objects_by_id, related_model, field.m2m_db_table(),
379*9c5db199SXin Li                    field.m2m_column_name(), field.m2m_reverse_name())
380*9c5db199SXin Li
381*9c5db199SXin Li
382*9c5db199SXin Li    def _many_to_one_pivot(self, base_objects_by_id, related_model,
383*9c5db199SXin Li                           foreign_key_field):
384*9c5db199SXin Li        """
385*9c5db199SXin Li        @returns a pivot iterator - see _get_pivot_iterator()
386*9c5db199SXin Li        """
387*9c5db199SXin Li        filter_data = {foreign_key_field.name + '__pk__in':
388*9c5db199SXin Li                       base_objects_by_id.keys()}
389*9c5db199SXin Li        for related_object in related_model.objects.filter(**filter_data):
390*9c5db199SXin Li            # lookup base object in the dict, rather than grabbing it from the
391*9c5db199SXin Li            # related object.  we need to return instances from the dict, not
392*9c5db199SXin Li            # fresh instances of the same models (and grabbing model instances
393*9c5db199SXin Li            # from the related models incurs a DB query each time).
394*9c5db199SXin Li            base_object_id = getattr(related_object, foreign_key_field.attname)
395*9c5db199SXin Li            base_object = base_objects_by_id[base_object_id]
396*9c5db199SXin Li            yield base_object, related_object
397*9c5db199SXin Li
398*9c5db199SXin Li
399*9c5db199SXin Li    def _query_pivot_table(self, base_objects_by_id, pivot_table,
400*9c5db199SXin Li                           pivot_from_field, pivot_to_field, related_model):
401*9c5db199SXin Li        """
402*9c5db199SXin Li        @param id_list list of IDs of self.model objects to include
403*9c5db199SXin Li        @param pivot_table the name of the pivot table
404*9c5db199SXin Li        @param pivot_from_field a field name on pivot_table referencing
405*9c5db199SXin Li        self.model
406*9c5db199SXin Li        @param pivot_to_field a field name on pivot_table referencing the
407*9c5db199SXin Li        related model.
408*9c5db199SXin Li        @param related_model the related model
409*9c5db199SXin Li
410*9c5db199SXin Li        @returns pivot list of IDs (base_id, related_id)
411*9c5db199SXin Li        """
412*9c5db199SXin Li        query = """
413*9c5db199SXin Li        SELECT %(from_field)s, %(to_field)s
414*9c5db199SXin Li        FROM %(table)s
415*9c5db199SXin Li        WHERE %(from_field)s IN (%(id_list)s)
416*9c5db199SXin Li        """ % dict(from_field=pivot_from_field,
417*9c5db199SXin Li                   to_field=pivot_to_field,
418*9c5db199SXin Li                   table=pivot_table,
419*9c5db199SXin Li                   id_list=','.join(
420*9c5db199SXin Li                           str(id_)
421*9c5db199SXin Li                           for id_ in six.iterkeys(base_objects_by_id)))
422*9c5db199SXin Li
423*9c5db199SXin Li        # Chose the connection that's responsible for this type of object
424*9c5db199SXin Li        # The databases for related_model and the current model will always
425*9c5db199SXin Li        # be the same, related_model is just easier to obtain here because
426*9c5db199SXin Li        # self is only a ExtendedManager, not the object.
427*9c5db199SXin Li        cursor = connections[related_model.objects.db].cursor()
428*9c5db199SXin Li        cursor.execute(query)
429*9c5db199SXin Li        return cursor.fetchall()
430*9c5db199SXin Li
431*9c5db199SXin Li
432*9c5db199SXin Li    def _many_to_many_pivot(self, base_objects_by_id, related_model,
433*9c5db199SXin Li                            pivot_table, pivot_from_field, pivot_to_field):
434*9c5db199SXin Li        """
435*9c5db199SXin Li        @param pivot_table: see _query_pivot_table
436*9c5db199SXin Li        @param pivot_from_field: see _query_pivot_table
437*9c5db199SXin Li        @param pivot_to_field: see _query_pivot_table
438*9c5db199SXin Li        @returns a pivot iterator - see _get_pivot_iterator()
439*9c5db199SXin Li        """
440*9c5db199SXin Li        id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
441*9c5db199SXin Li                                           pivot_from_field, pivot_to_field,
442*9c5db199SXin Li                                           related_model)
443*9c5db199SXin Li
444*9c5db199SXin Li        all_related_ids = list(set(related_id for base_id, related_id
445*9c5db199SXin Li                                   in id_pivot))
446*9c5db199SXin Li        related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
447*9c5db199SXin Li
448*9c5db199SXin Li        for base_id, related_id in id_pivot:
449*9c5db199SXin Li            yield base_objects_by_id[base_id], related_objects_by_id[related_id]
450*9c5db199SXin Li
451*9c5db199SXin Li
452*9c5db199SXin Li    def populate_relationships(self, base_objects, related_model,
453*9c5db199SXin Li                               related_list_name):
454*9c5db199SXin Li        """
455*9c5db199SXin Li        For each instance of this model in base_objects, add a field named
456*9c5db199SXin Li        related_list_name listing all the related objects of type related_model.
457*9c5db199SXin Li        related_model must be in a many-to-one or many-to-many relationship with
458*9c5db199SXin Li        this model.
459*9c5db199SXin Li        @param base_objects - list of instances of this model
460*9c5db199SXin Li        @param related_model - model class related to this model
461*9c5db199SXin Li        @param related_list_name - attribute name in which to store the related
462*9c5db199SXin Li        object list.
463*9c5db199SXin Li        """
464*9c5db199SXin Li        if not base_objects:
465*9c5db199SXin Li            # if we don't bail early, we'll get a SQL error later
466*9c5db199SXin Li            return
467*9c5db199SXin Li
468*9c5db199SXin Li        # The default maximum value of a host parameter number in SQLite is 999.
469*9c5db199SXin Li        # Exceed this will get a DatabaseError later.
470*9c5db199SXin Li        batch_size = 900
471*9c5db199SXin Li        for i in range(0, len(base_objects), batch_size):
472*9c5db199SXin Li            base_objects_batch = base_objects[i:i + batch_size]
473*9c5db199SXin Li            base_objects_by_id = dict((base_object._get_pk_val(), base_object)
474*9c5db199SXin Li                                      for base_object in base_objects_batch)
475*9c5db199SXin Li            pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
476*9c5db199SXin Li                                                      related_model)
477*9c5db199SXin Li
478*9c5db199SXin Li            for base_object in base_objects_batch:
479*9c5db199SXin Li                setattr(base_object, related_list_name, [])
480*9c5db199SXin Li
481*9c5db199SXin Li            for base_object, related_object in pivot_iterator:
482*9c5db199SXin Li                getattr(base_object, related_list_name).append(related_object)
483*9c5db199SXin Li
484*9c5db199SXin Li
485*9c5db199SXin Liclass ModelWithInvalidQuerySet(dbmodels.query.QuerySet):
486*9c5db199SXin Li    """
487*9c5db199SXin Li    QuerySet that handles delete() properly for models with an "invalid" bit
488*9c5db199SXin Li    """
489*9c5db199SXin Li    def delete(self):
490*9c5db199SXin Li        """Deletes the QuerySet."""
491*9c5db199SXin Li        for model in self:
492*9c5db199SXin Li            model.delete()
493*9c5db199SXin Li
494*9c5db199SXin Li
495*9c5db199SXin Liclass ModelWithInvalidManager(ExtendedManager):
496*9c5db199SXin Li    """
497*9c5db199SXin Li    Manager for objects with an "invalid" bit
498*9c5db199SXin Li    """
499*9c5db199SXin Li    def get_query_set(self):
500*9c5db199SXin Li        return ModelWithInvalidQuerySet(self.model)
501*9c5db199SXin Li
502*9c5db199SXin Li
503*9c5db199SXin Liclass ValidObjectsManager(ModelWithInvalidManager):
504*9c5db199SXin Li    """
505*9c5db199SXin Li    Manager returning only objects with invalid=False.
506*9c5db199SXin Li    """
507*9c5db199SXin Li    def get_query_set(self):
508*9c5db199SXin Li        queryset = super(ValidObjectsManager, self).get_query_set()
509*9c5db199SXin Li        return queryset.filter(invalid=False)
510*9c5db199SXin Li
511*9c5db199SXin Li
512*9c5db199SXin Liclass ModelExtensions(rdb_model_extensions.ModelValidators):
513*9c5db199SXin Li    """\
514*9c5db199SXin Li    Mixin with convenience functions for models, built on top of
515*9c5db199SXin Li    the model validators in rdb_model_extensions.
516*9c5db199SXin Li    """
517*9c5db199SXin Li    # TODO: at least some of these functions really belong in a custom
518*9c5db199SXin Li    # Manager class
519*9c5db199SXin Li
520*9c5db199SXin Li
521*9c5db199SXin Li    SERIALIZATION_LINKS_TO_FOLLOW = set()
522*9c5db199SXin Li    """
523*9c5db199SXin Li    To be able to send jobs and hosts to shards, it's necessary to find their
524*9c5db199SXin Li    dependencies.
525*9c5db199SXin Li    The most generic approach for this would be to traverse all relationships
526*9c5db199SXin Li    to other objects recursively. This would list all objects that are related
527*9c5db199SXin Li    in any way.
528*9c5db199SXin Li    But this approach finds too many objects: If a host should be transferred,
529*9c5db199SXin Li    all it's relationships would be traversed. This would find an acl group.
530*9c5db199SXin Li    If then the acl group's relationships are traversed, the relationship
531*9c5db199SXin Li    would be followed backwards and many other hosts would be found.
532*9c5db199SXin Li
533*9c5db199SXin Li    This mapping tells that algorithm which relations to follow explicitly.
534*9c5db199SXin Li    """
535*9c5db199SXin Li
536*9c5db199SXin Li
537*9c5db199SXin Li    SERIALIZATION_LINKS_TO_KEEP = set()
538*9c5db199SXin Li    """This set stores foreign keys which we don't want to follow, but
539*9c5db199SXin Li    still want to include in the serialized dictionary. For
540*9c5db199SXin Li    example, we follow the relationship `Host.hostattribute_set`,
541*9c5db199SXin Li    but we do not want to follow `HostAttributes.host_id` back to
542*9c5db199SXin Li    to Host, which would otherwise lead to a circle. However, we still
543*9c5db199SXin Li    like to serialize HostAttribute.`host_id`."""
544*9c5db199SXin Li
545*9c5db199SXin Li    SERIALIZATION_LOCAL_LINKS_TO_UPDATE = set()
546*9c5db199SXin Li    """
547*9c5db199SXin Li    On deserializion, if the object to persist already exists, local fields
548*9c5db199SXin Li    will only be updated, if their name is in this set.
549*9c5db199SXin Li    """
550*9c5db199SXin Li
551*9c5db199SXin Li
552*9c5db199SXin Li    @classmethod
553*9c5db199SXin Li    def convert_human_readable_values(cls, data, to_human_readable=False):
554*9c5db199SXin Li        """\
555*9c5db199SXin Li        Performs conversions on user-supplied field data, to make it
556*9c5db199SXin Li        easier for users to pass human-readable data.
557*9c5db199SXin Li
558*9c5db199SXin Li        For all fields that have choice sets, convert their values
559*9c5db199SXin Li        from human-readable strings to enum values, if necessary.  This
560*9c5db199SXin Li        allows users to pass strings instead of the corresponding
561*9c5db199SXin Li        integer values.
562*9c5db199SXin Li
563*9c5db199SXin Li        For all foreign key fields, call smart_get with the supplied
564*9c5db199SXin Li        data.  This allows the user to pass either an ID value or
565*9c5db199SXin Li        the name of the object as a string.
566*9c5db199SXin Li
567*9c5db199SXin Li        If to_human_readable=True, perform the inverse - i.e. convert
568*9c5db199SXin Li        numeric values to human readable values.
569*9c5db199SXin Li
570*9c5db199SXin Li        This method modifies data in-place.
571*9c5db199SXin Li        """
572*9c5db199SXin Li        field_dict = cls.get_field_dict()
573*9c5db199SXin Li        for field_name in data:
574*9c5db199SXin Li            if field_name not in field_dict or data[field_name] is None:
575*9c5db199SXin Li                continue
576*9c5db199SXin Li            field_obj = field_dict[field_name]
577*9c5db199SXin Li            # convert enum values
578*9c5db199SXin Li            if field_obj.choices:
579*9c5db199SXin Li                for choice_data in field_obj.choices:
580*9c5db199SXin Li                    # choice_data is (value, name)
581*9c5db199SXin Li                    if to_human_readable:
582*9c5db199SXin Li                        from_val, to_val = choice_data
583*9c5db199SXin Li                    else:
584*9c5db199SXin Li                        to_val, from_val = choice_data
585*9c5db199SXin Li                    if from_val == data[field_name]:
586*9c5db199SXin Li                        data[field_name] = to_val
587*9c5db199SXin Li                        break
588*9c5db199SXin Li            # convert foreign key values
589*9c5db199SXin Li            elif field_obj.rel:
590*9c5db199SXin Li                dest_obj = field_obj.rel.to.smart_get(data[field_name],
591*9c5db199SXin Li                                                      valid_only=False)
592*9c5db199SXin Li                if to_human_readable:
593*9c5db199SXin Li                    # parameterized_jobs do not have a name_field
594*9c5db199SXin Li                    if (field_name != 'parameterized_job' and
595*9c5db199SXin Li                        dest_obj.name_field is not None):
596*9c5db199SXin Li                        data[field_name] = getattr(dest_obj,
597*9c5db199SXin Li                                                   dest_obj.name_field)
598*9c5db199SXin Li                else:
599*9c5db199SXin Li                    data[field_name] = dest_obj
600*9c5db199SXin Li
601*9c5db199SXin Li
602*9c5db199SXin Li
603*9c5db199SXin Li
604*9c5db199SXin Li    def _validate_unique(self):
605*9c5db199SXin Li        """\
606*9c5db199SXin Li        Validate that unique fields are unique.  Django manipulators do
607*9c5db199SXin Li        this too, but they're a huge pain to use manually.  Trust me.
608*9c5db199SXin Li        """
609*9c5db199SXin Li        errors = {}
610*9c5db199SXin Li        cls = type(self)
611*9c5db199SXin Li        field_dict = self.get_field_dict()
612*9c5db199SXin Li        manager = cls.get_valid_manager()
613*9c5db199SXin Li        for field_name, field_obj in six.iteritems(field_dict):
614*9c5db199SXin Li            if not field_obj.unique:
615*9c5db199SXin Li                continue
616*9c5db199SXin Li
617*9c5db199SXin Li            value = getattr(self, field_name)
618*9c5db199SXin Li            if value is None and field_obj.auto_created:
619*9c5db199SXin Li                # don't bother checking autoincrement fields about to be
620*9c5db199SXin Li                # generated
621*9c5db199SXin Li                continue
622*9c5db199SXin Li
623*9c5db199SXin Li            existing_objs = manager.filter(**{field_name : value})
624*9c5db199SXin Li            num_existing = existing_objs.count()
625*9c5db199SXin Li
626*9c5db199SXin Li            if num_existing == 0:
627*9c5db199SXin Li                continue
628*9c5db199SXin Li            if num_existing == 1 and existing_objs[0].id == self.id:
629*9c5db199SXin Li                continue
630*9c5db199SXin Li            errors[field_name] = (
631*9c5db199SXin Li                'This value must be unique (%s)' % (value))
632*9c5db199SXin Li        return errors
633*9c5db199SXin Li
634*9c5db199SXin Li
635*9c5db199SXin Li    def _validate(self):
636*9c5db199SXin Li        """
637*9c5db199SXin Li        First coerces all fields on this instance to their proper Python types.
638*9c5db199SXin Li        Then runs validation on every field. Returns a dictionary of
639*9c5db199SXin Li        field_name -> error_list.
640*9c5db199SXin Li
641*9c5db199SXin Li        Based on validate() from django.db.models.Model in Django 0.96, which
642*9c5db199SXin Li        was removed in Django 1.0. It should reappear in a later version. See:
643*9c5db199SXin Li            http://code.djangoproject.com/ticket/6845
644*9c5db199SXin Li        """
645*9c5db199SXin Li        error_dict = {}
646*9c5db199SXin Li        for f in self._meta.fields:
647*9c5db199SXin Li            try:
648*9c5db199SXin Li                python_value = f.to_python(
649*9c5db199SXin Li                    getattr(self, f.attname, f.get_default()))
650*9c5db199SXin Li            except django.core.exceptions.ValidationError as e:
651*9c5db199SXin Li                error_dict[f.name] = str(e)
652*9c5db199SXin Li                continue
653*9c5db199SXin Li
654*9c5db199SXin Li            if not f.blank and not python_value:
655*9c5db199SXin Li                error_dict[f.name] = 'This field is required.'
656*9c5db199SXin Li                continue
657*9c5db199SXin Li
658*9c5db199SXin Li            setattr(self, f.attname, python_value)
659*9c5db199SXin Li
660*9c5db199SXin Li        return error_dict
661*9c5db199SXin Li
662*9c5db199SXin Li
663*9c5db199SXin Li    def do_validate(self):
664*9c5db199SXin Li        """Validate fields."""
665*9c5db199SXin Li        errors = self._validate()
666*9c5db199SXin Li        unique_errors = self._validate_unique()
667*9c5db199SXin Li        for field_name, error in six.iteritems(unique_errors):
668*9c5db199SXin Li            errors.setdefault(field_name, error)
669*9c5db199SXin Li        if errors:
670*9c5db199SXin Li            raise ValidationError(errors)
671*9c5db199SXin Li
672*9c5db199SXin Li
673*9c5db199SXin Li    # actually (externally) useful methods follow
674*9c5db199SXin Li
675*9c5db199SXin Li    @classmethod
676*9c5db199SXin Li    def add_object(cls, data={}, **kwargs):
677*9c5db199SXin Li        """\
678*9c5db199SXin Li        Returns a new object created with the given data (a dictionary
679*9c5db199SXin Li        mapping field names to values). Merges any extra keyword args
680*9c5db199SXin Li        into data.
681*9c5db199SXin Li        """
682*9c5db199SXin Li        data = dict(data)
683*9c5db199SXin Li        data.update(kwargs)
684*9c5db199SXin Li        data = cls.prepare_data_args(data)
685*9c5db199SXin Li        cls.convert_human_readable_values(data)
686*9c5db199SXin Li        data = cls.provide_default_values(data)
687*9c5db199SXin Li
688*9c5db199SXin Li        obj = cls(**data)
689*9c5db199SXin Li        obj.do_validate()
690*9c5db199SXin Li        obj.save()
691*9c5db199SXin Li        return obj
692*9c5db199SXin Li
693*9c5db199SXin Li
694*9c5db199SXin Li    def update_object(self, data={}, **kwargs):
695*9c5db199SXin Li        """\
696*9c5db199SXin Li        Updates the object with the given data (a dictionary mapping
697*9c5db199SXin Li        field names to values).  Merges any extra keyword args into
698*9c5db199SXin Li        data.
699*9c5db199SXin Li        """
700*9c5db199SXin Li        data = dict(data)
701*9c5db199SXin Li        data.update(kwargs)
702*9c5db199SXin Li        data = self.prepare_data_args(data)
703*9c5db199SXin Li        self.convert_human_readable_values(data)
704*9c5db199SXin Li        for field_name, value in six.iteritems(data):
705*9c5db199SXin Li            setattr(self, field_name, value)
706*9c5db199SXin Li        self.do_validate()
707*9c5db199SXin Li        self.save()
708*9c5db199SXin Li
709*9c5db199SXin Li
710*9c5db199SXin Li    # see query_objects()
711*9c5db199SXin Li    _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by',
712*9c5db199SXin Li                            'extra_args', 'extra_where', 'no_distinct')
713*9c5db199SXin Li
714*9c5db199SXin Li
715*9c5db199SXin Li    @classmethod
716*9c5db199SXin Li    def _extract_special_params(cls, filter_data):
717*9c5db199SXin Li        """
718*9c5db199SXin Li        @returns a tuple of dicts (special_params, regular_filters), where
719*9c5db199SXin Li        special_params contains the parameters we handle specially and
720*9c5db199SXin Li        regular_filters is the remaining data to be handled by Django.
721*9c5db199SXin Li        """
722*9c5db199SXin Li        regular_filters = dict(filter_data)
723*9c5db199SXin Li        special_params = {}
724*9c5db199SXin Li        for key in cls._SPECIAL_FILTER_KEYS:
725*9c5db199SXin Li            if key in regular_filters:
726*9c5db199SXin Li                special_params[key] = regular_filters.pop(key)
727*9c5db199SXin Li        return special_params, regular_filters
728*9c5db199SXin Li
729*9c5db199SXin Li
730*9c5db199SXin Li    @classmethod
731*9c5db199SXin Li    def apply_presentation(cls, query, filter_data):
732*9c5db199SXin Li        """
733*9c5db199SXin Li        Apply presentation parameters -- sorting and paging -- to the given
734*9c5db199SXin Li        query.
735*9c5db199SXin Li        @returns new query with presentation applied
736*9c5db199SXin Li        """
737*9c5db199SXin Li        special_params, _ = cls._extract_special_params(filter_data)
738*9c5db199SXin Li        sort_by = special_params.get('sort_by', None)
739*9c5db199SXin Li        if sort_by:
740*9c5db199SXin Li            assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
741*9c5db199SXin Li            query = query.extra(order_by=sort_by)
742*9c5db199SXin Li
743*9c5db199SXin Li        query_start = special_params.get('query_start', None)
744*9c5db199SXin Li        query_limit = special_params.get('query_limit', None)
745*9c5db199SXin Li        if query_start is not None:
746*9c5db199SXin Li            if query_limit is None:
747*9c5db199SXin Li                raise ValueError('Cannot pass query_start without query_limit')
748*9c5db199SXin Li            # query_limit is passed as a page size
749*9c5db199SXin Li            query_limit += query_start
750*9c5db199SXin Li        return query[query_start:query_limit]
751*9c5db199SXin Li
752*9c5db199SXin Li
753*9c5db199SXin Li    @classmethod
754*9c5db199SXin Li    def query_objects(cls, filter_data, valid_only=True, initial_query=None,
755*9c5db199SXin Li                      apply_presentation=True):
756*9c5db199SXin Li        """\
757*9c5db199SXin Li        Returns a QuerySet object for querying the given model_class
758*9c5db199SXin Li        with the given filter_data.  Optional special arguments in
759*9c5db199SXin Li        filter_data include:
760*9c5db199SXin Li        -query_start: index of first return to return
761*9c5db199SXin Li        -query_limit: maximum number of results to return
762*9c5db199SXin Li        -sort_by: list of fields to sort on.  prefixing a '-' onto a
763*9c5db199SXin Li         field name changes the sort to descending order.
764*9c5db199SXin Li        -extra_args: keyword args to pass to query.extra() (see Django
765*9c5db199SXin Li         DB layer documentation)
766*9c5db199SXin Li        -extra_where: extra WHERE clause to append
767*9c5db199SXin Li        -no_distinct: if True, a DISTINCT will not be added to the SELECT
768*9c5db199SXin Li        """
769*9c5db199SXin Li        special_params, regular_filters = cls._extract_special_params(
770*9c5db199SXin Li                filter_data)
771*9c5db199SXin Li
772*9c5db199SXin Li        if initial_query is None:
773*9c5db199SXin Li            if valid_only:
774*9c5db199SXin Li                initial_query = cls.get_valid_manager()
775*9c5db199SXin Li            else:
776*9c5db199SXin Li                initial_query = cls.objects
777*9c5db199SXin Li
778*9c5db199SXin Li        query = initial_query.filter(**regular_filters)
779*9c5db199SXin Li
780*9c5db199SXin Li        use_distinct = not special_params.get('no_distinct', False)
781*9c5db199SXin Li        if use_distinct:
782*9c5db199SXin Li            query = query.distinct()
783*9c5db199SXin Li
784*9c5db199SXin Li        extra_args = special_params.get('extra_args', {})
785*9c5db199SXin Li        extra_where = special_params.get('extra_where', None)
786*9c5db199SXin Li        if extra_where:
787*9c5db199SXin Li            # escape %'s
788*9c5db199SXin Li            extra_where = cls.objects.escape_user_sql(extra_where)
789*9c5db199SXin Li            extra_args.setdefault('where', []).append(extra_where)
790*9c5db199SXin Li        if extra_args:
791*9c5db199SXin Li            query = query.extra(**extra_args)
792*9c5db199SXin Li            # TODO: Use readonly connection for these queries.
793*9c5db199SXin Li            # This has been disabled, because it's not used anyway, as the
794*9c5db199SXin Li            # configured readonly user is the same as the real user anyway.
795*9c5db199SXin Li
796*9c5db199SXin Li        if apply_presentation:
797*9c5db199SXin Li            query = cls.apply_presentation(query, filter_data)
798*9c5db199SXin Li
799*9c5db199SXin Li        return query
800*9c5db199SXin Li
801*9c5db199SXin Li
802*9c5db199SXin Li    @classmethod
803*9c5db199SXin Li    def query_count(cls, filter_data, initial_query=None):
804*9c5db199SXin Li        """\
805*9c5db199SXin Li        Like query_objects, but retreive only the count of results.
806*9c5db199SXin Li        """
807*9c5db199SXin Li        filter_data.pop('query_start', None)
808*9c5db199SXin Li        filter_data.pop('query_limit', None)
809*9c5db199SXin Li        query = cls.query_objects(filter_data, initial_query=initial_query)
810*9c5db199SXin Li        return query.count()
811*9c5db199SXin Li
812*9c5db199SXin Li
813*9c5db199SXin Li    @classmethod
814*9c5db199SXin Li    def clean_object_dicts(cls, field_dicts):
815*9c5db199SXin Li        """\
816*9c5db199SXin Li        Take a list of dicts corresponding to object (as returned by
817*9c5db199SXin Li        query.values()) and clean the data to be more suitable for
818*9c5db199SXin Li        returning to the user.
819*9c5db199SXin Li        """
820*9c5db199SXin Li        for field_dict in field_dicts:
821*9c5db199SXin Li            cls.clean_foreign_keys(field_dict)
822*9c5db199SXin Li            cls._convert_booleans(field_dict)
823*9c5db199SXin Li            cls.convert_human_readable_values(field_dict,
824*9c5db199SXin Li                                              to_human_readable=True)
825*9c5db199SXin Li
826*9c5db199SXin Li
827*9c5db199SXin Li    @classmethod
828*9c5db199SXin Li    def list_objects(cls, filter_data, initial_query=None):
829*9c5db199SXin Li        """\
830*9c5db199SXin Li        Like query_objects, but return a list of dictionaries.
831*9c5db199SXin Li        """
832*9c5db199SXin Li        query = cls.query_objects(filter_data, initial_query=initial_query)
833*9c5db199SXin Li        extra_fields = query.query.extra_select.keys()
834*9c5db199SXin Li        field_dicts = [model_object.get_object_dict(extra_fields=extra_fields)
835*9c5db199SXin Li                       for model_object in query]
836*9c5db199SXin Li        return field_dicts
837*9c5db199SXin Li
838*9c5db199SXin Li
839*9c5db199SXin Li    @classmethod
840*9c5db199SXin Li    def smart_get(cls, id_or_name, valid_only=True):
841*9c5db199SXin Li        """\
842*9c5db199SXin Li        smart_get(integer) -> get object by ID
843*9c5db199SXin Li        smart_get(string) -> get object by name_field
844*9c5db199SXin Li        """
845*9c5db199SXin Li        if valid_only:
846*9c5db199SXin Li            manager = cls.get_valid_manager()
847*9c5db199SXin Li        else:
848*9c5db199SXin Li            manager = cls.objects
849*9c5db199SXin Li
850*9c5db199SXin Li        if isinstance(id_or_name, six.integer_types):
851*9c5db199SXin Li            return manager.get(pk=id_or_name)
852*9c5db199SXin Li        if isinstance(id_or_name, six.string_types) and hasattr(
853*9c5db199SXin Li                cls, 'name_field'):
854*9c5db199SXin Li            return manager.get(**{cls.name_field : id_or_name})
855*9c5db199SXin Li        raise ValueError(
856*9c5db199SXin Li            'Invalid positional argument: %s (%s)' % (id_or_name,
857*9c5db199SXin Li                                                      type(id_or_name)))
858*9c5db199SXin Li
859*9c5db199SXin Li
860*9c5db199SXin Li    @classmethod
861*9c5db199SXin Li    def smart_get_bulk(cls, id_or_name_list):
862*9c5db199SXin Li        """Like smart_get, but for a list of ids or names"""
863*9c5db199SXin Li        invalid_inputs = []
864*9c5db199SXin Li        result_objects = []
865*9c5db199SXin Li        for id_or_name in id_or_name_list:
866*9c5db199SXin Li            try:
867*9c5db199SXin Li                result_objects.append(cls.smart_get(id_or_name))
868*9c5db199SXin Li            except cls.DoesNotExist:
869*9c5db199SXin Li                invalid_inputs.append(id_or_name)
870*9c5db199SXin Li        if invalid_inputs:
871*9c5db199SXin Li            raise cls.DoesNotExist('The following %ss do not exist: %s'
872*9c5db199SXin Li                                   % (cls.__name__.lower(),
873*9c5db199SXin Li                                      ', '.join(invalid_inputs)))
874*9c5db199SXin Li        return result_objects
875*9c5db199SXin Li
876*9c5db199SXin Li
877*9c5db199SXin Li    def get_object_dict(self, extra_fields=None):
878*9c5db199SXin Li        """\
879*9c5db199SXin Li        Return a dictionary mapping fields to this object's values.  @param
880*9c5db199SXin Li        extra_fields: list of extra attribute names to include, in addition to
881*9c5db199SXin Li        the fields defined on this object.
882*9c5db199SXin Li        """
883*9c5db199SXin Li        fields = self.get_field_dict().keys()
884*9c5db199SXin Li        if extra_fields:
885*9c5db199SXin Li            fields += extra_fields
886*9c5db199SXin Li        object_dict = dict((field_name, getattr(self, field_name))
887*9c5db199SXin Li                           for field_name in fields)
888*9c5db199SXin Li        self.clean_object_dicts([object_dict])
889*9c5db199SXin Li        self._postprocess_object_dict(object_dict)
890*9c5db199SXin Li        return object_dict
891*9c5db199SXin Li
892*9c5db199SXin Li
893*9c5db199SXin Li    def _postprocess_object_dict(self, object_dict):
894*9c5db199SXin Li        """For subclasses to override."""
895*9c5db199SXin Li        pass
896*9c5db199SXin Li
897*9c5db199SXin Li
898*9c5db199SXin Li    @classmethod
899*9c5db199SXin Li    def get_valid_manager(cls):
900*9c5db199SXin Li        return cls.objects
901*9c5db199SXin Li
902*9c5db199SXin Li
903*9c5db199SXin Li    def _record_attributes(self, attributes):
904*9c5db199SXin Li        """
905*9c5db199SXin Li        See on_attribute_changed.
906*9c5db199SXin Li        """
907*9c5db199SXin Li        assert not isinstance(attributes, six.string_types)
908*9c5db199SXin Li        self._recorded_attributes = dict((attribute, getattr(self, attribute))
909*9c5db199SXin Li                                         for attribute in attributes)
910*9c5db199SXin Li
911*9c5db199SXin Li
912*9c5db199SXin Li    def _check_for_updated_attributes(self):
913*9c5db199SXin Li        """
914*9c5db199SXin Li        See on_attribute_changed.
915*9c5db199SXin Li        """
916*9c5db199SXin Li        for attribute, original_value in six.iteritems(
917*9c5db199SXin Li                self._recorded_attributes):
918*9c5db199SXin Li            new_value = getattr(self, attribute)
919*9c5db199SXin Li            if original_value != new_value:
920*9c5db199SXin Li                self.on_attribute_changed(attribute, original_value)
921*9c5db199SXin Li        self._record_attributes(self._recorded_attributes.keys())
922*9c5db199SXin Li
923*9c5db199SXin Li
924*9c5db199SXin Li    def on_attribute_changed(self, attribute, old_value):
925*9c5db199SXin Li        """
926*9c5db199SXin Li        Called whenever an attribute is updated.  To be overridden.
927*9c5db199SXin Li
928*9c5db199SXin Li        To use this method, you must:
929*9c5db199SXin Li        * call _record_attributes() from __init__() (after making the super
930*9c5db199SXin Li        call) with a list of attributes for which you want to be notified upon
931*9c5db199SXin Li        change.
932*9c5db199SXin Li        * call _check_for_updated_attributes() from save().
933*9c5db199SXin Li        """
934*9c5db199SXin Li        pass
935*9c5db199SXin Li
936*9c5db199SXin Li
937*9c5db199SXin Li    def serialize(self, include_dependencies=True):
938*9c5db199SXin Li        """Serializes the object with dependencies.
939*9c5db199SXin Li
940*9c5db199SXin Li        The variable SERIALIZATION_LINKS_TO_FOLLOW defines which dependencies
941*9c5db199SXin Li        this function will serialize with the object.
942*9c5db199SXin Li
943*9c5db199SXin Li        @param include_dependencies: Whether or not to follow relations to
944*9c5db199SXin Li                                     objects this object depends on.
945*9c5db199SXin Li                                     This parameter is used when uploading
946*9c5db199SXin Li                                     jobs from a shard to the main, as the
947*9c5db199SXin Li                                     main already has all the dependent
948*9c5db199SXin Li                                     objects.
949*9c5db199SXin Li
950*9c5db199SXin Li        @returns: Dictionary representation of the object.
951*9c5db199SXin Li        """
952*9c5db199SXin Li        serialized = {}
953*9c5db199SXin Li        for field in self._meta.concrete_model._meta.local_fields:
954*9c5db199SXin Li            if field.rel is None:
955*9c5db199SXin Li                serialized[field.name] = field._get_val_from_obj(self)
956*9c5db199SXin Li            elif field.name in self.SERIALIZATION_LINKS_TO_KEEP:
957*9c5db199SXin Li                # attname will contain "_id" suffix for foreign keys,
958*9c5db199SXin Li                # e.g. HostAttribute.host will be serialized as 'host_id'.
959*9c5db199SXin Li                # Use it for easy deserialization.
960*9c5db199SXin Li                serialized[field.attname] = field._get_val_from_obj(self)
961*9c5db199SXin Li
962*9c5db199SXin Li        if include_dependencies:
963*9c5db199SXin Li            for link in self.SERIALIZATION_LINKS_TO_FOLLOW:
964*9c5db199SXin Li                serialized[link] = self._serialize_relation(link)
965*9c5db199SXin Li
966*9c5db199SXin Li        return serialized
967*9c5db199SXin Li
968*9c5db199SXin Li
969*9c5db199SXin Li    def _serialize_relation(self, link):
970*9c5db199SXin Li        """Serializes dependent objects given the name of the relation.
971*9c5db199SXin Li
972*9c5db199SXin Li        @param link: Name of the relation to take objects from.
973*9c5db199SXin Li
974*9c5db199SXin Li        @returns For To-Many relationships a list of the serialized related
975*9c5db199SXin Li            objects, for To-One relationships the serialized related object.
976*9c5db199SXin Li        """
977*9c5db199SXin Li        try:
978*9c5db199SXin Li            attr = getattr(self, link)
979*9c5db199SXin Li        except AttributeError:
980*9c5db199SXin Li            # One-To-One relationships that point to None may raise this
981*9c5db199SXin Li            return None
982*9c5db199SXin Li
983*9c5db199SXin Li        if attr is None:
984*9c5db199SXin Li            return None
985*9c5db199SXin Li        if hasattr(attr, 'all'):
986*9c5db199SXin Li            return [obj.serialize() for obj in attr.all()]
987*9c5db199SXin Li        return attr.serialize()
988*9c5db199SXin Li
989*9c5db199SXin Li
990*9c5db199SXin Li    @classmethod
991*9c5db199SXin Li    def _split_local_from_foreign_values(cls, data):
992*9c5db199SXin Li        """This splits local from foreign values in a serialized object.
993*9c5db199SXin Li
994*9c5db199SXin Li        @param data: The serialized object.
995*9c5db199SXin Li
996*9c5db199SXin Li        @returns A tuple of two lists, both containing tuples in the form
997*9c5db199SXin Li                 (link_name, link_value). The first list contains all links
998*9c5db199SXin Li                 for local fields, the second one contains those for foreign
999*9c5db199SXin Li                 fields/objects.
1000*9c5db199SXin Li        """
1001*9c5db199SXin Li        links_to_local_values, links_to_related_values = [], []
1002*9c5db199SXin Li        for link, value in six.iteritems(data):
1003*9c5db199SXin Li            if link in cls.SERIALIZATION_LINKS_TO_FOLLOW:
1004*9c5db199SXin Li                # It's a foreign key
1005*9c5db199SXin Li                links_to_related_values.append((link, value))
1006*9c5db199SXin Li            else:
1007*9c5db199SXin Li                # It's a local attribute or a foreign key
1008*9c5db199SXin Li                # we don't want to follow.
1009*9c5db199SXin Li                links_to_local_values.append((link, value))
1010*9c5db199SXin Li        return links_to_local_values, links_to_related_values
1011*9c5db199SXin Li
1012*9c5db199SXin Li
1013*9c5db199SXin Li    @classmethod
1014*9c5db199SXin Li    def _filter_update_allowed_fields(cls, data):
1015*9c5db199SXin Li        """Filters data and returns only files that updates are allowed on.
1016*9c5db199SXin Li
1017*9c5db199SXin Li        This is i.e. needed for syncing aborted bits from the main to shards.
1018*9c5db199SXin Li
1019*9c5db199SXin Li        Local links are only allowed to be updated, if they are in
1020*9c5db199SXin Li        SERIALIZATION_LOCAL_LINKS_TO_UPDATE.
1021*9c5db199SXin Li        Overwriting existing values is allowed in order to be able to sync i.e.
1022*9c5db199SXin Li        the aborted bit from the main to a shard.
1023*9c5db199SXin Li
1024*9c5db199SXin Li        The allowlisting mechanism is in place to prevent overwriting local
1025*9c5db199SXin Li        status: If all fields were overwritten, jobs would be completely be
1026*9c5db199SXin Li        set back to their original (unstarted) state.
1027*9c5db199SXin Li
1028*9c5db199SXin Li        @param data: List with tuples of the form (link_name, link_value), as
1029*9c5db199SXin Li                     returned by _split_local_from_foreign_values.
1030*9c5db199SXin Li
1031*9c5db199SXin Li        @returns List of the same format as data, but only containing data for
1032*9c5db199SXin Li                 fields that updates are allowed on.
1033*9c5db199SXin Li        """
1034*9c5db199SXin Li        return [pair for pair in data
1035*9c5db199SXin Li                if pair[0] in cls.SERIALIZATION_LOCAL_LINKS_TO_UPDATE]
1036*9c5db199SXin Li
1037*9c5db199SXin Li
1038*9c5db199SXin Li    @classmethod
1039*9c5db199SXin Li    def delete_matching_record(cls, **filter_args):
1040*9c5db199SXin Li        """Delete records matching the filter.
1041*9c5db199SXin Li
1042*9c5db199SXin Li        @param filter_args: Arguments for the django filter
1043*9c5db199SXin Li                used to locate the record to delete.
1044*9c5db199SXin Li        """
1045*9c5db199SXin Li        try:
1046*9c5db199SXin Li            existing_record = cls.objects.get(**filter_args)
1047*9c5db199SXin Li        except cls.DoesNotExist:
1048*9c5db199SXin Li            return
1049*9c5db199SXin Li        existing_record.delete()
1050*9c5db199SXin Li
1051*9c5db199SXin Li
1052*9c5db199SXin Li    def _deserialize_local(self, data):
1053*9c5db199SXin Li        """Set local attributes from a list of tuples.
1054*9c5db199SXin Li
1055*9c5db199SXin Li        @param data: List of tuples like returned by
1056*9c5db199SXin Li                     _split_local_from_foreign_values.
1057*9c5db199SXin Li        """
1058*9c5db199SXin Li        if not data:
1059*9c5db199SXin Li            return
1060*9c5db199SXin Li
1061*9c5db199SXin Li        for link, value in data:
1062*9c5db199SXin Li            setattr(self, link, value)
1063*9c5db199SXin Li        # Overwridden save() methods are prone to errors, so don't execute them.
1064*9c5db199SXin Li        # This is because:
1065*9c5db199SXin Li        # - the overwritten methods depend on ACL groups that don't yet exist
1066*9c5db199SXin Li        #   and don't handle errors
1067*9c5db199SXin Li        # - the overwritten methods think this object already exists in the db
1068*9c5db199SXin Li        #   because the id is already set
1069*9c5db199SXin Li        super(type(self), self).save()
1070*9c5db199SXin Li
1071*9c5db199SXin Li
1072*9c5db199SXin Li    def _deserialize_relations(self, data):
1073*9c5db199SXin Li        """Set foreign attributes from a list of tuples.
1074*9c5db199SXin Li
1075*9c5db199SXin Li        This deserialized the related objects using their own deserialize()
1076*9c5db199SXin Li        function and then sets the relation.
1077*9c5db199SXin Li
1078*9c5db199SXin Li        @param data: List of tuples like returned by
1079*9c5db199SXin Li                     _split_local_from_foreign_values.
1080*9c5db199SXin Li        """
1081*9c5db199SXin Li        for link, value in data:
1082*9c5db199SXin Li            self._deserialize_relation(link, value)
1083*9c5db199SXin Li        # See comment in _deserialize_local
1084*9c5db199SXin Li        super(type(self), self).save()
1085*9c5db199SXin Li
1086*9c5db199SXin Li
1087*9c5db199SXin Li    @classmethod
1088*9c5db199SXin Li    def get_record(cls, data):
1089*9c5db199SXin Li        """Retrieve a record with the data in the given input arg.
1090*9c5db199SXin Li
1091*9c5db199SXin Li        @param data: A dictionary containing the information to use in a query
1092*9c5db199SXin Li                for data. If child models have different constraints of
1093*9c5db199SXin Li                uniqueness they should override this model.
1094*9c5db199SXin Li
1095*9c5db199SXin Li        @return: An object with matching data.
1096*9c5db199SXin Li
1097*9c5db199SXin Li        @raises DoesNotExist: If a record with the given data doesn't exist.
1098*9c5db199SXin Li        """
1099*9c5db199SXin Li        return cls.objects.get(id=data['id'])
1100*9c5db199SXin Li
1101*9c5db199SXin Li
1102*9c5db199SXin Li    @classmethod
1103*9c5db199SXin Li    def deserialize(cls, data):
1104*9c5db199SXin Li        """Recursively deserializes and saves an object with it's dependencies.
1105*9c5db199SXin Li
1106*9c5db199SXin Li        This takes the result of the serialize method and creates objects
1107*9c5db199SXin Li        in the database that are just like the original.
1108*9c5db199SXin Li
1109*9c5db199SXin Li        If an object of the same type with the same id already exists, it's
1110*9c5db199SXin Li        local values will be left untouched, unless they are explicitly
1111*9c5db199SXin Li        allowlisted in SERIALIZATION_LOCAL_LINKS_TO_UPDATE.
1112*9c5db199SXin Li
1113*9c5db199SXin Li        Deserialize will always recursively propagate to all related objects
1114*9c5db199SXin Li        present in data though.
1115*9c5db199SXin Li        I.e. this is necessary to add users to an already existing acl-group.
1116*9c5db199SXin Li
1117*9c5db199SXin Li        @param data: Representation of an object and its dependencies, as
1118*9c5db199SXin Li                     returned by serialize.
1119*9c5db199SXin Li
1120*9c5db199SXin Li        @returns: The object represented by data if it didn't exist before,
1121*9c5db199SXin Li                  otherwise the object that existed before and has the same type
1122*9c5db199SXin Li                  and id as the one described by data.
1123*9c5db199SXin Li        """
1124*9c5db199SXin Li        if data is None:
1125*9c5db199SXin Li            return None
1126*9c5db199SXin Li
1127*9c5db199SXin Li        local, related = cls._split_local_from_foreign_values(data)
1128*9c5db199SXin Li        try:
1129*9c5db199SXin Li            instance = cls.get_record(data)
1130*9c5db199SXin Li            local = cls._filter_update_allowed_fields(local)
1131*9c5db199SXin Li        except cls.DoesNotExist:
1132*9c5db199SXin Li            instance = cls()
1133*9c5db199SXin Li
1134*9c5db199SXin Li        instance._deserialize_local(local)
1135*9c5db199SXin Li        instance._deserialize_relations(related)
1136*9c5db199SXin Li
1137*9c5db199SXin Li        return instance
1138*9c5db199SXin Li
1139*9c5db199SXin Li
1140*9c5db199SXin Li    def _check_update_from_shard(self, shard, updated_serialized,
1141*9c5db199SXin Li                                       *args, **kwargs):
1142*9c5db199SXin Li        """Check if an update sent from a shard is legitimate.
1143*9c5db199SXin Li
1144*9c5db199SXin Li        @raises error.UnallowedRecordsSentToMain if an update is not
1145*9c5db199SXin Li                legitimate.
1146*9c5db199SXin Li        """
1147*9c5db199SXin Li        raise NotImplementedError(
1148*9c5db199SXin Li            '_check_update_from_shard must be implemented by subclass %s '
1149*9c5db199SXin Li            'for type %s' % type(self))
1150*9c5db199SXin Li
1151*9c5db199SXin Li
1152*9c5db199SXin Li    @transaction.commit_on_success
1153*9c5db199SXin Li    def update_from_serialized(self, serialized):
1154*9c5db199SXin Li        """Updates local fields of an existing object from a serialized form.
1155*9c5db199SXin Li
1156*9c5db199SXin Li        This is different than the normal deserialize() in the way that it
1157*9c5db199SXin Li        does update local values, which deserialize doesn't, but doesn't
1158*9c5db199SXin Li        recursively propagate to related objects, which deserialize() does.
1159*9c5db199SXin Li
1160*9c5db199SXin Li        The use case of this function is to update job records on the main
1161*9c5db199SXin Li        after the jobs have been executed on a shard, as the main is not
1162*9c5db199SXin Li        interested in updates for users, labels, specialtasks, etc.
1163*9c5db199SXin Li
1164*9c5db199SXin Li        @param serialized: Representation of an object and its dependencies, as
1165*9c5db199SXin Li                           returned by serialize.
1166*9c5db199SXin Li
1167*9c5db199SXin Li        @raises ValueError: if serialized contains related objects, i.e. not
1168*9c5db199SXin Li                            only local fields.
1169*9c5db199SXin Li        """
1170*9c5db199SXin Li        local, related = (
1171*9c5db199SXin Li            self._split_local_from_foreign_values(serialized))
1172*9c5db199SXin Li        if related:
1173*9c5db199SXin Li            raise ValueError('Serialized must not contain foreign '
1174*9c5db199SXin Li                             'objects: %s' % related)
1175*9c5db199SXin Li
1176*9c5db199SXin Li        self._deserialize_local(local)
1177*9c5db199SXin Li
1178*9c5db199SXin Li
1179*9c5db199SXin Li    def custom_deserialize_relation(self, link, data):
1180*9c5db199SXin Li        """Allows overriding the deserialization behaviour by subclasses."""
1181*9c5db199SXin Li        raise NotImplementedError(
1182*9c5db199SXin Li            'custom_deserialize_relation must be implemented by subclass %s '
1183*9c5db199SXin Li            'for relation %s' % (type(self), link))
1184*9c5db199SXin Li
1185*9c5db199SXin Li
1186*9c5db199SXin Li    def _deserialize_relation(self, link, data):
1187*9c5db199SXin Li        """Deserializes related objects and sets references on this object.
1188*9c5db199SXin Li
1189*9c5db199SXin Li        Relations that point to a list of objects are handled automatically.
1190*9c5db199SXin Li        For many-to-one or one-to-one relations custom_deserialize_relation
1191*9c5db199SXin Li        must be overridden by the subclass.
1192*9c5db199SXin Li
1193*9c5db199SXin Li        Related objects are deserialized using their deserialize() method.
1194*9c5db199SXin Li        Thereby they and their dependencies are created if they don't exist
1195*9c5db199SXin Li        and saved to the database.
1196*9c5db199SXin Li
1197*9c5db199SXin Li        @param link: Name of the relation.
1198*9c5db199SXin Li        @param data: Serialized representation of the related object(s).
1199*9c5db199SXin Li                     This means a list of dictionaries for to-many relations,
1200*9c5db199SXin Li                     just a dictionary for to-one relations.
1201*9c5db199SXin Li        """
1202*9c5db199SXin Li        field = getattr(self, link)
1203*9c5db199SXin Li
1204*9c5db199SXin Li        if field and hasattr(field, 'all'):
1205*9c5db199SXin Li            self._deserialize_2m_relation(link, data, field.model)
1206*9c5db199SXin Li        else:
1207*9c5db199SXin Li            self.custom_deserialize_relation(link, data)
1208*9c5db199SXin Li
1209*9c5db199SXin Li
1210*9c5db199SXin Li    def _deserialize_2m_relation(self, link, data, related_class):
1211*9c5db199SXin Li        """Deserialize related objects for one to-many relationship.
1212*9c5db199SXin Li
1213*9c5db199SXin Li        @param link: Name of the relation.
1214*9c5db199SXin Li        @param data: Serialized representation of the related objects.
1215*9c5db199SXin Li                     This is a list with of dictionaries.
1216*9c5db199SXin Li        @param related_class: A class representing a django model, with which
1217*9c5db199SXin Li                              this class has a one-to-many relationship.
1218*9c5db199SXin Li        """
1219*9c5db199SXin Li        relation_set = getattr(self, link)
1220*9c5db199SXin Li        if related_class == self.get_attribute_model():
1221*9c5db199SXin Li            # When deserializing a model together with
1222*9c5db199SXin Li            # its attributes, clear all the exising attributes to ensure
1223*9c5db199SXin Li            # db consistency. Note 'update' won't be sufficient, as we also
1224*9c5db199SXin Li            # want to remove any attributes that no longer exist in |data|.
1225*9c5db199SXin Li            #
1226*9c5db199SXin Li            # core_filters is a dictionary of filters, defines how
1227*9c5db199SXin Li            # RelatedMangager would query for the 1-to-many relationship. E.g.
1228*9c5db199SXin Li            # Host.objects.get(
1229*9c5db199SXin Li            #     id=20).hostattribute_set.core_filters = {host_id:20}
1230*9c5db199SXin Li            # We use it to delete objects related to the current object.
1231*9c5db199SXin Li            related_class.objects.filter(**relation_set.core_filters).delete()
1232*9c5db199SXin Li        for serialized in data:
1233*9c5db199SXin Li            relation_set.add(related_class.deserialize(serialized))
1234*9c5db199SXin Li
1235*9c5db199SXin Li
1236*9c5db199SXin Li    @classmethod
1237*9c5db199SXin Li    def get_attribute_model(cls):
1238*9c5db199SXin Li        """Return the attribute model.
1239*9c5db199SXin Li
1240*9c5db199SXin Li        Subclass with attribute-like model should override this to
1241*9c5db199SXin Li        return the attribute model class. This method will be
1242*9c5db199SXin Li        called by _deserialize_2m_relation to determine whether
1243*9c5db199SXin Li        to clear the one-to-many relations first on deserialization of object.
1244*9c5db199SXin Li        """
1245*9c5db199SXin Li        return None
1246*9c5db199SXin Li
1247*9c5db199SXin Li
1248*9c5db199SXin Liclass ModelWithInvalid(ModelExtensions):
1249*9c5db199SXin Li    """
1250*9c5db199SXin Li    Overrides model methods save() and delete() to support invalidation in
1251*9c5db199SXin Li    place of actual deletion.  Subclasses must have a boolean "invalid"
1252*9c5db199SXin Li    field.
1253*9c5db199SXin Li    """
1254*9c5db199SXin Li
1255*9c5db199SXin Li    def save(self, *args, **kwargs):
1256*9c5db199SXin Li        """Saves the model"""
1257*9c5db199SXin Li        first_time = (self.id is None)
1258*9c5db199SXin Li        if first_time:
1259*9c5db199SXin Li            # see if this object was previously added and invalidated
1260*9c5db199SXin Li            my_name = getattr(self, self.name_field)
1261*9c5db199SXin Li            filters = {self.name_field : my_name, 'invalid' : True}
1262*9c5db199SXin Li            try:
1263*9c5db199SXin Li                old_object = self.__class__.objects.get(**filters)
1264*9c5db199SXin Li                self.resurrect_object(old_object)
1265*9c5db199SXin Li            except self.DoesNotExist:
1266*9c5db199SXin Li                # no existing object
1267*9c5db199SXin Li                pass
1268*9c5db199SXin Li
1269*9c5db199SXin Li        super(ModelWithInvalid, self).save(*args, **kwargs)
1270*9c5db199SXin Li
1271*9c5db199SXin Li
1272*9c5db199SXin Li    def resurrect_object(self, old_object):
1273*9c5db199SXin Li        """
1274*9c5db199SXin Li        Called when self is about to be saved for the first time and is actually
1275*9c5db199SXin Li        "undeleting" a previously deleted object.  Can be overridden by
1276*9c5db199SXin Li        subclasses to copy data as desired from the deleted entry (but this
1277*9c5db199SXin Li        superclass implementation must normally be called).
1278*9c5db199SXin Li        """
1279*9c5db199SXin Li        self.id = old_object.id
1280*9c5db199SXin Li
1281*9c5db199SXin Li
1282*9c5db199SXin Li    def clean_object(self):
1283*9c5db199SXin Li        """
1284*9c5db199SXin Li        This method is called when an object is marked invalid.
1285*9c5db199SXin Li        Subclasses should override this to clean up relationships that
1286*9c5db199SXin Li        should no longer exist if the object were deleted.
1287*9c5db199SXin Li        """
1288*9c5db199SXin Li        pass
1289*9c5db199SXin Li
1290*9c5db199SXin Li
1291*9c5db199SXin Li    def delete(self):
1292*9c5db199SXin Li        """Deletes the model"""
1293*9c5db199SXin Li        self.invalid = self.invalid
1294*9c5db199SXin Li        assert not self.invalid
1295*9c5db199SXin Li        self.invalid = True
1296*9c5db199SXin Li        self.save()
1297*9c5db199SXin Li        self.clean_object()
1298*9c5db199SXin Li
1299*9c5db199SXin Li
1300*9c5db199SXin Li    @classmethod
1301*9c5db199SXin Li    def get_valid_manager(cls):
1302*9c5db199SXin Li        return cls.valid_objects
1303*9c5db199SXin Li
1304*9c5db199SXin Li
1305*9c5db199SXin Li    class Manipulator(object):
1306*9c5db199SXin Li        """
1307*9c5db199SXin Li        Force default manipulators to look only at valid objects -
1308*9c5db199SXin Li        otherwise they will match against invalid objects when checking
1309*9c5db199SXin Li        uniqueness.
1310*9c5db199SXin Li        """
1311*9c5db199SXin Li        @classmethod
1312*9c5db199SXin Li        def _prepare(cls, model):
1313*9c5db199SXin Li            super(ModelWithInvalid.Manipulator, cls)._prepare(model)
1314*9c5db199SXin Li            cls.manager = model.valid_objects
1315*9c5db199SXin Li
1316*9c5db199SXin Li
1317*9c5db199SXin Liclass ModelWithAttributes(object):
1318*9c5db199SXin Li    """
1319*9c5db199SXin Li    Mixin class for models that have an attribute model associated with them.
1320*9c5db199SXin Li    The attribute model is assumed to have its value field named "value".
1321*9c5db199SXin Li    """
1322*9c5db199SXin Li
1323*9c5db199SXin Li    def _get_attribute_model_and_args(self, attribute):
1324*9c5db199SXin Li        """
1325*9c5db199SXin Li        Subclasses should override this to return a tuple (attribute_model,
1326*9c5db199SXin Li        keyword_args), where attribute_model is a model class and keyword_args
1327*9c5db199SXin Li        is a dict of args to pass to attribute_model.objects.get() to get an
1328*9c5db199SXin Li        instance of the given attribute on this object.
1329*9c5db199SXin Li        """
1330*9c5db199SXin Li        raise NotImplementedError
1331*9c5db199SXin Li
1332*9c5db199SXin Li
1333*9c5db199SXin Li    def _is_replaced_by_static_attribute(self, attribute):
1334*9c5db199SXin Li        """
1335*9c5db199SXin Li        Subclasses could override this to indicate whether it has static
1336*9c5db199SXin Li        attributes.
1337*9c5db199SXin Li        """
1338*9c5db199SXin Li        return False
1339*9c5db199SXin Li
1340*9c5db199SXin Li
1341*9c5db199SXin Li    def set_attribute(self, attribute, value):
1342*9c5db199SXin Li        if self._is_replaced_by_static_attribute(attribute):
1343*9c5db199SXin Li            raise error.UnmodifiableAttributeException(
1344*9c5db199SXin Li                    'Failed to set attribute "%s" for host "%s" since it '
1345*9c5db199SXin Li                    'is static. Use go/chromeos-skylab-inventory-tools to '
1346*9c5db199SXin Li                    'modify this attribute.' % (attribute, self.hostname))
1347*9c5db199SXin Li
1348*9c5db199SXin Li        attribute_model, get_args = self._get_attribute_model_and_args(
1349*9c5db199SXin Li            attribute)
1350*9c5db199SXin Li        attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
1351*9c5db199SXin Li        attribute_object.value = value
1352*9c5db199SXin Li        attribute_object.save()
1353*9c5db199SXin Li
1354*9c5db199SXin Li
1355*9c5db199SXin Li    def delete_attribute(self, attribute):
1356*9c5db199SXin Li        """Deletes an attribute"""
1357*9c5db199SXin Li        if self._is_replaced_by_static_attribute(attribute):
1358*9c5db199SXin Li            raise error.UnmodifiableAttributeException(
1359*9c5db199SXin Li                    'Failed to delete attribute "%s" for host "%s" since it '
1360*9c5db199SXin Li                    'is static. Use go/chromeos-skylab-inventory-tools to '
1361*9c5db199SXin Li                    'modify this attribute.' % (attribute, self.hostname))
1362*9c5db199SXin Li
1363*9c5db199SXin Li        attribute_model, get_args = self._get_attribute_model_and_args(
1364*9c5db199SXin Li            attribute)
1365*9c5db199SXin Li        try:
1366*9c5db199SXin Li            attribute_model.objects.get(**get_args).delete()
1367*9c5db199SXin Li        except attribute_model.DoesNotExist:
1368*9c5db199SXin Li            pass
1369*9c5db199SXin Li
1370*9c5db199SXin Li
1371*9c5db199SXin Li    def set_or_delete_attribute(self, attribute, value):
1372*9c5db199SXin Li        if value is None:
1373*9c5db199SXin Li            self.delete_attribute(attribute)
1374*9c5db199SXin Li        else:
1375*9c5db199SXin Li            self.set_attribute(attribute, value)
1376*9c5db199SXin Li
1377*9c5db199SXin Li
1378*9c5db199SXin Liclass ModelWithHashManager(dbmodels.Manager):
1379*9c5db199SXin Li    """Manager for use with the ModelWithHash abstract model class"""
1380*9c5db199SXin Li
1381*9c5db199SXin Li    def create(self, **kwargs):
1382*9c5db199SXin Li        """Always raises exception."""
1383*9c5db199SXin Li        raise Exception('ModelWithHash manager should use get_or_create() '
1384*9c5db199SXin Li                        'instead of create()')
1385*9c5db199SXin Li
1386*9c5db199SXin Li
1387*9c5db199SXin Li    def get_or_create(self, **kwargs):
1388*9c5db199SXin Li        kwargs['the_hash'] = self.model._compute_hash(**kwargs)
1389*9c5db199SXin Li        return super(ModelWithHashManager, self).get_or_create(**kwargs)
1390*9c5db199SXin Li
1391*9c5db199SXin Li
1392*9c5db199SXin Liclass ModelWithHash(dbmodels.Model):
1393*9c5db199SXin Li    """Superclass with methods for dealing with a hash column"""
1394*9c5db199SXin Li
1395*9c5db199SXin Li    the_hash = dbmodels.CharField(max_length=40, unique=True)
1396*9c5db199SXin Li
1397*9c5db199SXin Li    objects = ModelWithHashManager()
1398*9c5db199SXin Li
1399*9c5db199SXin Li    class Meta:
1400*9c5db199SXin Li        """Overrides dbmodels.Model.Meta."""
1401*9c5db199SXin Li        abstract = True
1402*9c5db199SXin Li
1403*9c5db199SXin Li
1404*9c5db199SXin Li    @classmethod
1405*9c5db199SXin Li    def _compute_hash(cls, **kwargs):
1406*9c5db199SXin Li        raise NotImplementedError('Subclasses must override _compute_hash()')
1407*9c5db199SXin Li
1408*9c5db199SXin Li
1409*9c5db199SXin Li    def save(self, force_insert=False, **kwargs):
1410*9c5db199SXin Li        """Prevents saving the model in most cases
1411*9c5db199SXin Li
1412*9c5db199SXin Li        We want these models to be immutable, so the generic save() operation
1413*9c5db199SXin Li        will not work. These models should be instantiated through their the
1414*9c5db199SXin Li        model.objects.get_or_create() method instead.
1415*9c5db199SXin Li
1416*9c5db199SXin Li        The exception is that save(force_insert=True) will be allowed, since
1417*9c5db199SXin Li        that creates a new row. However, the preferred way to make instances of
1418*9c5db199SXin Li        these models is through the get_or_create() method.
1419*9c5db199SXin Li        """
1420*9c5db199SXin Li        if not force_insert:
1421*9c5db199SXin Li            # Allow a forced insert to happen; if it's a duplicate, the unique
1422*9c5db199SXin Li            # constraint will catch it later anyways
1423*9c5db199SXin Li            raise Exception('ModelWithHash is immutable')
1424*9c5db199SXin Li        super(ModelWithHash, self).save(force_insert=force_insert, **kwargs)
1425