xref: /aosp_15_r20/external/autotest/client/cros/cellular/mbim_compliance/mbim_message.py (revision 9c5db1993ded3edbeafc8092d69fe5de2ee02df7)
1# Lint as: python2, python3
2# Copyright 2015 The Chromium OS Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5"""
6All of the MBIM messages are created using the MBIMControlMessageMeta metaclass.
7The metaclass supports a hierarchy of message definitions so that each message
8definition extends the structure of the base class it inherits.
9
10(mbim_message.py)
11MBIMControlMessage|         (mbim_message_request.py)
12                  |>MBIMControlMessageRequest |
13                  |                           |>MBIMOpen
14                  |                           |>MBIMClose
15                  |                           |>MBIMCommand    |
16                  |                           |                |>MBIMSetConnect
17                  |                           |                |>...
18                  |                           |
19                  |                           |>MBIMHostError
20                  |
21                  |         (mbim_message_response.py)
22                  |>MBIMControlMessageResponse|
23                                              |>MBIMOpenDone
24                                              |>MBIMCloseDone
25                                              |>MBIMCommandDone|
26                                              |                |>MBIMConnectInfo
27                                              |                |>...
28                                              |
29                                              |>MBIMHostError
30"""
31from __future__ import absolute_import
32from __future__ import division
33from __future__ import print_function
34
35import array
36import logging
37import struct
38from collections import namedtuple
39
40import six
41
42from six.moves import map
43from six.moves import zip
44
45from autotest_lib.client.cros.cellular.mbim_compliance import mbim_errors
46
47
48# Type of message classes. The values of each field in the message is stored
49# as an attribute of the object created.
50# Request message classes accepts values for the attributes of the object.
51MESSAGE_TYPE_REQUEST = 1
52# Response message classes accepts raw_data which is parsed into attributes of
53# the object.
54MESSAGE_TYPE_RESPONSE = 2
55
56# Message field types.
57# Just a normal field type. No special properties.
58FIELD_TYPE_NORMAL = 1
59# Identify the payload ID for a message. This is used in  parsing of
60# response messages to help in identifying the child message class.
61FIELD_TYPE_PAYLOAD_ID = 2
62# Total length of the message including any payload_buffer it may contain.
63FIELD_TYPE_TOTAL_LEN = 3
64# Length of the payload contained in the payload_buffer.
65FIELD_TYPE_PAYLOAD_LEN = 4
66# Number of fragments of this message.
67FIELD_TYPE_NUM_FRAGMENTS = 5
68# Transaction ID of this message
69FIELD_TYPE_TRANSACTION_ID = 6
70
71
72def message_class_new(cls, **kwargs):
73    """
74    Creates a message instance with either the given field name/value
75    pairs or raw data buffer.
76
77    The total_length and transaction_id fields are automatically calculated
78    if not explicitly provided in the message args.
79
80    @param kwargs: Dictionary of (field_name, field_value) pairs or
81                    raw_data=Packed binary array.
82    @returns New message object created.
83
84    """
85    if 'raw_data' in kwargs and kwargs['raw_data']:
86        # We unpack the raw data received into the appropriate fields
87        # for this class. If there is some additional data present in
88        # |raw_data| that does not fit the format of the structure,
89        # they're stored in the variable sized |payload_buffer| field.
90        raw_data = kwargs['raw_data']
91        data_format = cls.get_field_format_string(get_all=True)
92        unpack_length = cls.get_struct_len(get_all=True)
93        data_length = len(raw_data)
94        if data_length < unpack_length:
95            mbim_errors.log_and_raise(
96                    mbim_errors.MBIMComplianceControlMessageError,
97                    'Length of Data (%d) to be parsed less than message'
98                    ' structure length (%d)' %
99                    (data_length, unpack_length))
100        obj = super(cls, cls).__new__(cls, *struct.unpack_from(data_format,
101                                                               raw_data))
102        if data_length > unpack_length:
103            setattr(obj, 'payload_buffer', raw_data[unpack_length:])
104        else:
105            setattr(obj, 'payload_buffer', None)
106        return obj
107    else:
108        # Check if all the fields have been populated for this message
109        # except for transaction ID and message length since these
110        # are generated during init.
111        field_values = []
112        fields = cls.get_fields(get_all=True)
113        defaults = cls.get_defaults(get_all=True)
114        for _, field_name, field_type in fields:
115            if field_name not in kwargs:
116                if field_type == FIELD_TYPE_TOTAL_LEN:
117                    field_value = cls.get_struct_len(get_all=True)
118                    if 'payload_buffer' in kwargs:
119                        field_value += len(kwargs.get('payload_buffer'))
120                elif field_type == FIELD_TYPE_TRANSACTION_ID:
121                    field_value = cls.get_next_transaction_id()
122                else:
123                    field_value = defaults.get(field_name, None)
124                if field_value is None:
125                    mbim_errors.log_and_raise(
126                            mbim_errors.MBIMComplianceControlMessageError,
127                            'Missing field value (%s) in %s' % (
128                                    field_name, cls.__name__))
129                field_values.append(field_value)
130            else:
131                field_values.append(kwargs.pop(field_name))
132        obj = super(cls, cls).__new__(cls, *field_values)
133        # We need to account for optional variable sized payload_buffer
134        # in some messages which are not explicitly mentioned in the
135        # |cls._FIELDS| attribute.
136        if 'payload_buffer' in kwargs:
137            setattr(obj, 'payload_buffer', kwargs.pop('payload_buffer'))
138        else:
139            setattr(obj, 'payload_buffer', None)
140        if kwargs:
141            mbim_errors.log_and_raise(
142                    mbim_errors.MBIMComplianceControlMessageError,
143                    'Unexpected fields (%s) in %s' % (
144                            list(kwargs.keys()), cls.__name__))
145        return obj
146
147
148class MBIMControlMessageMeta(type):
149    """
150    Metaclass for all the control message parsing/generation.
151
152    The metaclass creates each class by concatenating all the message fields
153    from it's base classes to create a hierarchy of messages.
154    Thus the payload class of each message class becomes the subclass of that
155    message.
156
157    Message definition attributes->
158    _FIELDS(optional): Used to define structure elements. The fields of a
159                       message is the concatenation of the _FIELDS attribute
160                       along with all the _FIELDS attribute from it's parent
161                       classes.
162    _DEFAULTS(optional): Field name/value pairs to be assigned to some
163                         of the fields if they are fixed for a message type.
164                         These are generally used to assign values to fields in
165                         the parent class.
166    _IDENTIFIERS(optional): Field name/value pairs to be used to idenitfy this
167                            message during parsing from raw_data.
168    _SECONDARY_FRAGMENTS(optional): Used to identify if this class can be
169                                    fragmented and name of secondary class
170                                    definition.
171    MESSAGE_TYPE: Used to identify request/repsonse classes.
172
173    Message internal attributes->
174    _CONSOLIDATED_FIELDS: Consolidated list of all the fields defining this
175                          message.
176    _CONSOLIDATED_DEFAULTS: Consolidated list of all the default field
177                            name/value pairs for this  message.
178
179    """
180    def __new__(mcs, name, bases, attrs):
181        # The MBIMControlMessage base class, which inherits from 'object',
182        # is merely used to establish the class hierarchy and is never
183        # constructed on it's own.
184        if object in bases:
185            return super(MBIMControlMessageMeta, mcs).__new__(
186                    mcs, name, bases, attrs)
187
188        # Append the current class fields, defaults to any base parent class
189        # fields.
190        fields = []
191        defaults = {}
192        for base_class in bases:
193            if hasattr(base_class, '_CONSOLIDATED_FIELDS'):
194                fields = getattr(base_class, '_CONSOLIDATED_FIELDS')
195            if hasattr(base_class, '_CONSOLIDATED_DEFAULTS'):
196                defaults = getattr(base_class, '_CONSOLIDATED_DEFAULTS').copy()
197        if '_FIELDS' in attrs:
198            fields = fields + list(map(list, attrs['_FIELDS']))
199        if '_DEFAULTS' in attrs:
200            defaults.update(attrs['_DEFAULTS'])
201        attrs['_CONSOLIDATED_FIELDS'] = fields
202        attrs['_CONSOLIDATED_DEFAULTS'] = defaults
203
204        if not fields:
205            mbim_errors.log_and_raise(
206                    mbim_errors.MBIMComplianceControlMessageError,
207                    '%s message must have some fields defined' % name)
208
209        attrs['__new__'] = message_class_new
210        _, field_names, _ = list(zip(*fields))
211        message_class = namedtuple(name, field_names)
212        # Prepend the class created via namedtuple to |bases| in order to
213        # correctly resolve the __new__ method while preserving the class
214        # hierarchy.
215        cls = super(MBIMControlMessageMeta, mcs).__new__(
216                mcs, name, (message_class,) + bases, attrs)
217        return cls
218
219
220class MBIMControlMessage(six.with_metaclass(MBIMControlMessageMeta, object)):
221    """
222    MBIMControlMessage base class.
223
224    This class should not be instantiated or used directly.
225
226    """
227    _NEXT_TRANSACTION_ID = 0X00000000
228
229
230    @classmethod
231    def _find_subclasses(cls):
232        """
233        Helper function to find all the derived payload classes of this
234        class.
235
236        """
237        return [c for c in cls.__subclasses__()]
238
239
240    @classmethod
241    def get_fields(cls, get_all=False):
242        """
243        Helper function to find all the fields of this class.
244
245        Returns either the total message fields or only the current
246        substructure fields in the nested message.
247
248        @param get_all: Whether to return the total struct fields or sub struct
249                         fields.
250        @returns Fields of the structure.
251
252        """
253        if get_all:
254            return cls._CONSOLIDATED_FIELDS
255        else:
256            return cls._FIELDS
257
258
259    @classmethod
260    def get_defaults(cls, get_all=False):
261        """
262        Helper function to find all the default field values of this class.
263
264        Returns either the total message default field name/value pairs or only
265        the current substructure defaults in the nested message.
266
267        @param get_all: Whether to return the total struct defaults or sub
268                         struct defaults.
269        @returns Defaults of the structure.
270
271        """
272        if get_all:
273            return cls._CONSOLIDATED_DEFAULTS
274        else:
275            return cls._DEFAULTS
276
277
278    @classmethod
279    def _get_identifiers(cls):
280        """
281        Helper function to find all the identifier field name/value pairs of
282        this class.
283
284        @returns All the idenitifiers of this class.
285
286        """
287        return getattr(cls, '_IDENTIFIERS', None)
288
289
290    @classmethod
291    def _find_field_names_of_type(cls, find_type, get_all=False):
292        """
293        Helper function to find all the field names which matches the field_type
294        specified.
295
296        params find_type: One of the FIELD_TYPE_* enum values specified above.
297        @returns Corresponding field names if found, else None.
298        """
299        fields = cls.get_fields(get_all=get_all)
300        field_names = []
301        for _, field_name, field_type in fields:
302            if field_type == find_type:
303                field_names.append(field_name)
304        return field_names
305
306
307    @classmethod
308    def get_secondary_fragment(cls):
309        """
310        Helper function to retrieve the associated secondary fragment class.
311
312        @returns |_SECONDARY_FRAGMENT| attribute of the class
313
314        """
315        return getattr(cls, '_SECONDARY_FRAGMENT', None)
316
317
318    @classmethod
319    def get_field_names(cls, get_all=True):
320        """
321        Helper function to return the field names of the message.
322
323        @returns The field names of the message structure.
324
325        """
326        _, field_names, _ = list(zip(*cls.get_fields(get_all=get_all)))
327        return field_names
328
329
330    @classmethod
331    def get_field_formats(cls, get_all=True):
332        """
333        Helper function to return the field formats of the message.
334
335        @returns The format of fields of the message structure.
336
337        """
338        field_formats, _, _ = list(zip(*cls.get_fields(get_all=get_all)))
339        return field_formats
340
341
342    @classmethod
343    def get_field_format_string(cls, get_all=True):
344        """
345        Helper function to return the field format string of the message.
346
347        @returns The format string of the message structure.
348
349        """
350        format_string = '<' + ''.join(cls.get_field_formats(get_all=get_all))
351        return format_string
352
353
354    @classmethod
355    def get_struct_len(cls, get_all=False):
356        """
357        Returns the length of the structure representing the message.
358
359        Returns the length of either the total message or only the current
360        substructure in the nested message.
361
362        @param get_all: Whether to return the total struct length or sub struct
363                length.
364        @returns Length of the structure.
365
366        """
367        return struct.calcsize(cls.get_field_format_string(get_all=get_all))
368
369
370    @classmethod
371    def find_primary_parent_fragment(cls):
372        """
373        Traverses up the message tree to find the primary fragment class
374        at the same tree level as the secondary frag class associated with this
375        message class. This should only be called on primary fragment derived
376        classes!
377
378        @returns Primary frag class associated with the message.
379
380        """
381        secondary_frag_cls = cls.get_secondary_fragment()
382        secondary_frag_parent_cls = secondary_frag_cls.__bases__[1]
383        message_cls = cls
384        message_parent_cls = message_cls.__bases__[1]
385        while message_parent_cls != secondary_frag_parent_cls:
386            message_cls = message_parent_cls
387            message_parent_cls = message_cls.__bases__[1]
388        return message_cls
389
390
391    @classmethod
392    def get_next_transaction_id(cls):
393        """
394        Returns incrementing transaction ids on successive calls.
395
396        @returns The tracsaction id for control message delivery.
397
398        """
399        if MBIMControlMessage._NEXT_TRANSACTION_ID > (six.MAXSIZE - 2):
400            MBIMControlMessage._NEXT_TRANSACTION_ID = 0x00000000
401        MBIMControlMessage._NEXT_TRANSACTION_ID += 1
402        return MBIMControlMessage._NEXT_TRANSACTION_ID
403
404
405    def _get_fields_of_type(self, field_type, get_all=False):
406        """
407        Helper function to find all the field name/value of the specified type
408        in the given object.
409
410        @returns Corresponding map of field name/value pairs extracted from the
411                object.
412
413        """
414        field_names = self.__class__._find_field_names_of_type(field_type,
415                                                               get_all=get_all)
416        return {f: getattr(self, f) for f in field_names}
417
418
419    def _get_payload_id_fields(self):
420        """
421        Helper function to find all the payload id field name/value in the given
422        object.
423
424        @returns Corresponding field name/value pairs extracted from the object.
425
426        """
427        return self._get_fields_of_type(FIELD_TYPE_PAYLOAD_ID)
428
429
430    def get_payload_len(self):
431        """
432        Helper function to find the payload len field value in the given
433        object.
434
435        @returns Corresponding field value extracted from the object.
436
437        """
438        payload_len_fields = self._get_fields_of_type(FIELD_TYPE_PAYLOAD_LEN)
439        if ((not payload_len_fields) or (len(payload_len_fields) > 1)):
440            mbim_errors.log_and_raise(
441                    mbim_errors.MBIMComplianceControlMessageError,
442                    "Erorr in finding payload len field in message: %s" %
443                    self.__class__.__name__)
444        return list(payload_len_fields.values())[0]
445
446
447    def get_total_len(self):
448        """
449        Helper function to find the total len field value in the given
450        object.
451
452        @returns Corresponding field value extracted from the object.
453
454        """
455        total_len_fields = self._get_fields_of_type(FIELD_TYPE_TOTAL_LEN,
456                                                    get_all=True)
457        if ((not total_len_fields) or (len(total_len_fields) > 1)):
458            mbim_errors.log_and_raise(
459                    mbim_errors.MBIMComplianceControlMessageError,
460                    "Erorr in finding total len field in message: %s" %
461                    self.__class__.__name__)
462        return list(total_len_fields.values())[0]
463
464
465    def get_num_fragments(self):
466        """
467        Helper function to find the fragment num field value in the given
468        object.
469
470        @returns Corresponding field value extracted from the object.
471
472        """
473        num_fragment_fields = self._get_fields_of_type(FIELD_TYPE_NUM_FRAGMENTS)
474        if ((not num_fragment_fields) or (len(num_fragment_fields) > 1)):
475            mbim_errors.log_and_raise(
476                    mbim_errors.MBIMComplianceControlMessageError,
477                    "Erorr in finding num fragments field in message: %s" %
478                    self.__class__.__name__)
479        return list(num_fragment_fields.values())[0]
480
481
482    def find_payload_class(self):
483        """
484        Helper function to find the derived class which has the default
485        |payload_id| fields matching the current message contents.
486
487        @returns Corresponding class if found, else None.
488
489        """
490        cls = self.__class__
491        for payload_cls in cls._find_subclasses():
492            message_ids = self._get_payload_id_fields()
493            subclass_ids = payload_cls._get_identifiers()
494            if message_ids == subclass_ids:
495                return payload_cls
496        return None
497
498
499    def calculate_total_len(self):
500        """
501        Helper function to calculate the total len of a given message
502        object.
503
504        @returns Total length of the message.
505
506        """
507        message_class = self.__class__
508        total_len = message_class.get_struct_len(get_all=True)
509        if self.payload_buffer:
510            total_len += len(self.payload_buffer)
511        return total_len
512
513
514    def pack(self, format_string, field_names):
515        """
516        Packs a list of fields based on their formats.
517
518        @param format_string: The concatenated formats for the fields given in
519                |field_names|.
520        @param field_names: The name of the fields to be packed.
521        @returns The packet in binary array form.
522
523        """
524        field_values = [getattr(self, name) for name in field_names]
525        return array.array('B', struct.pack(format_string, *field_values))
526
527
528    def print_all_fields(self):
529        """Prints all the field name, value pair of this message."""
530        logging.debug('Class Name: %s', self.__class__.__name__)
531        for field_name in self.__class__.get_field_names(get_all=True):
532            logging.debug('Field Name: %s, Field Value: %s',
533                           field_name, str(getattr(self, field_name)))
534        if self.payload_buffer:
535            logging.debug('Payload: %s', str(getattr(self, 'payload_buffer')))
536
537
538    def create_raw_data(self):
539        """
540        Creates the raw binary data corresponding to the message struct.
541
542        @param payload_buffer: Variable sized paylaod buffer to attach at the
543                end of the msg.
544        @returns Packed byte array of the message.
545
546        """
547        message = self
548        message_class = message.__class__
549        format_string = message_class.get_field_format_string()
550        field_names = message_class.get_field_names()
551        packet = message.pack(format_string, field_names)
552        if self.payload_buffer:
553            packet.extend(self.payload_buffer)
554        return packet
555
556
557    def copy(self, **fields_to_alter):
558        """
559        Replaces the message tuple with updated field values.
560
561        @param fields_to_alter: Field name/value pairs to be changed.
562        @returns Updated message with the field values updated.
563
564        """
565        message = self._replace(**fields_to_alter)
566        # Copy the associated payload_buffer field to the new tuple.
567        message.payload_buffer = self.payload_buffer
568        return message
569