1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26
27 #define COND_LOWER_OP(b, name, ...) \
28 (b->shader->options->lower_int64_options & \
29 nir_lower_int64_op_to_options_mask(nir_op_##name)) \
30 ? lower_##name##64(b, __VA_ARGS__) \
31 : nir_##name(b, __VA_ARGS__)
32
33 #define COND_LOWER_CMP(b, name, ...) \
34 (b->shader->options->lower_int64_options & \
35 nir_lower_int64_op_to_options_mask(nir_op_##name)) \
36 ? lower_int64_compare(b, nir_op_##name, __VA_ARGS__) \
37 : nir_##name(b, __VA_ARGS__)
38
39 #define COND_LOWER_CAST(b, name, ...) \
40 (b->shader->options->lower_int64_options & \
41 nir_lower_int64_op_to_options_mask(nir_op_##name)) \
42 ? lower_##name(b, __VA_ARGS__) \
43 : nir_##name(b, __VA_ARGS__)
44
45 static nir_def *
lower_b2i64(nir_builder * b,nir_def * x)46 lower_b2i64(nir_builder *b, nir_def *x)
47 {
48 return nir_pack_64_2x32_split(b, nir_b2i32(b, x), nir_imm_int(b, 0));
49 }
50
51 static nir_def *
lower_i2i8(nir_builder * b,nir_def * x)52 lower_i2i8(nir_builder *b, nir_def *x)
53 {
54 return nir_i2i8(b, nir_unpack_64_2x32_split_x(b, x));
55 }
56
57 static nir_def *
lower_i2i16(nir_builder * b,nir_def * x)58 lower_i2i16(nir_builder *b, nir_def *x)
59 {
60 return nir_i2i16(b, nir_unpack_64_2x32_split_x(b, x));
61 }
62
63 static nir_def *
lower_i2i32(nir_builder * b,nir_def * x)64 lower_i2i32(nir_builder *b, nir_def *x)
65 {
66 return nir_unpack_64_2x32_split_x(b, x);
67 }
68
69 static nir_def *
lower_i2i64(nir_builder * b,nir_def * x)70 lower_i2i64(nir_builder *b, nir_def *x)
71 {
72 nir_def *x32 = x->bit_size == 32 ? x : nir_i2i32(b, x);
73 return nir_pack_64_2x32_split(b, x32, nir_ishr_imm(b, x32, 31));
74 }
75
76 static nir_def *
lower_u2u8(nir_builder * b,nir_def * x)77 lower_u2u8(nir_builder *b, nir_def *x)
78 {
79 return nir_u2u8(b, nir_unpack_64_2x32_split_x(b, x));
80 }
81
82 static nir_def *
lower_u2u16(nir_builder * b,nir_def * x)83 lower_u2u16(nir_builder *b, nir_def *x)
84 {
85 return nir_u2u16(b, nir_unpack_64_2x32_split_x(b, x));
86 }
87
88 static nir_def *
lower_u2u32(nir_builder * b,nir_def * x)89 lower_u2u32(nir_builder *b, nir_def *x)
90 {
91 return nir_unpack_64_2x32_split_x(b, x);
92 }
93
94 static nir_def *
lower_u2u64(nir_builder * b,nir_def * x)95 lower_u2u64(nir_builder *b, nir_def *x)
96 {
97 nir_def *x32 = x->bit_size == 32 ? x : nir_u2u32(b, x);
98 return nir_pack_64_2x32_split(b, x32, nir_imm_int(b, 0));
99 }
100
101 static nir_def *
lower_bcsel64(nir_builder * b,nir_def * cond,nir_def * x,nir_def * y)102 lower_bcsel64(nir_builder *b, nir_def *cond, nir_def *x, nir_def *y)
103 {
104 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
105 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
106 nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
107 nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
108
109 return nir_pack_64_2x32_split(b, nir_bcsel(b, cond, x_lo, y_lo),
110 nir_bcsel(b, cond, x_hi, y_hi));
111 }
112
113 static nir_def *
lower_inot64(nir_builder * b,nir_def * x)114 lower_inot64(nir_builder *b, nir_def *x)
115 {
116 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
117 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
118
119 return nir_pack_64_2x32_split(b, nir_inot(b, x_lo), nir_inot(b, x_hi));
120 }
121
122 static nir_def *
lower_iand64(nir_builder * b,nir_def * x,nir_def * y)123 lower_iand64(nir_builder *b, nir_def *x, nir_def *y)
124 {
125 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
126 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
127 nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
128 nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
129
130 return nir_pack_64_2x32_split(b, nir_iand(b, x_lo, y_lo),
131 nir_iand(b, x_hi, y_hi));
132 }
133
134 static nir_def *
lower_ior64(nir_builder * b,nir_def * x,nir_def * y)135 lower_ior64(nir_builder *b, nir_def *x, nir_def *y)
136 {
137 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
138 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
139 nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
140 nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
141
142 return nir_pack_64_2x32_split(b, nir_ior(b, x_lo, y_lo),
143 nir_ior(b, x_hi, y_hi));
144 }
145
146 static nir_def *
lower_ixor64(nir_builder * b,nir_def * x,nir_def * y)147 lower_ixor64(nir_builder *b, nir_def *x, nir_def *y)
148 {
149 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
150 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
151 nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
152 nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
153
154 return nir_pack_64_2x32_split(b, nir_ixor(b, x_lo, y_lo),
155 nir_ixor(b, x_hi, y_hi));
156 }
157
158 static nir_def *
lower_ishl64(nir_builder * b,nir_def * x,nir_def * y)159 lower_ishl64(nir_builder *b, nir_def *x, nir_def *y)
160 {
161 /* Implemented as
162 *
163 * uint64_t lshift(uint64_t x, int c)
164 * {
165 * c %= 64;
166 *
167 * if (c == 0) return x;
168 *
169 * uint32_t lo = LO(x), hi = HI(x);
170 *
171 * if (c < 32) {
172 * uint32_t lo_shifted = lo << c;
173 * uint32_t hi_shifted = hi << c;
174 * uint32_t lo_shifted_hi = lo >> abs(32 - c);
175 * return pack_64(lo_shifted, hi_shifted | lo_shifted_hi);
176 * } else {
177 * uint32_t lo_shifted_hi = lo << abs(32 - c);
178 * return pack_64(0, lo_shifted_hi);
179 * }
180 * }
181 */
182 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
183 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
184 y = nir_iand_imm(b, y, 0x3f);
185
186 nir_def *reverse_count = nir_iabs(b, nir_iadd_imm(b, y, -32));
187 nir_def *lo_shifted = nir_ishl(b, x_lo, y);
188 nir_def *hi_shifted = nir_ishl(b, x_hi, y);
189 nir_def *lo_shifted_hi = nir_ushr(b, x_lo, reverse_count);
190
191 nir_def *res_if_lt_32 =
192 nir_pack_64_2x32_split(b, lo_shifted,
193 nir_ior(b, hi_shifted, lo_shifted_hi));
194 nir_def *res_if_ge_32 =
195 nir_pack_64_2x32_split(b, nir_imm_int(b, 0),
196 nir_ishl(b, x_lo, reverse_count));
197
198 return nir_bcsel(b, nir_ieq_imm(b, y, 0), x,
199 nir_bcsel(b, nir_uge_imm(b, y, 32),
200 res_if_ge_32, res_if_lt_32));
201 }
202
203 static nir_def *
lower_ishr64(nir_builder * b,nir_def * x,nir_def * y)204 lower_ishr64(nir_builder *b, nir_def *x, nir_def *y)
205 {
206 /* Implemented as
207 *
208 * uint64_t arshift(uint64_t x, int c)
209 * {
210 * c %= 64;
211 *
212 * if (c == 0) return x;
213 *
214 * uint32_t lo = LO(x);
215 * int32_t hi = HI(x);
216 *
217 * if (c < 32) {
218 * uint32_t lo_shifted = lo >> c;
219 * uint32_t hi_shifted = hi >> c;
220 * uint32_t hi_shifted_lo = hi << abs(32 - c);
221 * return pack_64(hi_shifted, hi_shifted_lo | lo_shifted);
222 * } else {
223 * uint32_t hi_shifted = hi >> 31;
224 * uint32_t hi_shifted_lo = hi >> abs(32 - c);
225 * return pack_64(hi_shifted, hi_shifted_lo);
226 * }
227 * }
228 */
229 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
230 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
231 y = nir_iand_imm(b, y, 0x3f);
232
233 nir_def *reverse_count = nir_iabs(b, nir_iadd_imm(b, y, -32));
234 nir_def *lo_shifted = nir_ushr(b, x_lo, y);
235 nir_def *hi_shifted = nir_ishr(b, x_hi, y);
236 nir_def *hi_shifted_lo = nir_ishl(b, x_hi, reverse_count);
237
238 nir_def *res_if_lt_32 =
239 nir_pack_64_2x32_split(b, nir_ior(b, lo_shifted, hi_shifted_lo),
240 hi_shifted);
241 nir_def *res_if_ge_32 =
242 nir_pack_64_2x32_split(b, nir_ishr(b, x_hi, reverse_count),
243 nir_ishr_imm(b, x_hi, 31));
244
245 return nir_bcsel(b, nir_ieq_imm(b, y, 0), x,
246 nir_bcsel(b, nir_uge_imm(b, y, 32),
247 res_if_ge_32, res_if_lt_32));
248 }
249
250 static nir_def *
lower_ushr64(nir_builder * b,nir_def * x,nir_def * y)251 lower_ushr64(nir_builder *b, nir_def *x, nir_def *y)
252 {
253 /* Implemented as
254 *
255 * uint64_t rshift(uint64_t x, int c)
256 * {
257 * c %= 64;
258 *
259 * if (c == 0) return x;
260 *
261 * uint32_t lo = LO(x), hi = HI(x);
262 *
263 * if (c < 32) {
264 * uint32_t lo_shifted = lo >> c;
265 * uint32_t hi_shifted = hi >> c;
266 * uint32_t hi_shifted_lo = hi << abs(32 - c);
267 * return pack_64(hi_shifted, hi_shifted_lo | lo_shifted);
268 * } else {
269 * uint32_t hi_shifted_lo = hi >> abs(32 - c);
270 * return pack_64(0, hi_shifted_lo);
271 * }
272 * }
273 */
274
275 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
276 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
277 y = nir_iand_imm(b, y, 0x3f);
278
279 nir_def *reverse_count = nir_iabs(b, nir_iadd_imm(b, y, -32));
280 nir_def *lo_shifted = nir_ushr(b, x_lo, y);
281 nir_def *hi_shifted = nir_ushr(b, x_hi, y);
282 nir_def *hi_shifted_lo = nir_ishl(b, x_hi, reverse_count);
283
284 nir_def *res_if_lt_32 =
285 nir_pack_64_2x32_split(b, nir_ior(b, lo_shifted, hi_shifted_lo),
286 hi_shifted);
287 nir_def *res_if_ge_32 =
288 nir_pack_64_2x32_split(b, nir_ushr(b, x_hi, reverse_count),
289 nir_imm_int(b, 0));
290
291 return nir_bcsel(b, nir_ieq_imm(b, y, 0), x,
292 nir_bcsel(b, nir_uge_imm(b, y, 32),
293 res_if_ge_32, res_if_lt_32));
294 }
295
296 static nir_def *
lower_iadd64(nir_builder * b,nir_def * x,nir_def * y)297 lower_iadd64(nir_builder *b, nir_def *x, nir_def *y)
298 {
299 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
300 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
301 nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
302 nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
303
304 nir_def *res_lo = nir_iadd(b, x_lo, y_lo);
305 nir_def *carry = nir_b2i32(b, nir_ult(b, res_lo, x_lo));
306 nir_def *res_hi = nir_iadd(b, carry, nir_iadd(b, x_hi, y_hi));
307
308 return nir_pack_64_2x32_split(b, res_lo, res_hi);
309 }
310
311 static nir_def *
lower_isub64(nir_builder * b,nir_def * x,nir_def * y)312 lower_isub64(nir_builder *b, nir_def *x, nir_def *y)
313 {
314 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
315 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
316 nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
317 nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
318
319 nir_def *res_lo = nir_isub(b, x_lo, y_lo);
320 nir_def *borrow = nir_ineg(b, nir_b2i32(b, nir_ult(b, x_lo, y_lo)));
321 nir_def *res_hi = nir_iadd(b, nir_isub(b, x_hi, y_hi), borrow);
322
323 return nir_pack_64_2x32_split(b, res_lo, res_hi);
324 }
325
326 static nir_def *
lower_ineg64(nir_builder * b,nir_def * x)327 lower_ineg64(nir_builder *b, nir_def *x)
328 {
329 /* Since isub is the same number of instructions (with better dependencies)
330 * as iadd, subtraction is actually more efficient for ineg than the usual
331 * 2's complement "flip the bits and add one".
332 */
333 return lower_isub64(b, nir_imm_int64(b, 0), x);
334 }
335
336 static nir_def *
lower_iabs64(nir_builder * b,nir_def * x)337 lower_iabs64(nir_builder *b, nir_def *x)
338 {
339 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
340 nir_def *x_is_neg = nir_ilt_imm(b, x_hi, 0);
341 return nir_bcsel(b, x_is_neg, nir_ineg(b, x), x);
342 }
343
344 static nir_def *
lower_int64_compare(nir_builder * b,nir_op op,nir_def * x,nir_def * y)345 lower_int64_compare(nir_builder *b, nir_op op, nir_def *x, nir_def *y)
346 {
347 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
348 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
349 nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
350 nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
351
352 switch (op) {
353 case nir_op_ieq:
354 return nir_iand(b, nir_ieq(b, x_hi, y_hi), nir_ieq(b, x_lo, y_lo));
355 case nir_op_ine:
356 return nir_ior(b, nir_ine(b, x_hi, y_hi), nir_ine(b, x_lo, y_lo));
357 case nir_op_ult:
358 return nir_ior(b, nir_ult(b, x_hi, y_hi),
359 nir_iand(b, nir_ieq(b, x_hi, y_hi),
360 nir_ult(b, x_lo, y_lo)));
361 case nir_op_ilt:
362 return nir_ior(b, nir_ilt(b, x_hi, y_hi),
363 nir_iand(b, nir_ieq(b, x_hi, y_hi),
364 nir_ult(b, x_lo, y_lo)));
365 break;
366 case nir_op_uge:
367 /* Lower as !(x < y) in the hopes of better CSE */
368 return nir_inot(b, lower_int64_compare(b, nir_op_ult, x, y));
369 case nir_op_ige:
370 /* Lower as !(x < y) in the hopes of better CSE */
371 return nir_inot(b, lower_int64_compare(b, nir_op_ilt, x, y));
372 default:
373 unreachable("Invalid comparison");
374 }
375 }
376
377 static nir_def *
lower_umax64(nir_builder * b,nir_def * x,nir_def * y)378 lower_umax64(nir_builder *b, nir_def *x, nir_def *y)
379 {
380 return nir_bcsel(b, COND_LOWER_CMP(b, ult, x, y), y, x);
381 }
382
383 static nir_def *
lower_imax64(nir_builder * b,nir_def * x,nir_def * y)384 lower_imax64(nir_builder *b, nir_def *x, nir_def *y)
385 {
386 return nir_bcsel(b, COND_LOWER_CMP(b, ilt, x, y), y, x);
387 }
388
389 static nir_def *
lower_umin64(nir_builder * b,nir_def * x,nir_def * y)390 lower_umin64(nir_builder *b, nir_def *x, nir_def *y)
391 {
392 return nir_bcsel(b, COND_LOWER_CMP(b, ult, x, y), x, y);
393 }
394
395 static nir_def *
lower_imin64(nir_builder * b,nir_def * x,nir_def * y)396 lower_imin64(nir_builder *b, nir_def *x, nir_def *y)
397 {
398 return nir_bcsel(b, COND_LOWER_CMP(b, ilt, x, y), x, y);
399 }
400
401 static nir_def *
lower_mul_2x32_64(nir_builder * b,nir_def * x,nir_def * y,bool sign_extend)402 lower_mul_2x32_64(nir_builder *b, nir_def *x, nir_def *y,
403 bool sign_extend)
404 {
405 nir_def *res_hi = sign_extend ? nir_imul_high(b, x, y)
406 : nir_umul_high(b, x, y);
407
408 return nir_pack_64_2x32_split(b, nir_imul(b, x, y), res_hi);
409 }
410
411 static nir_def *
lower_imul64(nir_builder * b,nir_def * x,nir_def * y)412 lower_imul64(nir_builder *b, nir_def *x, nir_def *y)
413 {
414 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
415 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
416 nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
417 nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
418
419 nir_def *mul_lo = nir_umul_2x32_64(b, x_lo, y_lo);
420 nir_def *res_hi = nir_iadd(b, nir_unpack_64_2x32_split_y(b, mul_lo),
421 nir_iadd(b, nir_imul(b, x_lo, y_hi),
422 nir_imul(b, x_hi, y_lo)));
423
424 return nir_pack_64_2x32_split(b, nir_unpack_64_2x32_split_x(b, mul_lo),
425 res_hi);
426 }
427
428 static nir_def *
lower_mul_high64(nir_builder * b,nir_def * x,nir_def * y,bool sign_extend)429 lower_mul_high64(nir_builder *b, nir_def *x, nir_def *y,
430 bool sign_extend)
431 {
432 nir_def *x32[4], *y32[4];
433 x32[0] = nir_unpack_64_2x32_split_x(b, x);
434 x32[1] = nir_unpack_64_2x32_split_y(b, x);
435 if (sign_extend) {
436 x32[2] = x32[3] = nir_ishr_imm(b, x32[1], 31);
437 } else {
438 x32[2] = x32[3] = nir_imm_int(b, 0);
439 }
440
441 y32[0] = nir_unpack_64_2x32_split_x(b, y);
442 y32[1] = nir_unpack_64_2x32_split_y(b, y);
443 if (sign_extend) {
444 y32[2] = y32[3] = nir_ishr_imm(b, y32[1], 31);
445 } else {
446 y32[2] = y32[3] = nir_imm_int(b, 0);
447 }
448
449 nir_def *res[8] = {
450 NULL,
451 };
452
453 /* Yes, the following generates a pile of code. However, we throw res[0]
454 * and res[1] away in the end and, if we're in the umul case, four of our
455 * eight dword operands will be constant zero and opt_algebraic will clean
456 * this up nicely.
457 */
458 for (unsigned i = 0; i < 4; i++) {
459 nir_def *carry = NULL;
460 for (unsigned j = 0; j < 4; j++) {
461 /* The maximum values of x32[i] and y32[j] are UINT32_MAX so the
462 * maximum value of tmp is UINT32_MAX * UINT32_MAX. The maximum
463 * value that will fit in tmp is
464 *
465 * UINT64_MAX = UINT32_MAX << 32 + UINT32_MAX
466 * = UINT32_MAX * (UINT32_MAX + 1) + UINT32_MAX
467 * = UINT32_MAX * UINT32_MAX + 2 * UINT32_MAX
468 *
469 * so we're guaranteed that we can add in two more 32-bit values
470 * without overflowing tmp.
471 */
472 nir_def *tmp = nir_umul_2x32_64(b, x32[i], y32[j]);
473
474 if (res[i + j])
475 tmp = nir_iadd(b, tmp, nir_u2u64(b, res[i + j]));
476 if (carry)
477 tmp = nir_iadd(b, tmp, carry);
478 res[i + j] = nir_u2u32(b, tmp);
479 carry = nir_ushr_imm(b, tmp, 32);
480 }
481 res[i + 4] = nir_u2u32(b, carry);
482 }
483
484 return nir_pack_64_2x32_split(b, res[2], res[3]);
485 }
486
487 static nir_def *
lower_isign64(nir_builder * b,nir_def * x)488 lower_isign64(nir_builder *b, nir_def *x)
489 {
490 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
491 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
492
493 nir_def *is_non_zero = nir_i2b(b, nir_ior(b, x_lo, x_hi));
494 nir_def *res_hi = nir_ishr_imm(b, x_hi, 31);
495 nir_def *res_lo = nir_ior(b, res_hi, nir_b2i32(b, is_non_zero));
496
497 return nir_pack_64_2x32_split(b, res_lo, res_hi);
498 }
499
500 static void
lower_udiv64_mod64(nir_builder * b,nir_def * n,nir_def * d,nir_def ** q,nir_def ** r)501 lower_udiv64_mod64(nir_builder *b, nir_def *n, nir_def *d,
502 nir_def **q, nir_def **r)
503 {
504 /* TODO: We should specially handle the case where the denominator is a
505 * constant. In that case, we should be able to reduce it to a multiply by
506 * a constant, some shifts, and an add.
507 */
508 nir_def *n_lo = nir_unpack_64_2x32_split_x(b, n);
509 nir_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
510 nir_def *d_lo = nir_unpack_64_2x32_split_x(b, d);
511 nir_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
512
513 nir_def *q_lo = nir_imm_zero(b, n->num_components, 32);
514 nir_def *q_hi = nir_imm_zero(b, n->num_components, 32);
515
516 nir_def *n_hi_before_if = n_hi;
517 nir_def *q_hi_before_if = q_hi;
518
519 /* If the upper 32 bits of denom are non-zero, it is impossible for shifts
520 * greater than 32 bits to occur. If the upper 32 bits of the numerator
521 * are zero, it is impossible for (denom << [63, 32]) <= numer unless
522 * denom == 0.
523 */
524 nir_def *need_high_div =
525 nir_iand(b, nir_ieq_imm(b, d_hi, 0), nir_uge(b, n_hi, d_lo));
526 nir_push_if(b, nir_bany(b, need_high_div));
527 {
528 /* If we only have one component, then the bany above goes away and
529 * this is always true within the if statement.
530 */
531 if (n->num_components == 1)
532 need_high_div = nir_imm_true(b);
533
534 nir_def *log2_d_lo = nir_ufind_msb(b, d_lo);
535
536 for (int i = 31; i >= 0; i--) {
537 /* if ((d.x << i) <= n.y) {
538 * n.y -= d.x << i;
539 * quot.y |= 1U << i;
540 * }
541 */
542 nir_def *d_shift = nir_ishl_imm(b, d_lo, i);
543 nir_def *new_n_hi = nir_isub(b, n_hi, d_shift);
544 nir_def *new_q_hi = nir_ior_imm(b, q_hi, 1ull << i);
545 nir_def *cond = nir_iand(b, need_high_div,
546 nir_uge(b, n_hi, d_shift));
547 if (i != 0) {
548 /* log2_d_lo is always <= 31, so we don't need to bother with it
549 * in the last iteration.
550 */
551 cond = nir_iand(b, cond,
552 nir_ile_imm(b, log2_d_lo, 31 - i));
553 }
554 n_hi = nir_bcsel(b, cond, new_n_hi, n_hi);
555 q_hi = nir_bcsel(b, cond, new_q_hi, q_hi);
556 }
557 }
558 nir_pop_if(b, NULL);
559 n_hi = nir_if_phi(b, n_hi, n_hi_before_if);
560 q_hi = nir_if_phi(b, q_hi, q_hi_before_if);
561
562 nir_def *log2_denom = nir_ufind_msb(b, d_hi);
563
564 n = nir_pack_64_2x32_split(b, n_lo, n_hi);
565 d = nir_pack_64_2x32_split(b, d_lo, d_hi);
566 for (int i = 31; i >= 0; i--) {
567 /* if ((d64 << i) <= n64) {
568 * n64 -= d64 << i;
569 * quot.x |= 1U << i;
570 * }
571 */
572 nir_def *d_shift = nir_ishl_imm(b, d, i);
573 nir_def *new_n = nir_isub(b, n, d_shift);
574 nir_def *new_q_lo = nir_ior_imm(b, q_lo, 1ull << i);
575 nir_def *cond = nir_uge(b, n, d_shift);
576 if (i != 0) {
577 /* log2_denom is always <= 31, so we don't need to bother with it
578 * in the last iteration.
579 */
580 cond = nir_iand(b, cond,
581 nir_ile_imm(b, log2_denom, 31 - i));
582 }
583 n = nir_bcsel(b, cond, new_n, n);
584 q_lo = nir_bcsel(b, cond, new_q_lo, q_lo);
585 }
586
587 *q = nir_pack_64_2x32_split(b, q_lo, q_hi);
588 *r = n;
589 }
590
591 static nir_def *
lower_udiv64(nir_builder * b,nir_def * n,nir_def * d)592 lower_udiv64(nir_builder *b, nir_def *n, nir_def *d)
593 {
594 nir_def *q, *r;
595 lower_udiv64_mod64(b, n, d, &q, &r);
596 return q;
597 }
598
599 static nir_def *
lower_idiv64(nir_builder * b,nir_def * n,nir_def * d)600 lower_idiv64(nir_builder *b, nir_def *n, nir_def *d)
601 {
602 nir_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
603 nir_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
604
605 nir_def *negate = nir_ine(b, nir_ilt_imm(b, n_hi, 0),
606 nir_ilt_imm(b, d_hi, 0));
607 nir_def *q, *r;
608 lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
609 return nir_bcsel(b, negate, nir_ineg(b, q), q);
610 }
611
612 static nir_def *
lower_umod64(nir_builder * b,nir_def * n,nir_def * d)613 lower_umod64(nir_builder *b, nir_def *n, nir_def *d)
614 {
615 nir_def *q, *r;
616 lower_udiv64_mod64(b, n, d, &q, &r);
617 return r;
618 }
619
620 static nir_def *
lower_imod64(nir_builder * b,nir_def * n,nir_def * d)621 lower_imod64(nir_builder *b, nir_def *n, nir_def *d)
622 {
623 nir_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
624 nir_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
625 nir_def *n_is_neg = nir_ilt_imm(b, n_hi, 0);
626 nir_def *d_is_neg = nir_ilt_imm(b, d_hi, 0);
627
628 nir_def *q, *r;
629 lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
630
631 nir_def *rem = nir_bcsel(b, n_is_neg, nir_ineg(b, r), r);
632
633 return nir_bcsel(b, nir_ieq_imm(b, r, 0), nir_imm_int64(b, 0),
634 nir_bcsel(b, nir_ieq(b, n_is_neg, d_is_neg), rem,
635 nir_iadd(b, rem, d)));
636 }
637
638 static nir_def *
lower_irem64(nir_builder * b,nir_def * n,nir_def * d)639 lower_irem64(nir_builder *b, nir_def *n, nir_def *d)
640 {
641 nir_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
642 nir_def *n_is_neg = nir_ilt_imm(b, n_hi, 0);
643
644 nir_def *q, *r;
645 lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
646 return nir_bcsel(b, n_is_neg, nir_ineg(b, r), r);
647 }
648
649 static nir_def *
lower_extract(nir_builder * b,nir_op op,nir_def * x,nir_def * c)650 lower_extract(nir_builder *b, nir_op op, nir_def *x, nir_def *c)
651 {
652 assert(op == nir_op_extract_u8 || op == nir_op_extract_i8 ||
653 op == nir_op_extract_u16 || op == nir_op_extract_i16);
654
655 const int chunk = nir_src_as_uint(nir_src_for_ssa(c));
656 const int chunk_bits =
657 (op == nir_op_extract_u8 || op == nir_op_extract_i8) ? 8 : 16;
658 const int num_chunks_in_32 = 32 / chunk_bits;
659
660 nir_def *extract32;
661 if (chunk < num_chunks_in_32) {
662 extract32 = nir_build_alu(b, op, nir_unpack_64_2x32_split_x(b, x),
663 nir_imm_int(b, chunk),
664 NULL, NULL);
665 } else {
666 extract32 = nir_build_alu(b, op, nir_unpack_64_2x32_split_y(b, x),
667 nir_imm_int(b, chunk - num_chunks_in_32),
668 NULL, NULL);
669 }
670
671 if (op == nir_op_extract_i8 || op == nir_op_extract_i16)
672 return lower_i2i64(b, extract32);
673 else
674 return lower_u2u64(b, extract32);
675 }
676
677 static nir_def *
lower_ufind_msb64(nir_builder * b,nir_def * x)678 lower_ufind_msb64(nir_builder *b, nir_def *x)
679 {
680
681 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
682 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
683 nir_def *lo_count = nir_ufind_msb(b, x_lo);
684 nir_def *hi_count = nir_ufind_msb(b, x_hi);
685
686 /* hi_count is either -1 or a value in the range [31, 0]. lo_count is
687 * the same. The imax will pick lo_count only when hi_count is -1. In those
688 * cases, lo_count is guaranteed to be the correct answer.
689 * The ior 32 is always safe here as with -1 the value won't change,
690 * otherwise it adds 32, which is what we want anyway.
691 */
692 return nir_imax(b, lo_count, nir_ior_imm(b, hi_count, 32));
693 }
694
695 static nir_def *
lower_find_lsb64(nir_builder * b,nir_def * x)696 lower_find_lsb64(nir_builder *b, nir_def *x)
697 {
698 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
699 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
700 nir_def *lo_lsb = nir_find_lsb(b, x_lo);
701 nir_def *hi_lsb = nir_find_lsb(b, x_hi);
702
703 /* Use umin so that -1 (no bits found) becomes larger (0xFFFFFFFF)
704 * than any actual bit position, so we return a found bit instead.
705 * This is similar to the ufind_msb lowering.
706 */
707 return nir_umin(b, lo_lsb, nir_ior_imm(b, hi_lsb, 32));
708 }
709
710 static nir_def *
lower_2f(nir_builder * b,nir_def * x,unsigned dest_bit_size,bool src_is_signed)711 lower_2f(nir_builder *b, nir_def *x, unsigned dest_bit_size,
712 bool src_is_signed)
713 {
714 nir_def *x_sign = NULL;
715
716 if (src_is_signed) {
717 x_sign = nir_bcsel(b, COND_LOWER_CMP(b, ilt, x, nir_imm_int64(b, 0)),
718 nir_imm_floatN_t(b, -1, dest_bit_size),
719 nir_imm_floatN_t(b, 1, dest_bit_size));
720 x = COND_LOWER_OP(b, iabs, x);
721 }
722
723 nir_def *exp = COND_LOWER_OP(b, ufind_msb, x);
724 unsigned significand_bits;
725
726 switch (dest_bit_size) {
727 case 64:
728 significand_bits = 52;
729 break;
730 case 32:
731 significand_bits = 23;
732 break;
733 case 16:
734 significand_bits = 10;
735 break;
736 default:
737 unreachable("Invalid dest_bit_size");
738 }
739
740 nir_def *discard =
741 nir_imax(b, nir_iadd_imm(b, exp, -significand_bits),
742 nir_imm_int(b, 0));
743 nir_def *significand = COND_LOWER_OP(b, ushr, x, discard);
744 if (significand_bits < 32)
745 significand = COND_LOWER_CAST(b, u2u32, significand);
746
747 /* Round-to-nearest-even implementation:
748 * - if the non-representable part of the significand is higher than half
749 * the minimum representable significand, we round-up
750 * - if the non-representable part of the significand is equal to half the
751 * minimum representable significand and the representable part of the
752 * significand is odd, we round-up
753 * - in any other case, we round-down
754 */
755 nir_def *lsb_mask = COND_LOWER_OP(b, ishl, nir_imm_int64(b, 1), discard);
756 nir_def *rem_mask = COND_LOWER_OP(b, isub, lsb_mask, nir_imm_int64(b, 1));
757 nir_def *half = COND_LOWER_OP(b, ishr, lsb_mask, nir_imm_int(b, 1));
758 nir_def *rem = COND_LOWER_OP(b, iand, x, rem_mask);
759 nir_def *halfway = nir_iand(b, COND_LOWER_CMP(b, ieq, rem, half),
760 nir_ine_imm(b, discard, 0));
761 nir_def *is_odd = COND_LOWER_CMP(b, ine, nir_imm_int64(b, 0),
762 COND_LOWER_OP(b, iand, x, lsb_mask));
763 nir_def *round_up = nir_ior(b, COND_LOWER_CMP(b, ilt, half, rem),
764 nir_iand(b, halfway, is_odd));
765 if (!nir_is_rounding_mode_rtz(b->shader->info.float_controls_execution_mode,
766 dest_bit_size)) {
767 if (significand_bits >= 32)
768 significand = COND_LOWER_OP(b, iadd, significand,
769 COND_LOWER_CAST(b, b2i64, round_up));
770 else
771 significand = nir_iadd(b, significand, nir_b2i32(b, round_up));
772 }
773
774 nir_def *res;
775
776 if (dest_bit_size == 64) {
777 /* Compute the left shift required to normalize the original
778 * unrounded input manually.
779 */
780 nir_def *shift =
781 nir_imax(b, nir_isub_imm(b, significand_bits, exp),
782 nir_imm_int(b, 0));
783 significand = COND_LOWER_OP(b, ishl, significand, shift);
784
785 /* Check whether normalization led to overflow of the available
786 * significand bits, which can only happen if round_up was true
787 * above, in which case we need to add carry to the exponent and
788 * discard an extra bit from the significand. Note that we
789 * don't need to repeat the round-up logic again, since the LSB
790 * of the significand is guaranteed to be zero if there was
791 * overflow.
792 */
793 nir_def *carry = nir_b2i32(
794 b, nir_uge_imm(b, nir_unpack_64_2x32_split_y(b, significand),
795 (uint64_t)(1 << (significand_bits - 31))));
796 significand = COND_LOWER_OP(b, ishr, significand, carry);
797 exp = nir_iadd(b, exp, carry);
798
799 /* Compute the biased exponent, taking care to handle a zero
800 * input correctly, which would have caused exp to be negative.
801 */
802 nir_def *biased_exp = nir_bcsel(b, nir_ilt_imm(b, exp, 0),
803 nir_imm_int(b, 0),
804 nir_iadd_imm(b, exp, 1023));
805
806 /* Pack the significand and exponent manually. */
807 nir_def *lo = nir_unpack_64_2x32_split_x(b, significand);
808 nir_def *hi = nir_bitfield_insert(
809 b, nir_unpack_64_2x32_split_y(b, significand),
810 biased_exp, nir_imm_int(b, 20), nir_imm_int(b, 11));
811
812 res = nir_pack_64_2x32_split(b, lo, hi);
813
814 } else if (dest_bit_size == 32) {
815 res = nir_fmul(b, nir_u2f32(b, significand),
816 nir_fexp2(b, nir_u2f32(b, discard)));
817 } else {
818 res = nir_fmul(b, nir_u2f16(b, significand),
819 nir_fexp2(b, nir_u2f16(b, discard)));
820 }
821
822 if (src_is_signed)
823 res = nir_fmul(b, res, x_sign);
824
825 return res;
826 }
827
828 static nir_def *
lower_f2(nir_builder * b,nir_def * x,bool dst_is_signed)829 lower_f2(nir_builder *b, nir_def *x, bool dst_is_signed)
830 {
831 assert(x->bit_size == 16 || x->bit_size == 32 || x->bit_size == 64);
832 nir_def *x_sign = NULL;
833
834 if (dst_is_signed)
835 x_sign = nir_fsign(b, x);
836
837 x = nir_ftrunc(b, x);
838
839 if (dst_is_signed)
840 x = nir_fabs(b, x);
841
842 nir_def *res;
843 if (x->bit_size < 32) {
844 res = nir_pack_64_2x32_split(b, nir_f2u32(b, x), nir_imm_int(b, 0));
845 } else {
846 nir_def *div = nir_imm_floatN_t(b, 1ULL << 32, x->bit_size);
847 nir_def *res_hi = nir_f2u32(b, nir_fdiv(b, x, div));
848 nir_def *res_lo = nir_f2u32(b, nir_frem(b, x, div));
849 res = nir_pack_64_2x32_split(b, res_lo, res_hi);
850 }
851
852 if (dst_is_signed)
853 res = nir_bcsel(b, nir_flt_imm(b, x_sign, 0),
854 nir_ineg(b, res), res);
855
856 return res;
857 }
858
859 static nir_def *
lower_bit_count64(nir_builder * b,nir_def * x)860 lower_bit_count64(nir_builder *b, nir_def *x)
861 {
862 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
863 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
864 nir_def *lo_count = nir_bit_count(b, x_lo);
865 nir_def *hi_count = nir_bit_count(b, x_hi);
866 return nir_iadd(b, lo_count, hi_count);
867 }
868
869 nir_lower_int64_options
nir_lower_int64_op_to_options_mask(nir_op opcode)870 nir_lower_int64_op_to_options_mask(nir_op opcode)
871 {
872 switch (opcode) {
873 case nir_op_imul:
874 case nir_op_amul:
875 return nir_lower_imul64;
876 case nir_op_imul_2x32_64:
877 case nir_op_umul_2x32_64:
878 return nir_lower_imul_2x32_64;
879 case nir_op_imul_high:
880 case nir_op_umul_high:
881 return nir_lower_imul_high64;
882 case nir_op_isign:
883 return nir_lower_isign64;
884 case nir_op_udiv:
885 case nir_op_idiv:
886 case nir_op_umod:
887 case nir_op_imod:
888 case nir_op_irem:
889 return nir_lower_divmod64;
890 case nir_op_b2i64:
891 case nir_op_i2i8:
892 case nir_op_i2i16:
893 case nir_op_i2i32:
894 case nir_op_i2i64:
895 case nir_op_u2u8:
896 case nir_op_u2u16:
897 case nir_op_u2u32:
898 case nir_op_u2u64:
899 case nir_op_i2f64:
900 case nir_op_u2f64:
901 case nir_op_i2f32:
902 case nir_op_u2f32:
903 case nir_op_i2f16:
904 case nir_op_u2f16:
905 case nir_op_f2i64:
906 case nir_op_f2u64:
907 return nir_lower_conv64;
908 case nir_op_bcsel:
909 return nir_lower_bcsel64;
910 case nir_op_ieq:
911 case nir_op_ine:
912 case nir_op_ult:
913 case nir_op_ilt:
914 case nir_op_uge:
915 case nir_op_ige:
916 return nir_lower_icmp64;
917 case nir_op_iadd:
918 case nir_op_isub:
919 return nir_lower_iadd64;
920 case nir_op_imin:
921 case nir_op_imax:
922 case nir_op_umin:
923 case nir_op_umax:
924 return nir_lower_minmax64;
925 case nir_op_iabs:
926 return nir_lower_iabs64;
927 case nir_op_ineg:
928 return nir_lower_ineg64;
929 case nir_op_iand:
930 case nir_op_ior:
931 case nir_op_ixor:
932 case nir_op_inot:
933 return nir_lower_logic64;
934 case nir_op_ishl:
935 case nir_op_ishr:
936 case nir_op_ushr:
937 return nir_lower_shift64;
938 case nir_op_extract_u8:
939 case nir_op_extract_i8:
940 case nir_op_extract_u16:
941 case nir_op_extract_i16:
942 return nir_lower_extract64;
943 case nir_op_ufind_msb:
944 return nir_lower_ufind_msb64;
945 case nir_op_find_lsb:
946 return nir_lower_find_lsb64;
947 case nir_op_bit_count:
948 return nir_lower_bit_count64;
949 default:
950 return 0;
951 }
952 }
953
954 static nir_def *
lower_int64_alu_instr(nir_builder * b,nir_alu_instr * alu)955 lower_int64_alu_instr(nir_builder *b, nir_alu_instr *alu)
956 {
957 nir_def *src[4];
958 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
959 src[i] = nir_ssa_for_alu_src(b, alu, i);
960
961 switch (alu->op) {
962 case nir_op_imul:
963 case nir_op_amul:
964 return lower_imul64(b, src[0], src[1]);
965 case nir_op_imul_2x32_64:
966 return lower_mul_2x32_64(b, src[0], src[1], true);
967 case nir_op_umul_2x32_64:
968 return lower_mul_2x32_64(b, src[0], src[1], false);
969 case nir_op_imul_high:
970 return lower_mul_high64(b, src[0], src[1], true);
971 case nir_op_umul_high:
972 return lower_mul_high64(b, src[0], src[1], false);
973 case nir_op_isign:
974 return lower_isign64(b, src[0]);
975 case nir_op_udiv:
976 return lower_udiv64(b, src[0], src[1]);
977 case nir_op_idiv:
978 return lower_idiv64(b, src[0], src[1]);
979 case nir_op_umod:
980 return lower_umod64(b, src[0], src[1]);
981 case nir_op_imod:
982 return lower_imod64(b, src[0], src[1]);
983 case nir_op_irem:
984 return lower_irem64(b, src[0], src[1]);
985 case nir_op_b2i64:
986 return lower_b2i64(b, src[0]);
987 case nir_op_i2i8:
988 return lower_i2i8(b, src[0]);
989 case nir_op_i2i16:
990 return lower_i2i16(b, src[0]);
991 case nir_op_i2i32:
992 return lower_i2i32(b, src[0]);
993 case nir_op_i2i64:
994 return lower_i2i64(b, src[0]);
995 case nir_op_u2u8:
996 return lower_u2u8(b, src[0]);
997 case nir_op_u2u16:
998 return lower_u2u16(b, src[0]);
999 case nir_op_u2u32:
1000 return lower_u2u32(b, src[0]);
1001 case nir_op_u2u64:
1002 return lower_u2u64(b, src[0]);
1003 case nir_op_bcsel:
1004 return lower_bcsel64(b, src[0], src[1], src[2]);
1005 case nir_op_ieq:
1006 case nir_op_ine:
1007 case nir_op_ult:
1008 case nir_op_ilt:
1009 case nir_op_uge:
1010 case nir_op_ige:
1011 return lower_int64_compare(b, alu->op, src[0], src[1]);
1012 case nir_op_iadd:
1013 return lower_iadd64(b, src[0], src[1]);
1014 case nir_op_isub:
1015 return lower_isub64(b, src[0], src[1]);
1016 case nir_op_imin:
1017 return lower_imin64(b, src[0], src[1]);
1018 case nir_op_imax:
1019 return lower_imax64(b, src[0], src[1]);
1020 case nir_op_umin:
1021 return lower_umin64(b, src[0], src[1]);
1022 case nir_op_umax:
1023 return lower_umax64(b, src[0], src[1]);
1024 case nir_op_iabs:
1025 return lower_iabs64(b, src[0]);
1026 case nir_op_ineg:
1027 return lower_ineg64(b, src[0]);
1028 case nir_op_iand:
1029 return lower_iand64(b, src[0], src[1]);
1030 case nir_op_ior:
1031 return lower_ior64(b, src[0], src[1]);
1032 case nir_op_ixor:
1033 return lower_ixor64(b, src[0], src[1]);
1034 case nir_op_inot:
1035 return lower_inot64(b, src[0]);
1036 case nir_op_ishl:
1037 return lower_ishl64(b, src[0], src[1]);
1038 case nir_op_ishr:
1039 return lower_ishr64(b, src[0], src[1]);
1040 case nir_op_ushr:
1041 return lower_ushr64(b, src[0], src[1]);
1042 case nir_op_extract_u8:
1043 case nir_op_extract_i8:
1044 case nir_op_extract_u16:
1045 case nir_op_extract_i16:
1046 return lower_extract(b, alu->op, src[0], src[1]);
1047 case nir_op_ufind_msb:
1048 return lower_ufind_msb64(b, src[0]);
1049 case nir_op_find_lsb:
1050 return lower_find_lsb64(b, src[0]);
1051 case nir_op_bit_count:
1052 return lower_bit_count64(b, src[0]);
1053 case nir_op_i2f64:
1054 case nir_op_i2f32:
1055 case nir_op_i2f16:
1056 return lower_2f(b, src[0], alu->def.bit_size, true);
1057 case nir_op_u2f64:
1058 case nir_op_u2f32:
1059 case nir_op_u2f16:
1060 return lower_2f(b, src[0], alu->def.bit_size, false);
1061 case nir_op_f2i64:
1062 case nir_op_f2u64:
1063 return lower_f2(b, src[0], alu->op == nir_op_f2i64);
1064 default:
1065 unreachable("Invalid ALU opcode to lower");
1066 }
1067 }
1068
1069 static bool
should_lower_int64_alu_instr(const nir_alu_instr * alu,const nir_shader_compiler_options * options)1070 should_lower_int64_alu_instr(const nir_alu_instr *alu,
1071 const nir_shader_compiler_options *options)
1072 {
1073 switch (alu->op) {
1074 case nir_op_i2i8:
1075 case nir_op_i2i16:
1076 case nir_op_i2i32:
1077 case nir_op_u2u8:
1078 case nir_op_u2u16:
1079 case nir_op_u2u32:
1080 if (alu->src[0].src.ssa->bit_size != 64)
1081 return false;
1082 break;
1083 case nir_op_bcsel:
1084 assert(alu->src[1].src.ssa->bit_size ==
1085 alu->src[2].src.ssa->bit_size);
1086 if (alu->src[1].src.ssa->bit_size != 64)
1087 return false;
1088 break;
1089 case nir_op_ieq:
1090 case nir_op_ine:
1091 case nir_op_ult:
1092 case nir_op_ilt:
1093 case nir_op_uge:
1094 case nir_op_ige:
1095 assert(alu->src[0].src.ssa->bit_size ==
1096 alu->src[1].src.ssa->bit_size);
1097 if (alu->src[0].src.ssa->bit_size != 64)
1098 return false;
1099 break;
1100 case nir_op_ufind_msb:
1101 case nir_op_find_lsb:
1102 case nir_op_bit_count:
1103 if (alu->src[0].src.ssa->bit_size != 64)
1104 return false;
1105 break;
1106 case nir_op_amul:
1107 if (options->has_imul24)
1108 return false;
1109 if (alu->def.bit_size != 64)
1110 return false;
1111 break;
1112 case nir_op_i2f64:
1113 case nir_op_u2f64:
1114 case nir_op_i2f32:
1115 case nir_op_u2f32:
1116 case nir_op_i2f16:
1117 case nir_op_u2f16:
1118 if (alu->src[0].src.ssa->bit_size != 64)
1119 return false;
1120 break;
1121 case nir_op_f2u64:
1122 case nir_op_f2i64:
1123 FALLTHROUGH;
1124 default:
1125 if (alu->def.bit_size != 64)
1126 return false;
1127 break;
1128 }
1129
1130 unsigned mask = nir_lower_int64_op_to_options_mask(alu->op);
1131 return (options->lower_int64_options & mask) != 0;
1132 }
1133
1134 static nir_def *
split_64bit_subgroup_op(nir_builder * b,const nir_intrinsic_instr * intrin)1135 split_64bit_subgroup_op(nir_builder *b, const nir_intrinsic_instr *intrin)
1136 {
1137 const nir_intrinsic_info *info = &nir_intrinsic_infos[intrin->intrinsic];
1138
1139 /* This works on subgroup ops with a single 64-bit source which can be
1140 * trivially lowered by doing the exact same op on both halves.
1141 */
1142 assert(nir_src_bit_size(intrin->src[0]) == 64);
1143 nir_def *split_src0[2] = {
1144 nir_unpack_64_2x32_split_x(b, intrin->src[0].ssa),
1145 nir_unpack_64_2x32_split_y(b, intrin->src[0].ssa),
1146 };
1147
1148 assert(info->has_dest && intrin->def.bit_size == 64);
1149
1150 nir_def *res[2];
1151 for (unsigned i = 0; i < 2; i++) {
1152 nir_intrinsic_instr *split =
1153 nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
1154 split->num_components = intrin->num_components;
1155 split->src[0] = nir_src_for_ssa(split_src0[i]);
1156
1157 /* Other sources must be less than 64 bits and get copied directly */
1158 for (unsigned j = 1; j < info->num_srcs; j++) {
1159 assert(nir_src_bit_size(intrin->src[j]) < 64);
1160 split->src[j] = nir_src_for_ssa(intrin->src[j].ssa);
1161 }
1162
1163 /* Copy const indices, if any */
1164 memcpy(split->const_index, intrin->const_index,
1165 sizeof(intrin->const_index));
1166
1167 nir_def_init(&split->instr, &split->def,
1168 intrin->def.num_components, 32);
1169 nir_builder_instr_insert(b, &split->instr);
1170
1171 res[i] = &split->def;
1172 }
1173
1174 return nir_pack_64_2x32_split(b, res[0], res[1]);
1175 }
1176
1177 static nir_def *
build_vote_ieq(nir_builder * b,nir_def * x)1178 build_vote_ieq(nir_builder *b, nir_def *x)
1179 {
1180 nir_intrinsic_instr *vote =
1181 nir_intrinsic_instr_create(b->shader, nir_intrinsic_vote_ieq);
1182 vote->src[0] = nir_src_for_ssa(x);
1183 vote->num_components = x->num_components;
1184 nir_def_init(&vote->instr, &vote->def, 1, 1);
1185 nir_builder_instr_insert(b, &vote->instr);
1186 return &vote->def;
1187 }
1188
1189 static nir_def *
lower_vote_ieq(nir_builder * b,nir_def * x)1190 lower_vote_ieq(nir_builder *b, nir_def *x)
1191 {
1192 return nir_iand(b, build_vote_ieq(b, nir_unpack_64_2x32_split_x(b, x)),
1193 build_vote_ieq(b, nir_unpack_64_2x32_split_y(b, x)));
1194 }
1195
1196 static nir_def *
build_scan_intrinsic(nir_builder * b,nir_intrinsic_op scan_op,nir_op reduction_op,unsigned cluster_size,nir_def * val)1197 build_scan_intrinsic(nir_builder *b, nir_intrinsic_op scan_op,
1198 nir_op reduction_op, unsigned cluster_size,
1199 nir_def *val)
1200 {
1201 nir_intrinsic_instr *scan =
1202 nir_intrinsic_instr_create(b->shader, scan_op);
1203 scan->num_components = val->num_components;
1204 scan->src[0] = nir_src_for_ssa(val);
1205 nir_intrinsic_set_reduction_op(scan, reduction_op);
1206 if (scan_op == nir_intrinsic_reduce)
1207 nir_intrinsic_set_cluster_size(scan, cluster_size);
1208 nir_def_init(&scan->instr, &scan->def, val->num_components,
1209 val->bit_size);
1210 nir_builder_instr_insert(b, &scan->instr);
1211 return &scan->def;
1212 }
1213
1214 static nir_def *
lower_scan_iadd64(nir_builder * b,const nir_intrinsic_instr * intrin)1215 lower_scan_iadd64(nir_builder *b, const nir_intrinsic_instr *intrin)
1216 {
1217 unsigned cluster_size =
1218 intrin->intrinsic == nir_intrinsic_reduce ? nir_intrinsic_cluster_size(intrin) : 0;
1219
1220 /* Split it into three chunks of no more than 24 bits each. With 8 bits
1221 * of headroom, we're guaranteed that there will never be overflow in the
1222 * individual subgroup operations. (Assuming, of course, a subgroup size
1223 * no larger than 256 which seems reasonable.) We can then scan on each of
1224 * the chunks and add them back together at the end.
1225 */
1226 nir_def *x = intrin->src[0].ssa;
1227 nir_def *x_low =
1228 nir_u2u32(b, nir_iand_imm(b, x, 0xffffff));
1229 nir_def *x_mid =
1230 nir_u2u32(b, nir_iand_imm(b, nir_ushr_imm(b, x, 24),
1231 0xffffff));
1232 nir_def *x_hi =
1233 nir_u2u32(b, nir_ushr_imm(b, x, 48));
1234
1235 nir_def *scan_low =
1236 build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd,
1237 cluster_size, x_low);
1238 nir_def *scan_mid =
1239 build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd,
1240 cluster_size, x_mid);
1241 nir_def *scan_hi =
1242 build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd,
1243 cluster_size, x_hi);
1244
1245 scan_low = nir_u2u64(b, scan_low);
1246 scan_mid = nir_ishl_imm(b, nir_u2u64(b, scan_mid), 24);
1247 scan_hi = nir_ishl_imm(b, nir_u2u64(b, scan_hi), 48);
1248
1249 return nir_iadd(b, scan_hi, nir_iadd(b, scan_mid, scan_low));
1250 }
1251
1252 static bool
should_lower_int64_intrinsic(const nir_intrinsic_instr * intrin,const nir_shader_compiler_options * options)1253 should_lower_int64_intrinsic(const nir_intrinsic_instr *intrin,
1254 const nir_shader_compiler_options *options)
1255 {
1256 switch (intrin->intrinsic) {
1257 case nir_intrinsic_read_invocation:
1258 case nir_intrinsic_read_first_invocation:
1259 case nir_intrinsic_shuffle:
1260 case nir_intrinsic_shuffle_xor:
1261 case nir_intrinsic_shuffle_up:
1262 case nir_intrinsic_shuffle_down:
1263 case nir_intrinsic_quad_broadcast:
1264 case nir_intrinsic_quad_swap_horizontal:
1265 case nir_intrinsic_quad_swap_vertical:
1266 case nir_intrinsic_quad_swap_diagonal:
1267 return intrin->def.bit_size == 64 &&
1268 (options->lower_int64_options & nir_lower_subgroup_shuffle64);
1269
1270 case nir_intrinsic_vote_ieq:
1271 return intrin->src[0].ssa->bit_size == 64 &&
1272 (options->lower_int64_options & nir_lower_vote_ieq64);
1273
1274 case nir_intrinsic_reduce:
1275 case nir_intrinsic_inclusive_scan:
1276 case nir_intrinsic_exclusive_scan:
1277 if (intrin->def.bit_size != 64)
1278 return false;
1279
1280 switch (nir_intrinsic_reduction_op(intrin)) {
1281 case nir_op_iadd:
1282 return options->lower_int64_options & nir_lower_scan_reduce_iadd64;
1283 case nir_op_iand:
1284 case nir_op_ior:
1285 case nir_op_ixor:
1286 return options->lower_int64_options & nir_lower_scan_reduce_bitwise64;
1287 default:
1288 return false;
1289 }
1290 break;
1291
1292 default:
1293 return false;
1294 }
1295 }
1296
1297 static nir_def *
lower_int64_intrinsic(nir_builder * b,nir_intrinsic_instr * intrin)1298 lower_int64_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin)
1299 {
1300 switch (intrin->intrinsic) {
1301 case nir_intrinsic_read_invocation:
1302 case nir_intrinsic_read_first_invocation:
1303 case nir_intrinsic_shuffle:
1304 case nir_intrinsic_shuffle_xor:
1305 case nir_intrinsic_shuffle_up:
1306 case nir_intrinsic_shuffle_down:
1307 case nir_intrinsic_quad_broadcast:
1308 case nir_intrinsic_quad_swap_horizontal:
1309 case nir_intrinsic_quad_swap_vertical:
1310 case nir_intrinsic_quad_swap_diagonal:
1311 return split_64bit_subgroup_op(b, intrin);
1312
1313 case nir_intrinsic_vote_ieq:
1314 return lower_vote_ieq(b, intrin->src[0].ssa);
1315
1316 case nir_intrinsic_reduce:
1317 case nir_intrinsic_inclusive_scan:
1318 case nir_intrinsic_exclusive_scan:
1319 switch (nir_intrinsic_reduction_op(intrin)) {
1320 case nir_op_iadd:
1321 return lower_scan_iadd64(b, intrin);
1322 case nir_op_iand:
1323 case nir_op_ior:
1324 case nir_op_ixor:
1325 return split_64bit_subgroup_op(b, intrin);
1326 default:
1327 unreachable("Unsupported subgroup scan/reduce op");
1328 }
1329 break;
1330
1331 default:
1332 unreachable("Unsupported intrinsic");
1333 }
1334 return NULL;
1335 }
1336
1337 static bool
should_lower_int64_instr(const nir_instr * instr,const void * _options)1338 should_lower_int64_instr(const nir_instr *instr, const void *_options)
1339 {
1340 switch (instr->type) {
1341 case nir_instr_type_alu:
1342 return should_lower_int64_alu_instr(nir_instr_as_alu(instr), _options);
1343 case nir_instr_type_intrinsic:
1344 return should_lower_int64_intrinsic(nir_instr_as_intrinsic(instr),
1345 _options);
1346 default:
1347 return false;
1348 }
1349 }
1350
1351 static nir_def *
lower_int64_instr(nir_builder * b,nir_instr * instr,void * _options)1352 lower_int64_instr(nir_builder *b, nir_instr *instr, void *_options)
1353 {
1354 switch (instr->type) {
1355 case nir_instr_type_alu:
1356 return lower_int64_alu_instr(b, nir_instr_as_alu(instr));
1357 case nir_instr_type_intrinsic:
1358 return lower_int64_intrinsic(b, nir_instr_as_intrinsic(instr));
1359 default:
1360 return NULL;
1361 }
1362 }
1363
1364 bool
nir_lower_int64(nir_shader * shader)1365 nir_lower_int64(nir_shader *shader)
1366 {
1367 return nir_shader_lower_instructions(shader, should_lower_int64_instr,
1368 lower_int64_instr,
1369 (void *)shader->options);
1370 }
1371
1372 static bool
should_lower_int64_float_conv(const nir_instr * instr,const void * _options)1373 should_lower_int64_float_conv(const nir_instr *instr, const void *_options)
1374 {
1375 if (instr->type != nir_instr_type_alu)
1376 return false;
1377
1378 nir_alu_instr *alu = nir_instr_as_alu(instr);
1379
1380 switch (alu->op) {
1381 case nir_op_i2f64:
1382 case nir_op_i2f32:
1383 case nir_op_i2f16:
1384 case nir_op_u2f64:
1385 case nir_op_u2f32:
1386 case nir_op_u2f16:
1387 case nir_op_f2i64:
1388 case nir_op_f2u64:
1389 return should_lower_int64_alu_instr(alu, _options);
1390 default:
1391 return false;
1392 }
1393 }
1394
1395 /**
1396 * Like nir_lower_int64(), but only lowers conversions to/from float.
1397 *
1398 * These operations in particular may affect double-precision lowering,
1399 * so it can be useful to run them in tandem with nir_lower_doubles().
1400 */
1401 bool
nir_lower_int64_float_conversions(nir_shader * shader)1402 nir_lower_int64_float_conversions(nir_shader *shader)
1403 {
1404 return nir_shader_lower_instructions(shader, should_lower_int64_float_conv,
1405 lower_int64_instr,
1406 (void *)shader->options);
1407 }
1408