xref: /btstack/3rd-party/micro-ecc/scripts/mult_arm.py (revision 6ccd8248590f666db07dd7add13fecb4f5664fb5)
1#!/usr/bin/env python3
2
3import sys
4
5if len(sys.argv) < 2:
6    print("Provide the integer size in 32-bit words")
7    sys.exit(1)
8
9size = int(sys.argv[1])
10
11full_rows = size // 3
12init_size = size % 3
13
14if init_size == 0:
15    full_rows = full_rows - 1
16    init_size = 3
17
18def emit(line, *args):
19    s = '"' + line + r' \n\t"'
20    print(s % args)
21
22rx = [3, 4, 5]
23ry = [6, 7, 8]
24
25#### set up registers
26emit("add r0, %s", (size - init_size) * 4) # move z
27emit("add r2, %s", (size - init_size) * 4) # move y
28
29emit("ldmia r1!, {%s}", ", ".join(["r%s" % (rx[i]) for i in range(init_size)]))
30emit("ldmia r2!, {%s}", ", ".join(["r%s" % (ry[i]) for i in range(init_size)]))
31
32print("")
33if init_size == 1:
34    emit("umull r9, r10, r3, r6")
35    emit("stmia r0!, {r9, r10}")
36else:
37    #### first two multiplications of initial block
38    emit("umull r11, r12, r3, r6")
39    emit("stmia r0!, {r11}")
40    print("")
41    emit("mov r10, #0")
42    emit("umull r11, r9, r3, r7")
43    emit("adds r12, r11")
44    emit("adc r9, #0")
45    emit("umull r11, r14, r4, r6")
46    emit("adds r12, r11")
47    emit("adcs r9, r14")
48    emit("adc r10, #0")
49    emit("stmia r0!, {r12}")
50    print("")
51
52    #### rest of initial block, with moving accumulator registers
53    acc = [9, 10, 11, 12, 14]
54    if init_size == 3:
55        emit("mov r%s, #0", acc[2])
56        for i in range(0, 3):
57            emit("umull r%s, r%s, r%s, r%s", acc[3], acc[4], rx[i], ry[2 - i])
58            emit("adds r%s, r%s", acc[0], acc[3])
59            emit("adcs r%s, r%s", acc[1], acc[4])
60            emit("adc r%s, #0", acc[2])
61        emit("stmia r0!, {r%s}", acc[0])
62        print("")
63        acc = acc[1:] + acc[:1]
64
65        emit("mov r%s, #0", acc[2])
66        for i in range(0, 2):
67            emit("umull r%s, r%s, r%s, r%s", acc[3], acc[4], rx[i + 1], ry[2 - i])
68            emit("adds r%s, r%s", acc[0], acc[3])
69            emit("adcs r%s, r%s", acc[1], acc[4])
70            emit("adc r%s, #0", acc[2])
71        emit("stmia r0!, {r%s}", acc[0])
72        print("")
73        acc = acc[1:] + acc[:1]
74
75    emit("umull r%s, r%s, r%s, r%s", acc[3], acc[4], rx[init_size-1], ry[init_size-1])
76    emit("adds r%s, r%s", acc[0], acc[3])
77    emit("adc r%s, r%s", acc[1], acc[4])
78    emit("stmia r0!, {r%s}", acc[0])
79    emit("stmia r0!, {r%s}", acc[1])
80print("")
81
82#### reset y and z pointers
83emit("sub r0, %s", (2 * init_size + 3) * 4)
84emit("sub r2, %s", (init_size + 3) * 4)
85
86#### load y registers
87emit("ldmia r2!, {%s}", ", ".join(["r%s" % (ry[i]) for i in range(3)]))
88
89#### load additional x registers
90if init_size != 3:
91    emit("ldmia r1!, {%s}", ", ".join(["r%s" % (rx[i]) for i in range(init_size, 3)]))
92print("")
93
94prev_size = init_size
95for row in range(full_rows):
96    emit("umull r11, r12, r3, r6")
97    emit("stmia r0!, {r11}")
98    print("")
99    emit("mov r10, #0")
100    emit("umull r11, r9, r3, r7")
101    emit("adds r12, r11")
102    emit("adc r9, #0")
103    emit("umull r11, r14, r4, r6")
104    emit("adds r12, r11")
105    emit("adcs r9, r14")
106    emit("adc r10, #0")
107    emit("stmia r0!, {r12}")
108    print("")
109
110    acc = [9, 10, 11, 12, 14]
111    emit("mov r%s, #0", acc[2])
112    for i in range(0, 3):
113        emit("umull r%s, r%s, r%s, r%s", acc[3], acc[4], rx[i], ry[2 - i])
114        emit("adds r%s, r%s", acc[0], acc[3])
115        emit("adcs r%s, r%s", acc[1], acc[4])
116        emit("adc r%s, #0", acc[2])
117    emit("stmia r0!, {r%s}", acc[0])
118    print("")
119    acc = acc[1:] + acc[:1]
120
121    #### now we need to start shifting x and loading from z
122    x_regs = [3, 4, 5]
123    for r in range(0, prev_size):
124        x_regs = x_regs[1:] + x_regs[:1]
125        emit("ldmia r1!, {r%s}", x_regs[2])
126        emit("mov r%s, #0", acc[2])
127        for i in range(0, 3):
128            emit("umull r%s, r%s, r%s, r%s", acc[3], acc[4], x_regs[i], ry[2 - i])
129            emit("adds r%s, r%s", acc[0], acc[3])
130            emit("adcs r%s, r%s", acc[1], acc[4])
131            emit("adc r%s, #0", acc[2])
132        emit("ldr r%s, [r0]", acc[3]) # load stored value from initial block, and add to accumulator
133        emit("adds r%s, r%s", acc[0], acc[3])
134        emit("adcs r%s, #0", acc[1])
135        emit("adc r%s, #0", acc[2])
136        emit("stmia r0!, {r%s}", acc[0])
137        print("")
138        acc = acc[1:] + acc[:1]
139
140    # done shifting x, start shifting y
141    y_regs = [6, 7, 8]
142    for r in range(0, prev_size):
143        y_regs = y_regs[1:] + y_regs[:1]
144        emit("ldmia r2!, {r%s}", y_regs[2])
145        emit("mov r%s, #0", acc[2])
146        for i in range(0, 3):
147            emit("umull r%s, r%s, r%s, r%s", acc[3], acc[4], x_regs[i], y_regs[2 - i])
148            emit("adds r%s, r%s", acc[0], acc[3])
149            emit("adcs r%s, r%s", acc[1], acc[4])
150            emit("adc r%s, #0", acc[2])
151        emit("ldr r%s, [r0]", acc[3]) # load stored value from initial block, and add to accumulator
152        emit("adds r%s, r%s", acc[0], acc[3])
153        emit("adcs r%s, #0", acc[1])
154        emit("adc r%s, #0", acc[2])
155        emit("stmia r0!, {r%s}", acc[0])
156        print("")
157        acc = acc[1:] + acc[:1]
158
159    # done both shifts, do remaining corner
160    emit("mov r%s, #0", acc[2])
161    for i in range(0, 2):
162        emit("umull r%s, r%s, r%s, r%s", acc[3], acc[4], x_regs[i + 1], y_regs[2 - i])
163        emit("adds r%s, r%s", acc[0], acc[3])
164        emit("adcs r%s, r%s", acc[1], acc[4])
165        emit("adc r%s, #0", acc[2])
166    emit("stmia r0!, {r%s}", acc[0])
167    print("")
168    acc = acc[1:] + acc[:1]
169
170    emit("umull r%s, r%s, r%s, r%s", acc[3], acc[4], x_regs[2], y_regs[2])
171    emit("adds r%s, r%s", acc[0], acc[3])
172    emit("adc r%s, r%s", acc[1], acc[4])
173    emit("stmia r0!, {r%s}", acc[0])
174    emit("stmia r0!, {r%s}", acc[1])
175    print("")
176
177    prev_size = prev_size + 3
178    if row < full_rows - 1:
179        #### reset x, y and z pointers
180        emit("sub r0, %s", (2 * prev_size + 3) * 4)
181        emit("sub r1, %s", prev_size * 4)
182        emit("sub r2, %s", (prev_size + 3) * 4)
183
184        #### load x and y registers
185        emit("ldmia r1!, {%s}", ",".join(["r%s" % (rx[i]) for i in range(3)]))
186        emit("ldmia r2!, {%s}", ",".join(["r%s" % (ry[i]) for i in range(3)]))
187
188        print("")
189