xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_lower_idiv.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2015 Red Hat
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  * Authors:
24  *    Rob Clark <[email protected]>
25  */
26 
27 #include "nir.h"
28 #include "nir_builder.h"
29 
30 /* ported from LLVM's AMDGPUTargetLowering::LowerUDIVREM */
31 static nir_def *
emit_udiv(nir_builder * bld,nir_def * numer,nir_def * denom,bool modulo)32 emit_udiv(nir_builder *bld, nir_def *numer, nir_def *denom, bool modulo)
33 {
34    nir_def *rcp = nir_frcp(bld, nir_u2f32(bld, denom));
35    rcp = nir_f2u32(bld, nir_fmul_imm(bld, rcp, 4294966784.0));
36 
37    nir_def *neg_rcp_times_denom =
38       nir_imul(bld, rcp, nir_ineg(bld, denom));
39    rcp = nir_iadd(bld, rcp, nir_umul_high(bld, rcp, neg_rcp_times_denom));
40 
41    /* Get initial estimate for quotient/remainder, then refine the estimate
42     * in two iterations after */
43    nir_def *quotient = nir_umul_high(bld, numer, rcp);
44    nir_def *num_s_remainder = nir_imul(bld, quotient, denom);
45    nir_def *remainder = nir_isub(bld, numer, num_s_remainder);
46 
47    /* First refinement step */
48    nir_def *remainder_ge_den = nir_uge(bld, remainder, denom);
49    if (!modulo) {
50       quotient = nir_bcsel(bld, remainder_ge_den,
51                            nir_iadd_imm(bld, quotient, 1), quotient);
52    }
53    remainder = nir_bcsel(bld, remainder_ge_den,
54                          nir_isub(bld, remainder, denom), remainder);
55 
56    /* Second refinement step */
57    remainder_ge_den = nir_uge(bld, remainder, denom);
58    if (modulo) {
59       return nir_bcsel(bld, remainder_ge_den, nir_isub(bld, remainder, denom),
60                        remainder);
61    } else {
62       return nir_bcsel(bld, remainder_ge_den, nir_iadd_imm(bld, quotient, 1),
63                        quotient);
64    }
65 }
66 
67 /* ported from LLVM's AMDGPUTargetLowering::LowerSDIVREM */
68 static nir_def *
emit_idiv(nir_builder * bld,nir_def * numer,nir_def * denom,nir_op op)69 emit_idiv(nir_builder *bld, nir_def *numer, nir_def *denom, nir_op op)
70 {
71    nir_def *lhs = nir_iabs(bld, numer);
72    nir_def *rhs = nir_iabs(bld, denom);
73 
74    if (op == nir_op_idiv) {
75       /* We want (numer < 0) ^ (denom < 0). This is the XOR of the sign bits,
76        * and since XOR is bitwise, that's the sign bit of the XOR.
77        */
78       nir_def *d_sign = nir_ilt_imm(bld, nir_ixor(bld, numer, denom), 0);
79       nir_def *res = emit_udiv(bld, lhs, rhs, false);
80       return nir_bcsel(bld, d_sign, nir_ineg(bld, res), res);
81    } else {
82       nir_def *lh_sign = nir_ilt_imm(bld, numer, 0);
83       nir_def *rh_sign = nir_ilt_imm(bld, denom, 0);
84 
85       nir_def *res = emit_udiv(bld, lhs, rhs, true);
86       res = nir_bcsel(bld, lh_sign, nir_ineg(bld, res), res);
87       if (op == nir_op_imod) {
88          nir_def *cond = nir_ieq_imm(bld, res, 0);
89          cond = nir_ior(bld, nir_ieq(bld, lh_sign, rh_sign), cond);
90          res = nir_bcsel(bld, cond, res, nir_iadd(bld, res, denom));
91       }
92       return res;
93    }
94 }
95 
96 static nir_def *
convert_instr_small(nir_builder * b,nir_op op,nir_def * numer,nir_def * denom,const nir_lower_idiv_options * options)97 convert_instr_small(nir_builder *b, nir_op op,
98                     nir_def *numer, nir_def *denom,
99                     const nir_lower_idiv_options *options)
100 {
101    unsigned sz = numer->bit_size;
102    nir_alu_type int_type = nir_op_infos[op].output_type | sz;
103    nir_alu_type float_type = nir_type_float | (options->allow_fp16 ? sz * 2 : 32);
104 
105    nir_def *p = nir_type_convert(b, numer, int_type, float_type, nir_rounding_mode_undef);
106    nir_def *q = nir_type_convert(b, denom, int_type, float_type, nir_rounding_mode_undef);
107 
108    /* Take 1/q but offset mantissa by 1 to correct for rounding. This is
109     * needed for correct results and has been checked exhaustively for
110     * all pairs of 16-bit integers */
111    nir_def *rcp = nir_iadd_imm(b, nir_frcp(b, q), 1);
112 
113    /* Divide by multiplying by adjusted reciprocal */
114    nir_def *res = nir_fmul(b, p, rcp);
115 
116    /* Convert back to integer space with rounding inferred by type */
117    res = nir_type_convert(b, res, float_type, int_type, nir_rounding_mode_undef);
118 
119    /* Get remainder given the quotient */
120    if (op == nir_op_umod || op == nir_op_imod || op == nir_op_irem)
121       res = nir_isub(b, numer, nir_imul(b, denom, res));
122 
123    /* Adjust for sign, see constant folding definition */
124    if (op == nir_op_imod) {
125       nir_def *zero = nir_imm_zero(b, 1, sz);
126       nir_def *diff_sign =
127          nir_ine(b, nir_ige(b, numer, zero), nir_ige(b, denom, zero));
128 
129       nir_def *adjust = nir_iand(b, diff_sign, nir_ine(b, res, zero));
130       res = nir_iadd(b, res, nir_bcsel(b, adjust, denom, zero));
131    }
132 
133    return res;
134 }
135 
136 static nir_def *
lower_idiv(nir_builder * b,nir_instr * instr,void * _data)137 lower_idiv(nir_builder *b, nir_instr *instr, void *_data)
138 {
139    const nir_lower_idiv_options *options = _data;
140    nir_alu_instr *alu = nir_instr_as_alu(instr);
141 
142    nir_def *numer = nir_ssa_for_alu_src(b, alu, 0);
143    nir_def *denom = nir_ssa_for_alu_src(b, alu, 1);
144 
145    b->exact = true;
146 
147    if (numer->bit_size < 32)
148       return convert_instr_small(b, alu->op, numer, denom, options);
149    else if (alu->op == nir_op_udiv || alu->op == nir_op_umod)
150       return emit_udiv(b, numer, denom, alu->op == nir_op_umod);
151    else
152       return emit_idiv(b, numer, denom, alu->op);
153 }
154 
155 static bool
inst_is_idiv(const nir_instr * instr,UNUSED const void * _state)156 inst_is_idiv(const nir_instr *instr, UNUSED const void *_state)
157 {
158    if (instr->type != nir_instr_type_alu)
159       return false;
160 
161    nir_alu_instr *alu = nir_instr_as_alu(instr);
162 
163    if (alu->def.bit_size > 32)
164       return false;
165 
166    switch (alu->op) {
167    case nir_op_idiv:
168    case nir_op_udiv:
169    case nir_op_imod:
170    case nir_op_umod:
171    case nir_op_irem:
172       return true;
173    default:
174       return false;
175    }
176 }
177 
178 bool
nir_lower_idiv(nir_shader * shader,const nir_lower_idiv_options * options)179 nir_lower_idiv(nir_shader *shader, const nir_lower_idiv_options *options)
180 {
181    return nir_shader_lower_instructions(shader,
182                                         inst_is_idiv,
183                                         lower_idiv,
184                                         (void *)options);
185 }
186