xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_lower_int64.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2016 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 
27 #define COND_LOWER_OP(b, name, ...)                    \
28    (b->shader->options->lower_int64_options &          \
29     nir_lower_int64_op_to_options_mask(nir_op_##name)) \
30       ? lower_##name##64(b, __VA_ARGS__)               \
31       : nir_##name(b, __VA_ARGS__)
32 
33 #define COND_LOWER_CMP(b, name, ...)                       \
34    (b->shader->options->lower_int64_options &              \
35     nir_lower_int64_op_to_options_mask(nir_op_##name))     \
36       ? lower_int64_compare(b, nir_op_##name, __VA_ARGS__) \
37       : nir_##name(b, __VA_ARGS__)
38 
39 #define COND_LOWER_CAST(b, name, ...)                  \
40    (b->shader->options->lower_int64_options &          \
41     nir_lower_int64_op_to_options_mask(nir_op_##name)) \
42       ? lower_##name(b, __VA_ARGS__)                   \
43       : nir_##name(b, __VA_ARGS__)
44 
45 static nir_def *
lower_b2i64(nir_builder * b,nir_def * x)46 lower_b2i64(nir_builder *b, nir_def *x)
47 {
48    return nir_pack_64_2x32_split(b, nir_b2i32(b, x), nir_imm_int(b, 0));
49 }
50 
51 static nir_def *
lower_i2i8(nir_builder * b,nir_def * x)52 lower_i2i8(nir_builder *b, nir_def *x)
53 {
54    return nir_i2i8(b, nir_unpack_64_2x32_split_x(b, x));
55 }
56 
57 static nir_def *
lower_i2i16(nir_builder * b,nir_def * x)58 lower_i2i16(nir_builder *b, nir_def *x)
59 {
60    return nir_i2i16(b, nir_unpack_64_2x32_split_x(b, x));
61 }
62 
63 static nir_def *
lower_i2i32(nir_builder * b,nir_def * x)64 lower_i2i32(nir_builder *b, nir_def *x)
65 {
66    return nir_unpack_64_2x32_split_x(b, x);
67 }
68 
69 static nir_def *
lower_i2i64(nir_builder * b,nir_def * x)70 lower_i2i64(nir_builder *b, nir_def *x)
71 {
72    nir_def *x32 = x->bit_size == 32 ? x : nir_i2i32(b, x);
73    return nir_pack_64_2x32_split(b, x32, nir_ishr_imm(b, x32, 31));
74 }
75 
76 static nir_def *
lower_u2u8(nir_builder * b,nir_def * x)77 lower_u2u8(nir_builder *b, nir_def *x)
78 {
79    return nir_u2u8(b, nir_unpack_64_2x32_split_x(b, x));
80 }
81 
82 static nir_def *
lower_u2u16(nir_builder * b,nir_def * x)83 lower_u2u16(nir_builder *b, nir_def *x)
84 {
85    return nir_u2u16(b, nir_unpack_64_2x32_split_x(b, x));
86 }
87 
88 static nir_def *
lower_u2u32(nir_builder * b,nir_def * x)89 lower_u2u32(nir_builder *b, nir_def *x)
90 {
91    return nir_unpack_64_2x32_split_x(b, x);
92 }
93 
94 static nir_def *
lower_u2u64(nir_builder * b,nir_def * x)95 lower_u2u64(nir_builder *b, nir_def *x)
96 {
97    nir_def *x32 = x->bit_size == 32 ? x : nir_u2u32(b, x);
98    return nir_pack_64_2x32_split(b, x32, nir_imm_int(b, 0));
99 }
100 
101 static nir_def *
lower_bcsel64(nir_builder * b,nir_def * cond,nir_def * x,nir_def * y)102 lower_bcsel64(nir_builder *b, nir_def *cond, nir_def *x, nir_def *y)
103 {
104    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
105    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
106    nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
107    nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
108 
109    return nir_pack_64_2x32_split(b, nir_bcsel(b, cond, x_lo, y_lo),
110                                  nir_bcsel(b, cond, x_hi, y_hi));
111 }
112 
113 static nir_def *
lower_inot64(nir_builder * b,nir_def * x)114 lower_inot64(nir_builder *b, nir_def *x)
115 {
116    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
117    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
118 
119    return nir_pack_64_2x32_split(b, nir_inot(b, x_lo), nir_inot(b, x_hi));
120 }
121 
122 static nir_def *
lower_iand64(nir_builder * b,nir_def * x,nir_def * y)123 lower_iand64(nir_builder *b, nir_def *x, nir_def *y)
124 {
125    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
126    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
127    nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
128    nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
129 
130    return nir_pack_64_2x32_split(b, nir_iand(b, x_lo, y_lo),
131                                  nir_iand(b, x_hi, y_hi));
132 }
133 
134 static nir_def *
lower_ior64(nir_builder * b,nir_def * x,nir_def * y)135 lower_ior64(nir_builder *b, nir_def *x, nir_def *y)
136 {
137    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
138    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
139    nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
140    nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
141 
142    return nir_pack_64_2x32_split(b, nir_ior(b, x_lo, y_lo),
143                                  nir_ior(b, x_hi, y_hi));
144 }
145 
146 static nir_def *
lower_ixor64(nir_builder * b,nir_def * x,nir_def * y)147 lower_ixor64(nir_builder *b, nir_def *x, nir_def *y)
148 {
149    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
150    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
151    nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
152    nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
153 
154    return nir_pack_64_2x32_split(b, nir_ixor(b, x_lo, y_lo),
155                                  nir_ixor(b, x_hi, y_hi));
156 }
157 
158 static nir_def *
lower_ishl64(nir_builder * b,nir_def * x,nir_def * y)159 lower_ishl64(nir_builder *b, nir_def *x, nir_def *y)
160 {
161    /* Implemented as
162     *
163     * uint64_t lshift(uint64_t x, int c)
164     * {
165     *    c %= 64;
166     *
167     *    if (c == 0) return x;
168     *
169     *    uint32_t lo = LO(x), hi = HI(x);
170     *
171     *    if (c < 32) {
172     *       uint32_t lo_shifted = lo << c;
173     *       uint32_t hi_shifted = hi << c;
174     *       uint32_t lo_shifted_hi = lo >> abs(32 - c);
175     *       return pack_64(lo_shifted, hi_shifted | lo_shifted_hi);
176     *    } else {
177     *       uint32_t lo_shifted_hi = lo << abs(32 - c);
178     *       return pack_64(0, lo_shifted_hi);
179     *    }
180     * }
181     */
182    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
183    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
184    y = nir_iand_imm(b, y, 0x3f);
185 
186    nir_def *reverse_count = nir_iabs(b, nir_iadd_imm(b, y, -32));
187    nir_def *lo_shifted = nir_ishl(b, x_lo, y);
188    nir_def *hi_shifted = nir_ishl(b, x_hi, y);
189    nir_def *lo_shifted_hi = nir_ushr(b, x_lo, reverse_count);
190 
191    nir_def *res_if_lt_32 =
192       nir_pack_64_2x32_split(b, lo_shifted,
193                              nir_ior(b, hi_shifted, lo_shifted_hi));
194    nir_def *res_if_ge_32 =
195       nir_pack_64_2x32_split(b, nir_imm_int(b, 0),
196                              nir_ishl(b, x_lo, reverse_count));
197 
198    return nir_bcsel(b, nir_ieq_imm(b, y, 0), x,
199                     nir_bcsel(b, nir_uge_imm(b, y, 32),
200                               res_if_ge_32, res_if_lt_32));
201 }
202 
203 static nir_def *
lower_ishr64(nir_builder * b,nir_def * x,nir_def * y)204 lower_ishr64(nir_builder *b, nir_def *x, nir_def *y)
205 {
206    /* Implemented as
207     *
208     * uint64_t arshift(uint64_t x, int c)
209     * {
210     *    c %= 64;
211     *
212     *    if (c == 0) return x;
213     *
214     *    uint32_t lo = LO(x);
215     *    int32_t  hi = HI(x);
216     *
217     *    if (c < 32) {
218     *       uint32_t lo_shifted = lo >> c;
219     *       uint32_t hi_shifted = hi >> c;
220     *       uint32_t hi_shifted_lo = hi << abs(32 - c);
221     *       return pack_64(hi_shifted, hi_shifted_lo | lo_shifted);
222     *    } else {
223     *       uint32_t hi_shifted = hi >> 31;
224     *       uint32_t hi_shifted_lo = hi >> abs(32 - c);
225     *       return pack_64(hi_shifted, hi_shifted_lo);
226     *    }
227     * }
228     */
229    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
230    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
231    y = nir_iand_imm(b, y, 0x3f);
232 
233    nir_def *reverse_count = nir_iabs(b, nir_iadd_imm(b, y, -32));
234    nir_def *lo_shifted = nir_ushr(b, x_lo, y);
235    nir_def *hi_shifted = nir_ishr(b, x_hi, y);
236    nir_def *hi_shifted_lo = nir_ishl(b, x_hi, reverse_count);
237 
238    nir_def *res_if_lt_32 =
239       nir_pack_64_2x32_split(b, nir_ior(b, lo_shifted, hi_shifted_lo),
240                              hi_shifted);
241    nir_def *res_if_ge_32 =
242       nir_pack_64_2x32_split(b, nir_ishr(b, x_hi, reverse_count),
243                              nir_ishr_imm(b, x_hi, 31));
244 
245    return nir_bcsel(b, nir_ieq_imm(b, y, 0), x,
246                     nir_bcsel(b, nir_uge_imm(b, y, 32),
247                               res_if_ge_32, res_if_lt_32));
248 }
249 
250 static nir_def *
lower_ushr64(nir_builder * b,nir_def * x,nir_def * y)251 lower_ushr64(nir_builder *b, nir_def *x, nir_def *y)
252 {
253    /* Implemented as
254     *
255     * uint64_t rshift(uint64_t x, int c)
256     * {
257     *    c %= 64;
258     *
259     *    if (c == 0) return x;
260     *
261     *    uint32_t lo = LO(x), hi = HI(x);
262     *
263     *    if (c < 32) {
264     *       uint32_t lo_shifted = lo >> c;
265     *       uint32_t hi_shifted = hi >> c;
266     *       uint32_t hi_shifted_lo = hi << abs(32 - c);
267     *       return pack_64(hi_shifted, hi_shifted_lo | lo_shifted);
268     *    } else {
269     *       uint32_t hi_shifted_lo = hi >> abs(32 - c);
270     *       return pack_64(0, hi_shifted_lo);
271     *    }
272     * }
273     */
274 
275    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
276    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
277    y = nir_iand_imm(b, y, 0x3f);
278 
279    nir_def *reverse_count = nir_iabs(b, nir_iadd_imm(b, y, -32));
280    nir_def *lo_shifted = nir_ushr(b, x_lo, y);
281    nir_def *hi_shifted = nir_ushr(b, x_hi, y);
282    nir_def *hi_shifted_lo = nir_ishl(b, x_hi, reverse_count);
283 
284    nir_def *res_if_lt_32 =
285       nir_pack_64_2x32_split(b, nir_ior(b, lo_shifted, hi_shifted_lo),
286                              hi_shifted);
287    nir_def *res_if_ge_32 =
288       nir_pack_64_2x32_split(b, nir_ushr(b, x_hi, reverse_count),
289                              nir_imm_int(b, 0));
290 
291    return nir_bcsel(b, nir_ieq_imm(b, y, 0), x,
292                     nir_bcsel(b, nir_uge_imm(b, y, 32),
293                               res_if_ge_32, res_if_lt_32));
294 }
295 
296 static nir_def *
lower_iadd64(nir_builder * b,nir_def * x,nir_def * y)297 lower_iadd64(nir_builder *b, nir_def *x, nir_def *y)
298 {
299    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
300    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
301    nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
302    nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
303 
304    nir_def *res_lo = nir_iadd(b, x_lo, y_lo);
305    nir_def *carry = nir_b2i32(b, nir_ult(b, res_lo, x_lo));
306    nir_def *res_hi = nir_iadd(b, carry, nir_iadd(b, x_hi, y_hi));
307 
308    return nir_pack_64_2x32_split(b, res_lo, res_hi);
309 }
310 
311 static nir_def *
lower_isub64(nir_builder * b,nir_def * x,nir_def * y)312 lower_isub64(nir_builder *b, nir_def *x, nir_def *y)
313 {
314    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
315    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
316    nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
317    nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
318 
319    nir_def *res_lo = nir_isub(b, x_lo, y_lo);
320    nir_def *borrow = nir_ineg(b, nir_b2i32(b, nir_ult(b, x_lo, y_lo)));
321    nir_def *res_hi = nir_iadd(b, nir_isub(b, x_hi, y_hi), borrow);
322 
323    return nir_pack_64_2x32_split(b, res_lo, res_hi);
324 }
325 
326 static nir_def *
lower_ineg64(nir_builder * b,nir_def * x)327 lower_ineg64(nir_builder *b, nir_def *x)
328 {
329    /* Since isub is the same number of instructions (with better dependencies)
330     * as iadd, subtraction is actually more efficient for ineg than the usual
331     * 2's complement "flip the bits and add one".
332     */
333    return lower_isub64(b, nir_imm_int64(b, 0), x);
334 }
335 
336 static nir_def *
lower_iabs64(nir_builder * b,nir_def * x)337 lower_iabs64(nir_builder *b, nir_def *x)
338 {
339    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
340    nir_def *x_is_neg = nir_ilt_imm(b, x_hi, 0);
341    return nir_bcsel(b, x_is_neg, nir_ineg(b, x), x);
342 }
343 
344 static nir_def *
lower_int64_compare(nir_builder * b,nir_op op,nir_def * x,nir_def * y)345 lower_int64_compare(nir_builder *b, nir_op op, nir_def *x, nir_def *y)
346 {
347    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
348    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
349    nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
350    nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
351 
352    switch (op) {
353    case nir_op_ieq:
354       return nir_iand(b, nir_ieq(b, x_hi, y_hi), nir_ieq(b, x_lo, y_lo));
355    case nir_op_ine:
356       return nir_ior(b, nir_ine(b, x_hi, y_hi), nir_ine(b, x_lo, y_lo));
357    case nir_op_ult:
358       return nir_ior(b, nir_ult(b, x_hi, y_hi),
359                      nir_iand(b, nir_ieq(b, x_hi, y_hi),
360                               nir_ult(b, x_lo, y_lo)));
361    case nir_op_ilt:
362       return nir_ior(b, nir_ilt(b, x_hi, y_hi),
363                      nir_iand(b, nir_ieq(b, x_hi, y_hi),
364                               nir_ult(b, x_lo, y_lo)));
365       break;
366    case nir_op_uge:
367       /* Lower as !(x < y) in the hopes of better CSE */
368       return nir_inot(b, lower_int64_compare(b, nir_op_ult, x, y));
369    case nir_op_ige:
370       /* Lower as !(x < y) in the hopes of better CSE */
371       return nir_inot(b, lower_int64_compare(b, nir_op_ilt, x, y));
372    default:
373       unreachable("Invalid comparison");
374    }
375 }
376 
377 static nir_def *
lower_umax64(nir_builder * b,nir_def * x,nir_def * y)378 lower_umax64(nir_builder *b, nir_def *x, nir_def *y)
379 {
380    return nir_bcsel(b, COND_LOWER_CMP(b, ult, x, y), y, x);
381 }
382 
383 static nir_def *
lower_imax64(nir_builder * b,nir_def * x,nir_def * y)384 lower_imax64(nir_builder *b, nir_def *x, nir_def *y)
385 {
386    return nir_bcsel(b, COND_LOWER_CMP(b, ilt, x, y), y, x);
387 }
388 
389 static nir_def *
lower_umin64(nir_builder * b,nir_def * x,nir_def * y)390 lower_umin64(nir_builder *b, nir_def *x, nir_def *y)
391 {
392    return nir_bcsel(b, COND_LOWER_CMP(b, ult, x, y), x, y);
393 }
394 
395 static nir_def *
lower_imin64(nir_builder * b,nir_def * x,nir_def * y)396 lower_imin64(nir_builder *b, nir_def *x, nir_def *y)
397 {
398    return nir_bcsel(b, COND_LOWER_CMP(b, ilt, x, y), x, y);
399 }
400 
401 static nir_def *
lower_mul_2x32_64(nir_builder * b,nir_def * x,nir_def * y,bool sign_extend)402 lower_mul_2x32_64(nir_builder *b, nir_def *x, nir_def *y,
403                   bool sign_extend)
404 {
405    nir_def *res_hi = sign_extend ? nir_imul_high(b, x, y)
406                                  : nir_umul_high(b, x, y);
407 
408    return nir_pack_64_2x32_split(b, nir_imul(b, x, y), res_hi);
409 }
410 
411 static nir_def *
lower_imul64(nir_builder * b,nir_def * x,nir_def * y)412 lower_imul64(nir_builder *b, nir_def *x, nir_def *y)
413 {
414    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
415    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
416    nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
417    nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
418 
419    nir_def *mul_lo = nir_umul_2x32_64(b, x_lo, y_lo);
420    nir_def *res_hi = nir_iadd(b, nir_unpack_64_2x32_split_y(b, mul_lo),
421                               nir_iadd(b, nir_imul(b, x_lo, y_hi),
422                                        nir_imul(b, x_hi, y_lo)));
423 
424    return nir_pack_64_2x32_split(b, nir_unpack_64_2x32_split_x(b, mul_lo),
425                                  res_hi);
426 }
427 
428 static nir_def *
lower_mul_high64(nir_builder * b,nir_def * x,nir_def * y,bool sign_extend)429 lower_mul_high64(nir_builder *b, nir_def *x, nir_def *y,
430                  bool sign_extend)
431 {
432    nir_def *x32[4], *y32[4];
433    x32[0] = nir_unpack_64_2x32_split_x(b, x);
434    x32[1] = nir_unpack_64_2x32_split_y(b, x);
435    if (sign_extend) {
436       x32[2] = x32[3] = nir_ishr_imm(b, x32[1], 31);
437    } else {
438       x32[2] = x32[3] = nir_imm_int(b, 0);
439    }
440 
441    y32[0] = nir_unpack_64_2x32_split_x(b, y);
442    y32[1] = nir_unpack_64_2x32_split_y(b, y);
443    if (sign_extend) {
444       y32[2] = y32[3] = nir_ishr_imm(b, y32[1], 31);
445    } else {
446       y32[2] = y32[3] = nir_imm_int(b, 0);
447    }
448 
449    nir_def *res[8] = {
450       NULL,
451    };
452 
453    /* Yes, the following generates a pile of code.  However, we throw res[0]
454     * and res[1] away in the end and, if we're in the umul case, four of our
455     * eight dword operands will be constant zero and opt_algebraic will clean
456     * this up nicely.
457     */
458    for (unsigned i = 0; i < 4; i++) {
459       nir_def *carry = NULL;
460       for (unsigned j = 0; j < 4; j++) {
461          /* The maximum values of x32[i] and y32[j] are UINT32_MAX so the
462           * maximum value of tmp is UINT32_MAX * UINT32_MAX.  The maximum
463           * value that will fit in tmp is
464           *
465           *    UINT64_MAX = UINT32_MAX << 32 + UINT32_MAX
466           *               = UINT32_MAX * (UINT32_MAX + 1) + UINT32_MAX
467           *               = UINT32_MAX * UINT32_MAX + 2 * UINT32_MAX
468           *
469           * so we're guaranteed that we can add in two more 32-bit values
470           * without overflowing tmp.
471           */
472          nir_def *tmp = nir_umul_2x32_64(b, x32[i], y32[j]);
473 
474          if (res[i + j])
475             tmp = nir_iadd(b, tmp, nir_u2u64(b, res[i + j]));
476          if (carry)
477             tmp = nir_iadd(b, tmp, carry);
478          res[i + j] = nir_u2u32(b, tmp);
479          carry = nir_ushr_imm(b, tmp, 32);
480       }
481       res[i + 4] = nir_u2u32(b, carry);
482    }
483 
484    return nir_pack_64_2x32_split(b, res[2], res[3]);
485 }
486 
487 static nir_def *
lower_isign64(nir_builder * b,nir_def * x)488 lower_isign64(nir_builder *b, nir_def *x)
489 {
490    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
491    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
492 
493    nir_def *is_non_zero = nir_i2b(b, nir_ior(b, x_lo, x_hi));
494    nir_def *res_hi = nir_ishr_imm(b, x_hi, 31);
495    nir_def *res_lo = nir_ior(b, res_hi, nir_b2i32(b, is_non_zero));
496 
497    return nir_pack_64_2x32_split(b, res_lo, res_hi);
498 }
499 
500 static void
lower_udiv64_mod64(nir_builder * b,nir_def * n,nir_def * d,nir_def ** q,nir_def ** r)501 lower_udiv64_mod64(nir_builder *b, nir_def *n, nir_def *d,
502                    nir_def **q, nir_def **r)
503 {
504    /* TODO: We should specially handle the case where the denominator is a
505     * constant.  In that case, we should be able to reduce it to a multiply by
506     * a constant, some shifts, and an add.
507     */
508    nir_def *n_lo = nir_unpack_64_2x32_split_x(b, n);
509    nir_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
510    nir_def *d_lo = nir_unpack_64_2x32_split_x(b, d);
511    nir_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
512 
513    nir_def *q_lo = nir_imm_zero(b, n->num_components, 32);
514    nir_def *q_hi = nir_imm_zero(b, n->num_components, 32);
515 
516    nir_def *n_hi_before_if = n_hi;
517    nir_def *q_hi_before_if = q_hi;
518 
519    /* If the upper 32 bits of denom are non-zero, it is impossible for shifts
520     * greater than 32 bits to occur.  If the upper 32 bits of the numerator
521     * are zero, it is impossible for (denom << [63, 32]) <= numer unless
522     * denom == 0.
523     */
524    nir_def *need_high_div =
525       nir_iand(b, nir_ieq_imm(b, d_hi, 0), nir_uge(b, n_hi, d_lo));
526    nir_push_if(b, nir_bany(b, need_high_div));
527    {
528       /* If we only have one component, then the bany above goes away and
529        * this is always true within the if statement.
530        */
531       if (n->num_components == 1)
532          need_high_div = nir_imm_true(b);
533 
534       nir_def *log2_d_lo = nir_ufind_msb(b, d_lo);
535 
536       for (int i = 31; i >= 0; i--) {
537          /* if ((d.x << i) <= n.y) {
538           *    n.y -= d.x << i;
539           *    quot.y |= 1U << i;
540           * }
541           */
542          nir_def *d_shift = nir_ishl_imm(b, d_lo, i);
543          nir_def *new_n_hi = nir_isub(b, n_hi, d_shift);
544          nir_def *new_q_hi = nir_ior_imm(b, q_hi, 1ull << i);
545          nir_def *cond = nir_iand(b, need_high_div,
546                                   nir_uge(b, n_hi, d_shift));
547          if (i != 0) {
548             /* log2_d_lo is always <= 31, so we don't need to bother with it
549              * in the last iteration.
550              */
551             cond = nir_iand(b, cond,
552                             nir_ile_imm(b, log2_d_lo, 31 - i));
553          }
554          n_hi = nir_bcsel(b, cond, new_n_hi, n_hi);
555          q_hi = nir_bcsel(b, cond, new_q_hi, q_hi);
556       }
557    }
558    nir_pop_if(b, NULL);
559    n_hi = nir_if_phi(b, n_hi, n_hi_before_if);
560    q_hi = nir_if_phi(b, q_hi, q_hi_before_if);
561 
562    nir_def *log2_denom = nir_ufind_msb(b, d_hi);
563 
564    n = nir_pack_64_2x32_split(b, n_lo, n_hi);
565    d = nir_pack_64_2x32_split(b, d_lo, d_hi);
566    for (int i = 31; i >= 0; i--) {
567       /* if ((d64 << i) <= n64) {
568        *    n64 -= d64 << i;
569        *    quot.x |= 1U << i;
570        * }
571        */
572       nir_def *d_shift = nir_ishl_imm(b, d, i);
573       nir_def *new_n = nir_isub(b, n, d_shift);
574       nir_def *new_q_lo = nir_ior_imm(b, q_lo, 1ull << i);
575       nir_def *cond = nir_uge(b, n, d_shift);
576       if (i != 0) {
577          /* log2_denom is always <= 31, so we don't need to bother with it
578           * in the last iteration.
579           */
580          cond = nir_iand(b, cond,
581                          nir_ile_imm(b, log2_denom, 31 - i));
582       }
583       n = nir_bcsel(b, cond, new_n, n);
584       q_lo = nir_bcsel(b, cond, new_q_lo, q_lo);
585    }
586 
587    *q = nir_pack_64_2x32_split(b, q_lo, q_hi);
588    *r = n;
589 }
590 
591 static nir_def *
lower_udiv64(nir_builder * b,nir_def * n,nir_def * d)592 lower_udiv64(nir_builder *b, nir_def *n, nir_def *d)
593 {
594    nir_def *q, *r;
595    lower_udiv64_mod64(b, n, d, &q, &r);
596    return q;
597 }
598 
599 static nir_def *
lower_idiv64(nir_builder * b,nir_def * n,nir_def * d)600 lower_idiv64(nir_builder *b, nir_def *n, nir_def *d)
601 {
602    nir_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
603    nir_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
604 
605    nir_def *negate = nir_ine(b, nir_ilt_imm(b, n_hi, 0),
606                              nir_ilt_imm(b, d_hi, 0));
607    nir_def *q, *r;
608    lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
609    return nir_bcsel(b, negate, nir_ineg(b, q), q);
610 }
611 
612 static nir_def *
lower_umod64(nir_builder * b,nir_def * n,nir_def * d)613 lower_umod64(nir_builder *b, nir_def *n, nir_def *d)
614 {
615    nir_def *q, *r;
616    lower_udiv64_mod64(b, n, d, &q, &r);
617    return r;
618 }
619 
620 static nir_def *
lower_imod64(nir_builder * b,nir_def * n,nir_def * d)621 lower_imod64(nir_builder *b, nir_def *n, nir_def *d)
622 {
623    nir_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
624    nir_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
625    nir_def *n_is_neg = nir_ilt_imm(b, n_hi, 0);
626    nir_def *d_is_neg = nir_ilt_imm(b, d_hi, 0);
627 
628    nir_def *q, *r;
629    lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
630 
631    nir_def *rem = nir_bcsel(b, n_is_neg, nir_ineg(b, r), r);
632 
633    return nir_bcsel(b, nir_ieq_imm(b, r, 0), nir_imm_int64(b, 0),
634                     nir_bcsel(b, nir_ieq(b, n_is_neg, d_is_neg), rem,
635                               nir_iadd(b, rem, d)));
636 }
637 
638 static nir_def *
lower_irem64(nir_builder * b,nir_def * n,nir_def * d)639 lower_irem64(nir_builder *b, nir_def *n, nir_def *d)
640 {
641    nir_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
642    nir_def *n_is_neg = nir_ilt_imm(b, n_hi, 0);
643 
644    nir_def *q, *r;
645    lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
646    return nir_bcsel(b, n_is_neg, nir_ineg(b, r), r);
647 }
648 
649 static nir_def *
lower_extract(nir_builder * b,nir_op op,nir_def * x,nir_def * c)650 lower_extract(nir_builder *b, nir_op op, nir_def *x, nir_def *c)
651 {
652    assert(op == nir_op_extract_u8 || op == nir_op_extract_i8 ||
653           op == nir_op_extract_u16 || op == nir_op_extract_i16);
654 
655    const int chunk = nir_src_as_uint(nir_src_for_ssa(c));
656    const int chunk_bits =
657       (op == nir_op_extract_u8 || op == nir_op_extract_i8) ? 8 : 16;
658    const int num_chunks_in_32 = 32 / chunk_bits;
659 
660    nir_def *extract32;
661    if (chunk < num_chunks_in_32) {
662       extract32 = nir_build_alu(b, op, nir_unpack_64_2x32_split_x(b, x),
663                                 nir_imm_int(b, chunk),
664                                 NULL, NULL);
665    } else {
666       extract32 = nir_build_alu(b, op, nir_unpack_64_2x32_split_y(b, x),
667                                 nir_imm_int(b, chunk - num_chunks_in_32),
668                                 NULL, NULL);
669    }
670 
671    if (op == nir_op_extract_i8 || op == nir_op_extract_i16)
672       return lower_i2i64(b, extract32);
673    else
674       return lower_u2u64(b, extract32);
675 }
676 
677 static nir_def *
lower_ufind_msb64(nir_builder * b,nir_def * x)678 lower_ufind_msb64(nir_builder *b, nir_def *x)
679 {
680 
681    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
682    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
683    nir_def *lo_count = nir_ufind_msb(b, x_lo);
684    nir_def *hi_count = nir_ufind_msb(b, x_hi);
685 
686    /* hi_count is either -1 or a value in the range [31, 0]. lo_count is
687     * the same. The imax will pick lo_count only when hi_count is -1. In those
688     * cases, lo_count is guaranteed to be the correct answer.
689     * The ior 32 is always safe here as with -1 the value won't change,
690     * otherwise it adds 32, which is what we want anyway.
691     */
692    return nir_imax(b, lo_count, nir_ior_imm(b, hi_count, 32));
693 }
694 
695 static nir_def *
lower_find_lsb64(nir_builder * b,nir_def * x)696 lower_find_lsb64(nir_builder *b, nir_def *x)
697 {
698    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
699    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
700    nir_def *lo_lsb = nir_find_lsb(b, x_lo);
701    nir_def *hi_lsb = nir_find_lsb(b, x_hi);
702 
703    /* Use umin so that -1 (no bits found) becomes larger (0xFFFFFFFF)
704     * than any actual bit position, so we return a found bit instead.
705     * This is similar to the ufind_msb lowering.
706     */
707    return nir_umin(b, lo_lsb, nir_ior_imm(b, hi_lsb, 32));
708 }
709 
710 static nir_def *
lower_2f(nir_builder * b,nir_def * x,unsigned dest_bit_size,bool src_is_signed)711 lower_2f(nir_builder *b, nir_def *x, unsigned dest_bit_size,
712          bool src_is_signed)
713 {
714    nir_def *x_sign = NULL;
715 
716    if (src_is_signed) {
717       x_sign = nir_bcsel(b, COND_LOWER_CMP(b, ilt, x, nir_imm_int64(b, 0)),
718                          nir_imm_floatN_t(b, -1, dest_bit_size),
719                          nir_imm_floatN_t(b, 1, dest_bit_size));
720       x = COND_LOWER_OP(b, iabs, x);
721    }
722 
723    nir_def *exp = COND_LOWER_OP(b, ufind_msb, x);
724    unsigned significand_bits;
725 
726    switch (dest_bit_size) {
727    case 64:
728       significand_bits = 52;
729       break;
730    case 32:
731       significand_bits = 23;
732       break;
733    case 16:
734       significand_bits = 10;
735       break;
736    default:
737       unreachable("Invalid dest_bit_size");
738    }
739 
740    nir_def *discard =
741       nir_imax(b, nir_iadd_imm(b, exp, -significand_bits),
742                nir_imm_int(b, 0));
743    nir_def *significand = COND_LOWER_OP(b, ushr, x, discard);
744    if (significand_bits < 32)
745       significand = COND_LOWER_CAST(b, u2u32, significand);
746 
747    /* Round-to-nearest-even implementation:
748     * - if the non-representable part of the significand is higher than half
749     *   the minimum representable significand, we round-up
750     * - if the non-representable part of the significand is equal to half the
751     *   minimum representable significand and the representable part of the
752     *   significand is odd, we round-up
753     * - in any other case, we round-down
754     */
755    nir_def *lsb_mask = COND_LOWER_OP(b, ishl, nir_imm_int64(b, 1), discard);
756    nir_def *rem_mask = COND_LOWER_OP(b, isub, lsb_mask, nir_imm_int64(b, 1));
757    nir_def *half = COND_LOWER_OP(b, ishr, lsb_mask, nir_imm_int(b, 1));
758    nir_def *rem = COND_LOWER_OP(b, iand, x, rem_mask);
759    nir_def *halfway = nir_iand(b, COND_LOWER_CMP(b, ieq, rem, half),
760                                nir_ine_imm(b, discard, 0));
761    nir_def *is_odd = COND_LOWER_CMP(b, ine, nir_imm_int64(b, 0),
762                                     COND_LOWER_OP(b, iand, x, lsb_mask));
763    nir_def *round_up = nir_ior(b, COND_LOWER_CMP(b, ilt, half, rem),
764                                nir_iand(b, halfway, is_odd));
765    if (!nir_is_rounding_mode_rtz(b->shader->info.float_controls_execution_mode,
766                                  dest_bit_size)) {
767       if (significand_bits >= 32)
768          significand = COND_LOWER_OP(b, iadd, significand,
769                                      COND_LOWER_CAST(b, b2i64, round_up));
770       else
771          significand = nir_iadd(b, significand, nir_b2i32(b, round_up));
772    }
773 
774    nir_def *res;
775 
776    if (dest_bit_size == 64) {
777       /* Compute the left shift required to normalize the original
778        * unrounded input manually.
779        */
780       nir_def *shift =
781          nir_imax(b, nir_isub_imm(b, significand_bits, exp),
782                   nir_imm_int(b, 0));
783       significand = COND_LOWER_OP(b, ishl, significand, shift);
784 
785       /* Check whether normalization led to overflow of the available
786        * significand bits, which can only happen if round_up was true
787        * above, in which case we need to add carry to the exponent and
788        * discard an extra bit from the significand.  Note that we
789        * don't need to repeat the round-up logic again, since the LSB
790        * of the significand is guaranteed to be zero if there was
791        * overflow.
792        */
793       nir_def *carry = nir_b2i32(
794          b, nir_uge_imm(b, nir_unpack_64_2x32_split_y(b, significand),
795                         (uint64_t)(1 << (significand_bits - 31))));
796       significand = COND_LOWER_OP(b, ishr, significand, carry);
797       exp = nir_iadd(b, exp, carry);
798 
799       /* Compute the biased exponent, taking care to handle a zero
800        * input correctly, which would have caused exp to be negative.
801        */
802       nir_def *biased_exp = nir_bcsel(b, nir_ilt_imm(b, exp, 0),
803                                       nir_imm_int(b, 0),
804                                       nir_iadd_imm(b, exp, 1023));
805 
806       /* Pack the significand and exponent manually. */
807       nir_def *lo = nir_unpack_64_2x32_split_x(b, significand);
808       nir_def *hi = nir_bitfield_insert(
809          b, nir_unpack_64_2x32_split_y(b, significand),
810          biased_exp, nir_imm_int(b, 20), nir_imm_int(b, 11));
811 
812       res = nir_pack_64_2x32_split(b, lo, hi);
813 
814    } else if (dest_bit_size == 32) {
815       res = nir_fmul(b, nir_u2f32(b, significand),
816                      nir_fexp2(b, nir_u2f32(b, discard)));
817    } else {
818       res = nir_fmul(b, nir_u2f16(b, significand),
819                      nir_fexp2(b, nir_u2f16(b, discard)));
820    }
821 
822    if (src_is_signed)
823       res = nir_fmul(b, res, x_sign);
824 
825    return res;
826 }
827 
828 static nir_def *
lower_f2(nir_builder * b,nir_def * x,bool dst_is_signed)829 lower_f2(nir_builder *b, nir_def *x, bool dst_is_signed)
830 {
831    assert(x->bit_size == 16 || x->bit_size == 32 || x->bit_size == 64);
832    nir_def *x_sign = NULL;
833 
834    if (dst_is_signed)
835       x_sign = nir_fsign(b, x);
836 
837    x = nir_ftrunc(b, x);
838 
839    if (dst_is_signed)
840       x = nir_fabs(b, x);
841 
842    nir_def *res;
843    if (x->bit_size < 32) {
844       res = nir_pack_64_2x32_split(b, nir_f2u32(b, x), nir_imm_int(b, 0));
845    } else {
846       nir_def *div = nir_imm_floatN_t(b, 1ULL << 32, x->bit_size);
847       nir_def *res_hi = nir_f2u32(b, nir_fdiv(b, x, div));
848       nir_def *res_lo = nir_f2u32(b, nir_frem(b, x, div));
849       res = nir_pack_64_2x32_split(b, res_lo, res_hi);
850    }
851 
852    if (dst_is_signed)
853       res = nir_bcsel(b, nir_flt_imm(b, x_sign, 0),
854                       nir_ineg(b, res), res);
855 
856    return res;
857 }
858 
859 static nir_def *
lower_bit_count64(nir_builder * b,nir_def * x)860 lower_bit_count64(nir_builder *b, nir_def *x)
861 {
862    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
863    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
864    nir_def *lo_count = nir_bit_count(b, x_lo);
865    nir_def *hi_count = nir_bit_count(b, x_hi);
866    return nir_iadd(b, lo_count, hi_count);
867 }
868 
869 nir_lower_int64_options
nir_lower_int64_op_to_options_mask(nir_op opcode)870 nir_lower_int64_op_to_options_mask(nir_op opcode)
871 {
872    switch (opcode) {
873    case nir_op_imul:
874    case nir_op_amul:
875       return nir_lower_imul64;
876    case nir_op_imul_2x32_64:
877    case nir_op_umul_2x32_64:
878       return nir_lower_imul_2x32_64;
879    case nir_op_imul_high:
880    case nir_op_umul_high:
881       return nir_lower_imul_high64;
882    case nir_op_isign:
883       return nir_lower_isign64;
884    case nir_op_udiv:
885    case nir_op_idiv:
886    case nir_op_umod:
887    case nir_op_imod:
888    case nir_op_irem:
889       return nir_lower_divmod64;
890    case nir_op_b2i64:
891    case nir_op_i2i8:
892    case nir_op_i2i16:
893    case nir_op_i2i32:
894    case nir_op_i2i64:
895    case nir_op_u2u8:
896    case nir_op_u2u16:
897    case nir_op_u2u32:
898    case nir_op_u2u64:
899    case nir_op_i2f64:
900    case nir_op_u2f64:
901    case nir_op_i2f32:
902    case nir_op_u2f32:
903    case nir_op_i2f16:
904    case nir_op_u2f16:
905    case nir_op_f2i64:
906    case nir_op_f2u64:
907       return nir_lower_conv64;
908    case nir_op_bcsel:
909       return nir_lower_bcsel64;
910    case nir_op_ieq:
911    case nir_op_ine:
912    case nir_op_ult:
913    case nir_op_ilt:
914    case nir_op_uge:
915    case nir_op_ige:
916       return nir_lower_icmp64;
917    case nir_op_iadd:
918    case nir_op_isub:
919       return nir_lower_iadd64;
920    case nir_op_imin:
921    case nir_op_imax:
922    case nir_op_umin:
923    case nir_op_umax:
924       return nir_lower_minmax64;
925    case nir_op_iabs:
926       return nir_lower_iabs64;
927    case nir_op_ineg:
928       return nir_lower_ineg64;
929    case nir_op_iand:
930    case nir_op_ior:
931    case nir_op_ixor:
932    case nir_op_inot:
933       return nir_lower_logic64;
934    case nir_op_ishl:
935    case nir_op_ishr:
936    case nir_op_ushr:
937       return nir_lower_shift64;
938    case nir_op_extract_u8:
939    case nir_op_extract_i8:
940    case nir_op_extract_u16:
941    case nir_op_extract_i16:
942       return nir_lower_extract64;
943    case nir_op_ufind_msb:
944       return nir_lower_ufind_msb64;
945    case nir_op_find_lsb:
946       return nir_lower_find_lsb64;
947    case nir_op_bit_count:
948       return nir_lower_bit_count64;
949    default:
950       return 0;
951    }
952 }
953 
954 static nir_def *
lower_int64_alu_instr(nir_builder * b,nir_alu_instr * alu)955 lower_int64_alu_instr(nir_builder *b, nir_alu_instr *alu)
956 {
957    nir_def *src[4];
958    for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
959       src[i] = nir_ssa_for_alu_src(b, alu, i);
960 
961    switch (alu->op) {
962    case nir_op_imul:
963    case nir_op_amul:
964       return lower_imul64(b, src[0], src[1]);
965    case nir_op_imul_2x32_64:
966       return lower_mul_2x32_64(b, src[0], src[1], true);
967    case nir_op_umul_2x32_64:
968       return lower_mul_2x32_64(b, src[0], src[1], false);
969    case nir_op_imul_high:
970       return lower_mul_high64(b, src[0], src[1], true);
971    case nir_op_umul_high:
972       return lower_mul_high64(b, src[0], src[1], false);
973    case nir_op_isign:
974       return lower_isign64(b, src[0]);
975    case nir_op_udiv:
976       return lower_udiv64(b, src[0], src[1]);
977    case nir_op_idiv:
978       return lower_idiv64(b, src[0], src[1]);
979    case nir_op_umod:
980       return lower_umod64(b, src[0], src[1]);
981    case nir_op_imod:
982       return lower_imod64(b, src[0], src[1]);
983    case nir_op_irem:
984       return lower_irem64(b, src[0], src[1]);
985    case nir_op_b2i64:
986       return lower_b2i64(b, src[0]);
987    case nir_op_i2i8:
988       return lower_i2i8(b, src[0]);
989    case nir_op_i2i16:
990       return lower_i2i16(b, src[0]);
991    case nir_op_i2i32:
992       return lower_i2i32(b, src[0]);
993    case nir_op_i2i64:
994       return lower_i2i64(b, src[0]);
995    case nir_op_u2u8:
996       return lower_u2u8(b, src[0]);
997    case nir_op_u2u16:
998       return lower_u2u16(b, src[0]);
999    case nir_op_u2u32:
1000       return lower_u2u32(b, src[0]);
1001    case nir_op_u2u64:
1002       return lower_u2u64(b, src[0]);
1003    case nir_op_bcsel:
1004       return lower_bcsel64(b, src[0], src[1], src[2]);
1005    case nir_op_ieq:
1006    case nir_op_ine:
1007    case nir_op_ult:
1008    case nir_op_ilt:
1009    case nir_op_uge:
1010    case nir_op_ige:
1011       return lower_int64_compare(b, alu->op, src[0], src[1]);
1012    case nir_op_iadd:
1013       return lower_iadd64(b, src[0], src[1]);
1014    case nir_op_isub:
1015       return lower_isub64(b, src[0], src[1]);
1016    case nir_op_imin:
1017       return lower_imin64(b, src[0], src[1]);
1018    case nir_op_imax:
1019       return lower_imax64(b, src[0], src[1]);
1020    case nir_op_umin:
1021       return lower_umin64(b, src[0], src[1]);
1022    case nir_op_umax:
1023       return lower_umax64(b, src[0], src[1]);
1024    case nir_op_iabs:
1025       return lower_iabs64(b, src[0]);
1026    case nir_op_ineg:
1027       return lower_ineg64(b, src[0]);
1028    case nir_op_iand:
1029       return lower_iand64(b, src[0], src[1]);
1030    case nir_op_ior:
1031       return lower_ior64(b, src[0], src[1]);
1032    case nir_op_ixor:
1033       return lower_ixor64(b, src[0], src[1]);
1034    case nir_op_inot:
1035       return lower_inot64(b, src[0]);
1036    case nir_op_ishl:
1037       return lower_ishl64(b, src[0], src[1]);
1038    case nir_op_ishr:
1039       return lower_ishr64(b, src[0], src[1]);
1040    case nir_op_ushr:
1041       return lower_ushr64(b, src[0], src[1]);
1042    case nir_op_extract_u8:
1043    case nir_op_extract_i8:
1044    case nir_op_extract_u16:
1045    case nir_op_extract_i16:
1046       return lower_extract(b, alu->op, src[0], src[1]);
1047    case nir_op_ufind_msb:
1048       return lower_ufind_msb64(b, src[0]);
1049    case nir_op_find_lsb:
1050       return lower_find_lsb64(b, src[0]);
1051    case nir_op_bit_count:
1052       return lower_bit_count64(b, src[0]);
1053    case nir_op_i2f64:
1054    case nir_op_i2f32:
1055    case nir_op_i2f16:
1056       return lower_2f(b, src[0], alu->def.bit_size, true);
1057    case nir_op_u2f64:
1058    case nir_op_u2f32:
1059    case nir_op_u2f16:
1060       return lower_2f(b, src[0], alu->def.bit_size, false);
1061    case nir_op_f2i64:
1062    case nir_op_f2u64:
1063       return lower_f2(b, src[0], alu->op == nir_op_f2i64);
1064    default:
1065       unreachable("Invalid ALU opcode to lower");
1066    }
1067 }
1068 
1069 static bool
should_lower_int64_alu_instr(const nir_alu_instr * alu,const nir_shader_compiler_options * options)1070 should_lower_int64_alu_instr(const nir_alu_instr *alu,
1071                              const nir_shader_compiler_options *options)
1072 {
1073    switch (alu->op) {
1074    case nir_op_i2i8:
1075    case nir_op_i2i16:
1076    case nir_op_i2i32:
1077    case nir_op_u2u8:
1078    case nir_op_u2u16:
1079    case nir_op_u2u32:
1080       if (alu->src[0].src.ssa->bit_size != 64)
1081          return false;
1082       break;
1083    case nir_op_bcsel:
1084       assert(alu->src[1].src.ssa->bit_size ==
1085              alu->src[2].src.ssa->bit_size);
1086       if (alu->src[1].src.ssa->bit_size != 64)
1087          return false;
1088       break;
1089    case nir_op_ieq:
1090    case nir_op_ine:
1091    case nir_op_ult:
1092    case nir_op_ilt:
1093    case nir_op_uge:
1094    case nir_op_ige:
1095       assert(alu->src[0].src.ssa->bit_size ==
1096              alu->src[1].src.ssa->bit_size);
1097       if (alu->src[0].src.ssa->bit_size != 64)
1098          return false;
1099       break;
1100    case nir_op_ufind_msb:
1101    case nir_op_find_lsb:
1102    case nir_op_bit_count:
1103       if (alu->src[0].src.ssa->bit_size != 64)
1104          return false;
1105       break;
1106    case nir_op_amul:
1107       if (options->has_imul24)
1108          return false;
1109       if (alu->def.bit_size != 64)
1110          return false;
1111       break;
1112    case nir_op_i2f64:
1113    case nir_op_u2f64:
1114    case nir_op_i2f32:
1115    case nir_op_u2f32:
1116    case nir_op_i2f16:
1117    case nir_op_u2f16:
1118       if (alu->src[0].src.ssa->bit_size != 64)
1119          return false;
1120       break;
1121    case nir_op_f2u64:
1122    case nir_op_f2i64:
1123       FALLTHROUGH;
1124    default:
1125       if (alu->def.bit_size != 64)
1126          return false;
1127       break;
1128    }
1129 
1130    unsigned mask = nir_lower_int64_op_to_options_mask(alu->op);
1131    return (options->lower_int64_options & mask) != 0;
1132 }
1133 
1134 static nir_def *
split_64bit_subgroup_op(nir_builder * b,const nir_intrinsic_instr * intrin)1135 split_64bit_subgroup_op(nir_builder *b, const nir_intrinsic_instr *intrin)
1136 {
1137    const nir_intrinsic_info *info = &nir_intrinsic_infos[intrin->intrinsic];
1138 
1139    /* This works on subgroup ops with a single 64-bit source which can be
1140     * trivially lowered by doing the exact same op on both halves.
1141     */
1142    assert(nir_src_bit_size(intrin->src[0]) == 64);
1143    nir_def *split_src0[2] = {
1144       nir_unpack_64_2x32_split_x(b, intrin->src[0].ssa),
1145       nir_unpack_64_2x32_split_y(b, intrin->src[0].ssa),
1146    };
1147 
1148    assert(info->has_dest && intrin->def.bit_size == 64);
1149 
1150    nir_def *res[2];
1151    for (unsigned i = 0; i < 2; i++) {
1152       nir_intrinsic_instr *split =
1153          nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
1154       split->num_components = intrin->num_components;
1155       split->src[0] = nir_src_for_ssa(split_src0[i]);
1156 
1157       /* Other sources must be less than 64 bits and get copied directly */
1158       for (unsigned j = 1; j < info->num_srcs; j++) {
1159          assert(nir_src_bit_size(intrin->src[j]) < 64);
1160          split->src[j] = nir_src_for_ssa(intrin->src[j].ssa);
1161       }
1162 
1163       /* Copy const indices, if any */
1164       memcpy(split->const_index, intrin->const_index,
1165              sizeof(intrin->const_index));
1166 
1167       nir_def_init(&split->instr, &split->def,
1168                    intrin->def.num_components, 32);
1169       nir_builder_instr_insert(b, &split->instr);
1170 
1171       res[i] = &split->def;
1172    }
1173 
1174    return nir_pack_64_2x32_split(b, res[0], res[1]);
1175 }
1176 
1177 static nir_def *
build_vote_ieq(nir_builder * b,nir_def * x)1178 build_vote_ieq(nir_builder *b, nir_def *x)
1179 {
1180    nir_intrinsic_instr *vote =
1181       nir_intrinsic_instr_create(b->shader, nir_intrinsic_vote_ieq);
1182    vote->src[0] = nir_src_for_ssa(x);
1183    vote->num_components = x->num_components;
1184    nir_def_init(&vote->instr, &vote->def, 1, 1);
1185    nir_builder_instr_insert(b, &vote->instr);
1186    return &vote->def;
1187 }
1188 
1189 static nir_def *
lower_vote_ieq(nir_builder * b,nir_def * x)1190 lower_vote_ieq(nir_builder *b, nir_def *x)
1191 {
1192    return nir_iand(b, build_vote_ieq(b, nir_unpack_64_2x32_split_x(b, x)),
1193                    build_vote_ieq(b, nir_unpack_64_2x32_split_y(b, x)));
1194 }
1195 
1196 static nir_def *
build_scan_intrinsic(nir_builder * b,nir_intrinsic_op scan_op,nir_op reduction_op,unsigned cluster_size,nir_def * val)1197 build_scan_intrinsic(nir_builder *b, nir_intrinsic_op scan_op,
1198                      nir_op reduction_op, unsigned cluster_size,
1199                      nir_def *val)
1200 {
1201    nir_intrinsic_instr *scan =
1202       nir_intrinsic_instr_create(b->shader, scan_op);
1203    scan->num_components = val->num_components;
1204    scan->src[0] = nir_src_for_ssa(val);
1205    nir_intrinsic_set_reduction_op(scan, reduction_op);
1206    if (scan_op == nir_intrinsic_reduce)
1207       nir_intrinsic_set_cluster_size(scan, cluster_size);
1208    nir_def_init(&scan->instr, &scan->def, val->num_components,
1209                 val->bit_size);
1210    nir_builder_instr_insert(b, &scan->instr);
1211    return &scan->def;
1212 }
1213 
1214 static nir_def *
lower_scan_iadd64(nir_builder * b,const nir_intrinsic_instr * intrin)1215 lower_scan_iadd64(nir_builder *b, const nir_intrinsic_instr *intrin)
1216 {
1217    unsigned cluster_size =
1218       intrin->intrinsic == nir_intrinsic_reduce ? nir_intrinsic_cluster_size(intrin) : 0;
1219 
1220    /* Split it into three chunks of no more than 24 bits each.  With 8 bits
1221     * of headroom, we're guaranteed that there will never be overflow in the
1222     * individual subgroup operations.  (Assuming, of course, a subgroup size
1223     * no larger than 256 which seems reasonable.)  We can then scan on each of
1224     * the chunks and add them back together at the end.
1225     */
1226    nir_def *x = intrin->src[0].ssa;
1227    nir_def *x_low =
1228       nir_u2u32(b, nir_iand_imm(b, x, 0xffffff));
1229    nir_def *x_mid =
1230       nir_u2u32(b, nir_iand_imm(b, nir_ushr_imm(b, x, 24),
1231                                 0xffffff));
1232    nir_def *x_hi =
1233       nir_u2u32(b, nir_ushr_imm(b, x, 48));
1234 
1235    nir_def *scan_low =
1236       build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd,
1237                            cluster_size, x_low);
1238    nir_def *scan_mid =
1239       build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd,
1240                            cluster_size, x_mid);
1241    nir_def *scan_hi =
1242       build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd,
1243                            cluster_size, x_hi);
1244 
1245    scan_low = nir_u2u64(b, scan_low);
1246    scan_mid = nir_ishl_imm(b, nir_u2u64(b, scan_mid), 24);
1247    scan_hi = nir_ishl_imm(b, nir_u2u64(b, scan_hi), 48);
1248 
1249    return nir_iadd(b, scan_hi, nir_iadd(b, scan_mid, scan_low));
1250 }
1251 
1252 static bool
should_lower_int64_intrinsic(const nir_intrinsic_instr * intrin,const nir_shader_compiler_options * options)1253 should_lower_int64_intrinsic(const nir_intrinsic_instr *intrin,
1254                              const nir_shader_compiler_options *options)
1255 {
1256    switch (intrin->intrinsic) {
1257    case nir_intrinsic_read_invocation:
1258    case nir_intrinsic_read_first_invocation:
1259    case nir_intrinsic_shuffle:
1260    case nir_intrinsic_shuffle_xor:
1261    case nir_intrinsic_shuffle_up:
1262    case nir_intrinsic_shuffle_down:
1263    case nir_intrinsic_quad_broadcast:
1264    case nir_intrinsic_quad_swap_horizontal:
1265    case nir_intrinsic_quad_swap_vertical:
1266    case nir_intrinsic_quad_swap_diagonal:
1267       return intrin->def.bit_size == 64 &&
1268              (options->lower_int64_options & nir_lower_subgroup_shuffle64);
1269 
1270    case nir_intrinsic_vote_ieq:
1271       return intrin->src[0].ssa->bit_size == 64 &&
1272              (options->lower_int64_options & nir_lower_vote_ieq64);
1273 
1274    case nir_intrinsic_reduce:
1275    case nir_intrinsic_inclusive_scan:
1276    case nir_intrinsic_exclusive_scan:
1277       if (intrin->def.bit_size != 64)
1278          return false;
1279 
1280       switch (nir_intrinsic_reduction_op(intrin)) {
1281       case nir_op_iadd:
1282          return options->lower_int64_options & nir_lower_scan_reduce_iadd64;
1283       case nir_op_iand:
1284       case nir_op_ior:
1285       case nir_op_ixor:
1286          return options->lower_int64_options & nir_lower_scan_reduce_bitwise64;
1287       default:
1288          return false;
1289       }
1290       break;
1291 
1292    default:
1293       return false;
1294    }
1295 }
1296 
1297 static nir_def *
lower_int64_intrinsic(nir_builder * b,nir_intrinsic_instr * intrin)1298 lower_int64_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin)
1299 {
1300    switch (intrin->intrinsic) {
1301    case nir_intrinsic_read_invocation:
1302    case nir_intrinsic_read_first_invocation:
1303    case nir_intrinsic_shuffle:
1304    case nir_intrinsic_shuffle_xor:
1305    case nir_intrinsic_shuffle_up:
1306    case nir_intrinsic_shuffle_down:
1307    case nir_intrinsic_quad_broadcast:
1308    case nir_intrinsic_quad_swap_horizontal:
1309    case nir_intrinsic_quad_swap_vertical:
1310    case nir_intrinsic_quad_swap_diagonal:
1311       return split_64bit_subgroup_op(b, intrin);
1312 
1313    case nir_intrinsic_vote_ieq:
1314       return lower_vote_ieq(b, intrin->src[0].ssa);
1315 
1316    case nir_intrinsic_reduce:
1317    case nir_intrinsic_inclusive_scan:
1318    case nir_intrinsic_exclusive_scan:
1319       switch (nir_intrinsic_reduction_op(intrin)) {
1320       case nir_op_iadd:
1321          return lower_scan_iadd64(b, intrin);
1322       case nir_op_iand:
1323       case nir_op_ior:
1324       case nir_op_ixor:
1325          return split_64bit_subgroup_op(b, intrin);
1326       default:
1327          unreachable("Unsupported subgroup scan/reduce op");
1328       }
1329       break;
1330 
1331    default:
1332       unreachable("Unsupported intrinsic");
1333    }
1334    return NULL;
1335 }
1336 
1337 static bool
should_lower_int64_instr(const nir_instr * instr,const void * _options)1338 should_lower_int64_instr(const nir_instr *instr, const void *_options)
1339 {
1340    switch (instr->type) {
1341    case nir_instr_type_alu:
1342       return should_lower_int64_alu_instr(nir_instr_as_alu(instr), _options);
1343    case nir_instr_type_intrinsic:
1344       return should_lower_int64_intrinsic(nir_instr_as_intrinsic(instr),
1345                                           _options);
1346    default:
1347       return false;
1348    }
1349 }
1350 
1351 static nir_def *
lower_int64_instr(nir_builder * b,nir_instr * instr,void * _options)1352 lower_int64_instr(nir_builder *b, nir_instr *instr, void *_options)
1353 {
1354    switch (instr->type) {
1355    case nir_instr_type_alu:
1356       return lower_int64_alu_instr(b, nir_instr_as_alu(instr));
1357    case nir_instr_type_intrinsic:
1358       return lower_int64_intrinsic(b, nir_instr_as_intrinsic(instr));
1359    default:
1360       return NULL;
1361    }
1362 }
1363 
1364 bool
nir_lower_int64(nir_shader * shader)1365 nir_lower_int64(nir_shader *shader)
1366 {
1367    return nir_shader_lower_instructions(shader, should_lower_int64_instr,
1368                                         lower_int64_instr,
1369                                         (void *)shader->options);
1370 }
1371 
1372 static bool
should_lower_int64_float_conv(const nir_instr * instr,const void * _options)1373 should_lower_int64_float_conv(const nir_instr *instr, const void *_options)
1374 {
1375    if (instr->type != nir_instr_type_alu)
1376       return false;
1377 
1378    nir_alu_instr *alu = nir_instr_as_alu(instr);
1379 
1380    switch (alu->op) {
1381    case nir_op_i2f64:
1382    case nir_op_i2f32:
1383    case nir_op_i2f16:
1384    case nir_op_u2f64:
1385    case nir_op_u2f32:
1386    case nir_op_u2f16:
1387    case nir_op_f2i64:
1388    case nir_op_f2u64:
1389       return should_lower_int64_alu_instr(alu, _options);
1390    default:
1391       return false;
1392    }
1393 }
1394 
1395 /**
1396  * Like nir_lower_int64(), but only lowers conversions to/from float.
1397  *
1398  * These operations in particular may affect double-precision lowering,
1399  * so it can be useful to run them in tandem with nir_lower_doubles().
1400  */
1401 bool
nir_lower_int64_float_conversions(nir_shader * shader)1402 nir_lower_int64_float_conversions(nir_shader *shader)
1403 {
1404    return nir_shader_lower_instructions(shader, should_lower_int64_float_conv,
1405                                         lower_int64_instr,
1406                                         (void *)shader->options);
1407 }
1408