1 /*
2  * Copyright 2016, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <stdbool.h>
18 #include <stdint.h>
19 #include <stdio.h>
20 #include <stdarg.h>
21 
22 #include "v7/apf_defs.h"
23 #include "v7/apf.h"
24 #include "disassembler.h"
25 
26 // If "c" is of a signed type, generate a compile warning that gets promoted to an error.
27 // This makes bounds checking simpler because ">= 0" can be avoided. Otherwise adding
28 // superfluous ">= 0" with unsigned expressions generates compile warnings.
29 #define ENFORCE_UNSIGNED(c) ((c)==(uint32_t)(c))
30 
31 char print_buf[1024];
32 char* buf_ptr;
33 int buf_remain;
34 bool v6_mode = false;
35 
36 __attribute__ ((format (printf, 1, 2) ))
bprintf(const char * format,...)37 static void bprintf(const char* format, ...) {
38     va_list args;
39     va_start(args, format);
40     int ret = vsnprintf(buf_ptr, buf_remain, format, args);
41     va_end(args);
42     if (ret < 0) return;
43     if (ret >= buf_remain) ret = buf_remain;
44     buf_ptr += ret;
45     buf_remain -= ret;
46 }
47 
print_opcode(const char * opcode)48 static void print_opcode(const char* opcode) {
49     bprintf("%-12s", opcode);
50 }
51 
52 // Mapping from opcode number to opcode name.
53 static const char* opcode_names [] = {
54     [LDB_OPCODE] = "ldb",
55     [LDH_OPCODE] = "ldh",
56     [LDW_OPCODE] = "ldw",
57     [LDBX_OPCODE] = "ldbx",
58     [LDHX_OPCODE] = "ldhx",
59     [LDWX_OPCODE] = "ldwx",
60     [ADD_OPCODE] = "add",
61     [MUL_OPCODE] = "mul",
62     [DIV_OPCODE] = "div",
63     [AND_OPCODE] = "and",
64     [OR_OPCODE] = "or",
65     [SH_OPCODE] = "sh",
66     [LI_OPCODE] = "li",
67     [JMP_OPCODE] = "jmp",
68     [JEQ_OPCODE] = "jeq",
69     [JNE_OPCODE] = "jne",
70     [JGT_OPCODE] = "jgt",
71     [JLT_OPCODE] = "jlt",
72     [JSET_OPCODE] = "jset",
73     [JBSMATCH_OPCODE] = NULL,
74     [LDDW_OPCODE] = "lddw",
75     [STDW_OPCODE] = "stdw",
76     [WRITE_OPCODE] = "write",
77     [JNSET_OPCODE] = "jnset",
78 };
79 
print_jump_target(uint32_t target,uint32_t program_len)80 static void print_jump_target(uint32_t target, uint32_t program_len) {
81     if (target == program_len) {
82         bprintf("PASS");
83     } else if (target == program_len + 1) {
84         bprintf("DROP");
85     } else {
86         bprintf("%u", target);
87     }
88 }
89 
apf_disassemble(const uint8_t * program,uint32_t program_len,uint32_t * const ptr2pc)90 const char* apf_disassemble(const uint8_t* program, uint32_t program_len, uint32_t* const ptr2pc) {
91     buf_ptr = print_buf;
92     buf_remain = sizeof(print_buf);
93     if (*ptr2pc > program_len + 1) {
94         bprintf("pc is overflow: pc %d, program_len: %d", *ptr2pc, program_len);
95         return print_buf;
96     }
97 
98     bprintf("%8u: ", *ptr2pc);
99 
100     if (*ptr2pc == program_len) {
101         bprintf("PASS");
102         ++(*ptr2pc);
103         return print_buf;
104     }
105 
106     if (*ptr2pc == program_len + 1) {
107         bprintf("DROP");
108         ++(*ptr2pc);
109         return print_buf;
110     }
111 
112     const uint8_t bytecode = program[(*ptr2pc)++];
113     const uint32_t opcode = EXTRACT_OPCODE(bytecode);
114 
115 #define PRINT_OPCODE() print_opcode(opcode_names[opcode])
116 #define DECODE_IMM(length)  ({                                        \
117     uint32_t value = 0;                                               \
118     for (uint32_t i = 0; i < (length) && *ptr2pc < program_len; i++)  \
119         value = (value << 8) | program[(*ptr2pc)++];                  \
120     value;})
121 
122     const uint32_t reg_num = EXTRACT_REGISTER(bytecode);
123     // All instructions have immediate fields, so load them now.
124     const uint32_t len_field = EXTRACT_IMM_LENGTH(bytecode);
125     uint32_t imm = 0;
126     int32_t signed_imm = 0;
127     if (len_field != 0) {
128         const uint32_t imm_len = 1 << (len_field - 1);
129         imm = DECODE_IMM(imm_len);
130         // Sign extend imm into signed_imm.
131         signed_imm = imm << ((4 - imm_len) * 8);
132         signed_imm >>= (4 - imm_len) * 8;
133     }
134     switch (opcode) {
135         case PASSDROP_OPCODE:
136             if (reg_num == 0) {
137                 print_opcode("pass");
138             } else {
139                 print_opcode("drop");
140             }
141             if (imm > 0) {
142                 bprintf("counter=%d", imm);
143             }
144             break;
145         case LDB_OPCODE:
146         case LDH_OPCODE:
147         case LDW_OPCODE:
148             PRINT_OPCODE();
149             bprintf("r%d, [%u]", reg_num, imm);
150             break;
151         case LDBX_OPCODE:
152         case LDHX_OPCODE:
153         case LDWX_OPCODE:
154             PRINT_OPCODE();
155             if (imm) {
156                 bprintf("r%d, [r1+%u]", reg_num, imm);
157             } else {
158                 bprintf("r%d, [r1]", reg_num);
159             }
160             break;
161         case JMP_OPCODE:
162             if (reg_num == 0) {
163                 PRINT_OPCODE();
164                 print_jump_target(*ptr2pc + imm, program_len);
165             } else {
166                 v6_mode = true;
167                 print_opcode("data");
168                 bprintf("%d, ", imm);
169                 uint32_t len = imm;
170                 while (len--) bprintf("%02x", program[(*ptr2pc)++]);
171             }
172             break;
173         case JEQ_OPCODE:
174         case JNE_OPCODE:
175         case JGT_OPCODE:
176         case JLT_OPCODE:
177         case JSET_OPCODE:
178         case JNSET_OPCODE: {
179             PRINT_OPCODE();
180             bprintf("r0, ");
181             // Load second immediate field.
182             if (reg_num == 1) {
183                 bprintf("r1, ");
184             } else if (len_field == 0) {
185                 bprintf("0, ");
186             } else {
187                 uint32_t cmp_imm = DECODE_IMM(1 << (len_field - 1));
188                 bprintf("0x%x, ", cmp_imm);
189             }
190             print_jump_target(*ptr2pc + imm, program_len);
191             break;
192         }
193         case JBSMATCH_OPCODE: {
194             if (reg_num == 0) {
195                 print_opcode("jbsne");
196             } else {
197                 print_opcode("jbseq");
198             }
199             bprintf("r0, ");
200             const uint32_t cmp_imm = DECODE_IMM(1 << (len_field - 1));
201             const uint32_t cnt = (cmp_imm >> 11) + 1; // 1+, up to 32 fits in u16
202             const uint32_t len = cmp_imm & 2047; // 0..2047
203             bprintf("0x%x, ", len);
204             print_jump_target(*ptr2pc + imm + cnt * len, program_len);
205             bprintf(", ");
206             if (cnt > 1) {
207                 bprintf("{ ");
208             }
209             for (uint32_t i = 0; i < cnt; ++i) {
210                 for (uint32_t j = 0; j < len; ++j) {
211                     uint8_t byte = program[(*ptr2pc)++];
212                     bprintf("%02x", byte);
213                 }
214                 if (i != cnt - 1) {
215                     bprintf(", ");
216                 }
217             }
218             if (cnt > 1) {
219                 bprintf(" }");
220             }
221             break;
222         }
223         case SH_OPCODE:
224             PRINT_OPCODE();
225             if (reg_num) {
226                 bprintf("r0, r1");
227             } else {
228                 bprintf("r0, %d", signed_imm);
229             }
230             break;
231         case ADD_OPCODE:
232         case MUL_OPCODE:
233         case DIV_OPCODE:
234         case AND_OPCODE:
235         case OR_OPCODE:
236             PRINT_OPCODE();
237             if (reg_num) {
238                 bprintf("r0, r1");
239             } else if (!imm && opcode == DIV_OPCODE) {
240                 bprintf("pass (div 0)");
241             } else {
242                 bprintf("r0, %u", imm);
243             }
244             break;
245         case LI_OPCODE:
246             PRINT_OPCODE();
247             bprintf("r%d, %d", reg_num, signed_imm);
248             break;
249         case EXT_OPCODE:
250             if (
251 // If LDM_EXT_OPCODE is 0 and imm is compared with it, a compiler error will result,
252 // instead just enforce that imm is unsigned (so it's always greater or equal to 0).
253 #if LDM_EXT_OPCODE == 0
254                 ENFORCE_UNSIGNED(imm) &&
255 #else
256                 imm >= LDM_EXT_OPCODE &&
257 #endif
258                 imm < (LDM_EXT_OPCODE + MEMORY_ITEMS)) {
259                 print_opcode("ldm");
260                 bprintf("r%d, m[%u]", reg_num, imm - LDM_EXT_OPCODE);
261             } else if (imm >= STM_EXT_OPCODE && imm < (STM_EXT_OPCODE + MEMORY_ITEMS)) {
262                 print_opcode("stm");
263                 bprintf("r%d, m[%u]", reg_num, imm - STM_EXT_OPCODE);
264             } else switch (imm) {
265                 case NOT_EXT_OPCODE:
266                     print_opcode("not");
267                     bprintf("r%d", reg_num);
268                     break;
269                 case NEG_EXT_OPCODE:
270                     print_opcode("neg");
271                     bprintf("r%d", reg_num);
272                     break;
273                 case SWAP_EXT_OPCODE:
274                     print_opcode("swap");
275                     break;
276                 case MOV_EXT_OPCODE:
277                     print_opcode("mov");
278                     bprintf("r%d, r%d", reg_num, reg_num ^ 1);
279                     break;
280                 case ALLOCATE_EXT_OPCODE:
281                     print_opcode("allocate");
282                     if (reg_num == 0) {
283                         bprintf("r%d", reg_num);
284                     } else {
285                         uint32_t alloc_len = DECODE_IMM(2);
286                         bprintf("%d", alloc_len);
287                     }
288                     break;
289                 case TRANSMIT_EXT_OPCODE:
290                     print_opcode(reg_num ? "transmitudp" : "transmit");
291                     u8 ip_ofs = DECODE_IMM(1);
292                     u8 csum_ofs = DECODE_IMM(1);
293                     if (csum_ofs < 255) {
294                         u8 csum_start = DECODE_IMM(1);
295                         u16 partial_csum = DECODE_IMM(2);
296                         bprintf("ip_ofs=%d, csum_ofs=%d, csum_start=%d, partial_csum=0x%04x",
297                                 ip_ofs, csum_ofs, csum_start, partial_csum);
298                     } else {
299                         bprintf("ip_ofs=%d", ip_ofs);
300                     }
301                     break;
302                 case EWRITE1_EXT_OPCODE: print_opcode("ewrite1"); bprintf("r%d", reg_num); break;
303                 case EWRITE2_EXT_OPCODE: print_opcode("ewrite2"); bprintf("r%d", reg_num); break;
304                 case EWRITE4_EXT_OPCODE: print_opcode("ewrite4"); bprintf("r%d", reg_num); break;
305                 case EPKTDATACOPYIMM_EXT_OPCODE:
306                 case EPKTDATACOPYR1_EXT_OPCODE: {
307                     if (reg_num == 0) {
308                         print_opcode("epktcopy");
309                     } else {
310                         print_opcode("edatacopy");
311                     }
312                     if (imm == EPKTDATACOPYIMM_EXT_OPCODE) {
313                         uint32_t len = DECODE_IMM(1);
314                         bprintf(" src=r0, len=%d", len);
315                     } else {
316                         bprintf(" src=r0, len=r1");
317                     }
318 
319                     break;
320                 }
321                 case JDNSQMATCH_EXT_OPCODE:       // 43
322                 case JDNSAMATCH_EXT_OPCODE:       // 44
323                 case JDNSQMATCHSAFE_EXT_OPCODE:   // 45
324                 case JDNSAMATCHSAFE_EXT_OPCODE: { // 46
325                     uint32_t offs = DECODE_IMM(1 << (len_field - 1));
326                     int qtype = -1;
327                     switch(imm) {
328                         case JDNSQMATCH_EXT_OPCODE:
329                             print_opcode(reg_num ? "jdnsqeq" : "jdnsqne");
330                             qtype = DECODE_IMM(1);
331                             break;
332                         case JDNSQMATCHSAFE_EXT_OPCODE:
333                             print_opcode(reg_num ? "jdnsqeqsafe" : "jdnsqnesafe");
334                             qtype = DECODE_IMM(1);
335                             break;
336                         case JDNSAMATCH_EXT_OPCODE:
337                             print_opcode(reg_num ? "jdnsaeq" : "jdnsane"); break;
338                         case JDNSAMATCHSAFE_EXT_OPCODE:
339                             print_opcode(reg_num ? "jdnsaeqsafe" : "jdnsanesafe"); break;
340                         default:
341                             bprintf("unknown_ext %u", imm); break;
342                     }
343                     bprintf("r0, ");
344                     uint32_t end = *ptr2pc;
345                     while (end + 1 < program_len && !(program[end] == 0 && program[end + 1] == 0)) {
346                         end++;
347                     }
348                     end += 2;
349                     print_jump_target(end + offs, program_len);
350                     bprintf(", ");
351                     if (imm == JDNSQMATCH_EXT_OPCODE || imm == JDNSQMATCHSAFE_EXT_OPCODE) {
352                         bprintf("%d, ", qtype);
353                     }
354                     while (*ptr2pc < end) {
355                         uint8_t byte = program[(*ptr2pc)++];
356                         // values < 0x40 could be lengths, but - and 0..9 are in practice usually
357                         // too long to be lengths so print them as characters. All other chars < 0x40
358                         // are not valid in dns character.
359                         if (byte == '-' || (byte >= '0' && byte <= '9') || byte >= 0x40) {
360                             bprintf("%c", byte);
361                         } else {
362                             bprintf("(%d)", byte);
363                         }
364                     }
365                     break;
366                 }
367                 case JONEOF_EXT_OPCODE: {
368                     const uint32_t imm_len = 1 << (len_field - 1);
369                     uint32_t jump_offs = DECODE_IMM(imm_len);
370                     uint8_t imm3 = DECODE_IMM(1);
371                     bool jmp = imm3 & 1;
372                     uint8_t len = ((imm3 >> 1) & 3) + 1;
373                     uint8_t cnt = (imm3 >> 3) + 2;
374                     if (jmp) {
375                         print_opcode("jnoneof");
376                     } else {
377                         print_opcode("joneof");
378                     }
379                     bprintf("r%d, ", reg_num);
380                     print_jump_target(*ptr2pc + jump_offs + cnt * len, program_len);
381                     bprintf(", { ");
382                     while (cnt--) {
383                         uint32_t v = DECODE_IMM(len);
384                         if (cnt) {
385                             bprintf("%d, ", v);
386                         } else {
387                             bprintf("%d ", v);
388                         }
389                     }
390                     bprintf("}");
391                     break;
392                 }
393                 case EXCEPTIONBUFFER_EXT_OPCODE: {
394                     uint32_t buf_size = DECODE_IMM(2);
395                     print_opcode("debugbuf");
396                     bprintf("size=%d", buf_size);
397                     break;
398                 }
399                 default:
400                     bprintf("unknown_ext %u", imm);
401                     break;
402             }
403             break;
404         case LDDW_OPCODE:
405         case STDW_OPCODE:
406             PRINT_OPCODE();
407             if (v6_mode) {
408                 if (opcode == LDDW_OPCODE) {
409                     bprintf("r%u, counter=%d", reg_num, imm);
410                 } else {
411                     bprintf("counter=%d, r%u", imm, reg_num);
412                 }
413             } else {
414                 if (signed_imm > 0) {
415                     bprintf("r%u, [r%u+%d]", reg_num, reg_num ^ 1, signed_imm);
416                 } else if (signed_imm < 0) {
417                     bprintf("r%u, [r%u-%d]", reg_num, reg_num ^ 1, -signed_imm);
418                 } else {
419                     bprintf("r%u, [r%u]", reg_num, reg_num ^ 1);
420                 }
421             }
422             break;
423         case WRITE_OPCODE: {
424             PRINT_OPCODE();
425             uint32_t write_len = 1 << (len_field - 1);
426             if (write_len > 0) {
427                 bprintf("0x");
428             }
429             for (uint32_t i = 0; i < write_len; ++i) {
430                 uint8_t byte =
431                     (uint8_t) ((imm >> (write_len - 1 - i) * 8) & 0xff);
432                 bprintf("%02x", byte);
433 
434             }
435             break;
436         }
437         case PKTDATACOPY_OPCODE: {
438             if (reg_num == 0) {
439                 print_opcode("pktcopy");
440             } else {
441                 print_opcode("datacopy");
442             }
443             uint32_t src_offs = imm;
444             uint32_t copy_len = DECODE_IMM(1);
445             bprintf("src=%d, len=%d", src_offs, copy_len);
446             break;
447         }
448         // Unknown opcode
449         default:
450             bprintf("unknown %u", opcode);
451             break;
452     }
453     return print_buf;
454 }
455