xref: /btstack/3rd-party/micro-ecc/scripts/square_arm.py (revision 6ccd8248590f666db07dd7add13fecb4f5664fb5)
1#!/usr/bin/env python3
2
3import sys
4
5if len(sys.argv) < 2:
6    print("Provide the integer size in 32-bit words")
7    sys.exit(1)
8
9size = int(sys.argv[1])
10
11if size > 8:
12    print("This script doesn't work with integer size %s due to laziness" % (size))
13    sys.exit(1)
14
15init_size = 0
16if size > 6:
17    init_size = size - 6
18
19def emit(line, *args):
20    s = '"' + line + r' \n\t"'
21    print(s % args)
22
23def mulacc(acc, r1, r2):
24    if size <= 6:
25        emit("umull r1, r14, r%s, r%s", r1, r2)
26        emit("adds r%s, r1", acc[0])
27        emit("adcs r%s, r14", acc[1])
28        emit("adc r%s, #0", acc[2])
29    else:
30        emit("mov r14, r%s", acc[1])
31        emit("umlal r%s, r%s, r%s, r%s", acc[0], acc[1], r1, r2)
32        emit("cmp r14, r%s", acc[1])
33        emit("it hi")
34        emit("adchi r%s, #0", acc[2])
35
36r = [2, 3, 4, 5, 6, 7]
37
38s = size - init_size
39
40if init_size == 1:
41    emit("ldmia r1!, {r2}")
42    emit("add r1, %s", (size - init_size * 2) * 4)
43    emit("ldmia r1!, {r5}")
44
45    emit("add r0, %s", (size - init_size) * 4)
46    emit("umull r8, r9, r2, r5")
47    emit("stmia r0!, {r8, r9}")
48
49    emit("sub r0, %s", (size + init_size) * 4)
50    emit("sub r1, %s", (size) * 4)
51    print("")
52elif init_size == 2:
53    emit("ldmia r1!, {r2, r3}")
54    emit("add r1, %s", (size - init_size * 2) * 4)
55    emit("ldmia r1!, {r5, r6}")
56
57    emit("add r0, %s", (size - init_size) * 4)
58    print("")
59
60    emit("umull r8, r9, r2, r5")
61    emit("stmia r0!, {r8}")
62    print("")
63
64    emit("umull r12, r10, r2, r6")
65    emit("adds r9, r12")
66    emit("adc r10, #0")
67    emit("stmia r0!, {r9}")
68    print("")
69
70    emit("umull r8, r9, r3, r6")
71    emit("adds r10, r8")
72    emit("adc r11, r9, #0")
73    emit("stmia r0!, {r10, r11}")
74    print("")
75
76    emit("sub r0, %s", (size + init_size) * 4)
77    emit("sub r1, %s", (size) * 4)
78
79# load input words
80emit("ldmia r1!, {%s}", ", ".join(["r%s" % (r[i]) for i in range(s)]))
81print("")
82
83emit("umull r11, r12, r2, r2")
84emit("stmia r0!, {r11}")
85print("")
86emit("mov r9, #0")
87emit("umull r10, r11, r2, r3")
88emit("adds r12, r10")
89emit("adcs r8, r11, #0")
90emit("adc r9, #0")
91emit("adds r12, r10")
92emit("adcs r8, r11")
93emit("adc r9, #0")
94emit("stmia r0!, {r12}")
95print("")
96emit("mov r10, #0")
97emit("umull r11, r12, r2, r4")
98emit("adds r11, r11")
99emit("adcs r12, r12")
100emit("adc r10, #0")
101emit("adds r8, r11")
102emit("adcs r9, r12")
103emit("adc r10, #0")
104emit("umull r11, r12, r3, r3")
105emit("adds r8, r11")
106emit("adcs r9, r12")
107emit("adc r10, #0")
108emit("stmia r0!, {r8}")
109print("")
110
111acc = [8, 9, 10]
112old_acc = [11, 12]
113for i in range(3, s):
114    emit("mov r%s, #0", old_acc[1])
115    tmp = [acc[1], acc[2]]
116    acc = [acc[0], old_acc[0], old_acc[1]]
117    old_acc = tmp
118
119    # gather non-equal words
120    emit("umull r%s, r%s, r%s, r%s", acc[0], acc[1], r[0], r[i])
121    for j in range(1, (i+1)//2):
122        mulacc(acc, r[j], r[i-j])
123    # multiply by 2
124    emit("adds r%s, r%s", acc[0], acc[0])
125    emit("adcs r%s, r%s", acc[1], acc[1])
126    emit("adc r%s, r%s", acc[2], acc[2])
127
128    # add equal word (if any)
129    if ((i+1) % 2) != 0:
130        mulacc(acc, r[i//2], r[i//2])
131
132    # add old accumulator
133    emit("adds r%s, r%s", acc[0], old_acc[0])
134    emit("adcs r%s, r%s", acc[1], old_acc[1])
135    emit("adc r%s, #0", acc[2])
136
137    # store
138    emit("stmia r0!, {r%s}", acc[0])
139    print("")
140
141regs = list(r)
142for i in range(init_size):
143    regs = regs[1:] + regs[:1]
144    emit("ldmia r1!, {r%s}", regs[5])
145
146    for limit in [4, 5]:
147        emit("mov r%s, #0", old_acc[1])
148        tmp = [acc[1], acc[2]]
149        acc = [acc[0], old_acc[0], old_acc[1]]
150        old_acc = tmp
151
152        # gather non-equal words
153        emit("umull r%s, r%s, r%s, r%s", acc[0], acc[1], regs[0], regs[limit])
154        for j in range(1, (limit+1)//2):
155            mulacc(acc, regs[j], regs[limit-j])
156
157        emit("ldr r14, [r0]") # load stored value from initial block, and add to accumulator
158        emit("adds r%s, r14", acc[0])
159        emit("adcs r%s, #0", acc[1])
160        emit("adc r%s, #0", acc[2])
161
162        # multiply by 2
163        emit("adds r%s, r%s", acc[0], acc[0])
164        emit("adcs r%s, r%s", acc[1], acc[1])
165        emit("adc r%s, r%s", acc[2], acc[2])
166
167        # add equal word
168        if limit == 4:
169            mulacc(acc, regs[2], regs[2])
170
171        # add old accumulator
172        emit("adds r%s, r%s", acc[0], old_acc[0])
173        emit("adcs r%s, r%s", acc[1], old_acc[1])
174        emit("adc r%s, #0", acc[2])
175
176        # store
177        emit("stmia r0!, {r%s}", acc[0])
178        print("")
179
180for i in range(1, s-3):
181    emit("mov r%s, #0", old_acc[1])
182    tmp = [acc[1], acc[2]]
183    acc = [acc[0], old_acc[0], old_acc[1]]
184    old_acc = tmp
185
186    # gather non-equal words
187    emit("umull r%s, r%s, r%s, r%s", acc[0], acc[1], regs[i], regs[s - 1])
188    for j in range(1, (s-i)//2):
189        mulacc(acc, regs[i+j], regs[s - 1 - j])
190
191    # multiply by 2
192    emit("adds r%s, r%s", acc[0], acc[0])
193    emit("adcs r%s, r%s", acc[1], acc[1])
194    emit("adc r%s, r%s", acc[2], acc[2])
195
196    # add equal word (if any)
197    if ((s-i) % 2) != 0:
198        mulacc(acc, regs[i + (s-i)//2], regs[i + (s-i)//2])
199
200    # add old accumulator
201    emit("adds r%s, r%s", acc[0], old_acc[0])
202    emit("adcs r%s, r%s", acc[1], old_acc[1])
203    emit("adc r%s, #0", acc[2])
204
205    # store
206    emit("stmia r0!, {r%s}", acc[0])
207    print("")
208
209acc = acc[1:] + acc[:1]
210emit("mov r%s, #0", acc[2])
211emit("umull r1, r%s, r%s, r%s", old_acc[1], regs[s - 3], regs[s - 1])
212emit("adds r1, r1")
213emit("adcs r%s, r%s", old_acc[1], old_acc[1])
214emit("adc r%s, #0", acc[2])
215emit("adds r%s, r1", acc[0])
216emit("adcs r%s, r%s", acc[1], old_acc[1])
217emit("adc r%s, #0", acc[2])
218emit("umull r1, r%s, r%s, r%s", old_acc[1], regs[s - 2], regs[s - 2])
219emit("adds r%s, r1", acc[0])
220emit("adcs r%s, r%s", acc[1], old_acc[1])
221emit("adc r%s, #0", acc[2])
222emit("stmia r0!, {r%s}", acc[0])
223print("")
224
225acc = acc[1:] + acc[:1]
226emit("mov r%s, #0", acc[2])
227emit("umull r1, r%s, r%s, r%s", old_acc[1], regs[s - 2], regs[s - 1])
228emit("adds r1, r1")
229emit("adcs r%s, r%s", old_acc[1], old_acc[1])
230emit("adc r%s, #0", acc[2])
231emit("adds r%s, r1", acc[0])
232emit("adcs r%s, r%s", acc[1], old_acc[1])
233emit("adc r%s, #0", acc[2])
234emit("stmia r0!, {r%s}", acc[0])
235print("")
236
237acc = acc[1:] + acc[:1]
238emit("umull r1, r%s, r%s, r%s", old_acc[1], regs[s - 1], regs[s - 1])
239emit("adds r%s, r1", acc[0])
240emit("adcs r%s, r%s", acc[1], old_acc[1])
241emit("stmia r0!, {r%s}", acc[0])
242emit("stmia r0!, {r%s}", acc[1])
243