xref: /aosp_15_r20/external/mesa3d/src/intel/compiler/intel_nir_opt_peephole_imul32x16.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2022 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "intel_nir.h"
25 #include "compiler/nir/nir_builder.h"
26 
27 /**
28  * Implement a peephole pass to convert integer multiplications to imul32x16.
29  */
30 
31 struct pass_data {
32    struct hash_table *range_ht;
33 };
34 
35 static void
replace_imul_instr(nir_builder * b,nir_alu_instr * imul,unsigned small_val,nir_op new_opcode)36 replace_imul_instr(nir_builder *b, nir_alu_instr *imul, unsigned small_val,
37                    nir_op new_opcode)
38 {
39    assert(small_val == 0 || small_val == 1);
40 
41    b->cursor = nir_before_instr(&imul->instr);
42 
43    nir_alu_instr *imul_32x16 = nir_alu_instr_create(b->shader, new_opcode);
44 
45    nir_alu_src_copy(&imul_32x16->src[0], &imul->src[1 - small_val]);
46    nir_alu_src_copy(&imul_32x16->src[1], &imul->src[small_val]);
47 
48    nir_def_init(&imul_32x16->instr, &imul_32x16->def,
49                 imul->def.num_components, 32);
50 
51    nir_def_rewrite_uses(&imul->def,
52                             &imul_32x16->def);
53 
54    nir_builder_instr_insert(b, &imul_32x16->instr);
55 
56    nir_instr_remove(&imul->instr);
57    nir_instr_free(&imul->instr);
58 }
59 
60 enum root_operation {
61    non_unary = 0,
62    integer_neg = 1 << 0,
63    integer_abs = 1 << 1,
64    integer_neg_abs = integer_neg | integer_abs,
65    invalid_root = 255
66 };
67 
68 static enum root_operation
signed_integer_range_analysis(nir_shader * shader,struct hash_table * range_ht,nir_scalar scalar,int * lo,int * hi)69 signed_integer_range_analysis(nir_shader *shader, struct hash_table *range_ht,
70                               nir_scalar scalar, int *lo, int *hi)
71 {
72    if (nir_scalar_is_const(scalar)) {
73       *lo = nir_scalar_as_int(scalar);
74       *hi = *lo;
75       return non_unary;
76    }
77 
78    if (nir_scalar_is_alu(scalar)) {
79       switch (nir_scalar_alu_op(scalar)) {
80       case nir_op_iabs:
81          signed_integer_range_analysis(shader, range_ht,
82                                        nir_scalar_chase_alu_src(scalar, 0),
83                                        lo, hi);
84 
85          if (*lo == INT32_MIN) {
86             *hi = INT32_MAX;
87          } else {
88             const int32_t a = abs(*lo);
89             const int32_t b = abs(*hi);
90 
91             *lo = MIN2(a, b);
92             *hi = MAX2(a, b);
93          }
94 
95          /* Absolute value wipes out any inner negations, and it is redundant
96           * with any inner absolute values.
97           */
98          return integer_abs;
99 
100       case nir_op_ineg: {
101          const enum root_operation root =
102             signed_integer_range_analysis(shader, range_ht,
103                                           nir_scalar_chase_alu_src(scalar, 0),
104                                           lo, hi);
105 
106          if (*lo == INT32_MIN) {
107             *hi = INT32_MAX;
108          } else {
109             const int32_t a = -(*lo);
110             const int32_t b = -(*hi);
111 
112             *lo = MIN2(a, b);
113             *hi = MAX2(a, b);
114          }
115 
116          /* Negation of a negation cancels out, but negation of absolute value
117           * must preserve the integer_abs bit.
118           */
119          return root ^ integer_neg;
120       }
121 
122       case nir_op_imax: {
123          int src0_lo, src0_hi;
124          int src1_lo, src1_hi;
125 
126          signed_integer_range_analysis(shader, range_ht,
127                                        nir_scalar_chase_alu_src(scalar, 0),
128                                        &src0_lo, &src0_hi);
129          signed_integer_range_analysis(shader, range_ht,
130                                        nir_scalar_chase_alu_src(scalar, 1),
131                                        &src1_lo, &src1_hi);
132 
133          *lo = MAX2(src0_lo, src1_lo);
134          *hi = MAX2(src0_hi, src1_hi);
135 
136          return non_unary;
137       }
138 
139       case nir_op_imin: {
140          int src0_lo, src0_hi;
141          int src1_lo, src1_hi;
142 
143          signed_integer_range_analysis(shader, range_ht,
144                                        nir_scalar_chase_alu_src(scalar, 0),
145                                        &src0_lo, &src0_hi);
146          signed_integer_range_analysis(shader, range_ht,
147                                        nir_scalar_chase_alu_src(scalar, 1),
148                                        &src1_lo, &src1_hi);
149 
150          *lo = MIN2(src0_lo, src1_lo);
151          *hi = MIN2(src0_hi, src1_hi);
152 
153          return non_unary;
154       }
155 
156       default:
157          break;
158       }
159    }
160 
161    /* Any value with the sign-bit set is problematic. Consider the case when
162     * bound is 0x80000000. As an unsigned value, this means the value must be
163     * in the range [0, 0x80000000]. As a signed value, it means the value must
164     * be in the range [0, INT_MAX] or it must be INT_MIN.
165     *
166     * If bound is -2, it means the value is either in the range [INT_MIN, -2]
167     * or it is in the range [0, INT_MAX].
168     *
169     * This function only returns a single, contiguous range. The union of the
170     * two ranges for any value of bound with the sign-bit set is [INT_MIN,
171     * INT_MAX].
172     */
173    const int32_t bound = nir_unsigned_upper_bound(shader, range_ht,
174                                                      scalar, NULL);
175    if (bound < 0) {
176       *lo = INT32_MIN;
177       *hi = INT32_MAX;
178    } else {
179       *lo = 0;
180       *hi = bound;
181    }
182 
183    return non_unary;
184 }
185 
186 static bool
intel_nir_opt_peephole_imul32x16_instr(nir_builder * b,nir_instr * instr,void * cb_data)187 intel_nir_opt_peephole_imul32x16_instr(nir_builder *b,
188                                        nir_instr *instr,
189                                        void *cb_data)
190 {
191    struct pass_data *d = (struct pass_data *) cb_data;
192    struct hash_table *range_ht = d->range_ht;
193 
194    if (instr->type != nir_instr_type_alu)
195       return false;
196 
197    nir_alu_instr *imul = nir_instr_as_alu(instr);
198    if (imul->op != nir_op_imul)
199       return false;
200 
201    if (imul->def.bit_size != 32)
202       return false;
203 
204    nir_op new_opcode = nir_num_opcodes;
205 
206    unsigned i;
207    for (i = 0; i < 2; i++) {
208       if (!nir_src_is_const(imul->src[i].src))
209          continue;
210 
211       int64_t lo = INT64_MAX;
212       int64_t hi = INT64_MIN;
213 
214       for (unsigned comp = 0; comp < imul->def.num_components; comp++) {
215          int64_t v = nir_src_comp_as_int(imul->src[i].src, comp);
216 
217          if (v < lo)
218             lo = v;
219 
220          if (v > hi)
221             hi = v;
222       }
223 
224       if (lo >= INT16_MIN && hi <= INT16_MAX) {
225          new_opcode = nir_op_imul_32x16;
226          break;
227       } else if (lo >= 0 && hi <= UINT16_MAX) {
228          new_opcode = nir_op_umul_32x16;
229          break;
230       }
231    }
232 
233    if (new_opcode != nir_num_opcodes) {
234       replace_imul_instr(b, imul, i, new_opcode);
235       return true;
236    }
237 
238    if (imul->def.num_components > 1)
239       return false;
240 
241    const nir_scalar imul_scalar = { &imul->def, 0 };
242    int idx = -1;
243    enum root_operation prev_root = invalid_root;
244 
245    for (i = 0; i < 2; i++) {
246       /* All constants were previously processed.  There is nothing more to
247        * learn from a constant here.
248        */
249       if (imul->src[i].src.ssa->parent_instr->type == nir_instr_type_load_const)
250          continue;
251 
252       nir_scalar scalar = nir_scalar_chase_alu_src(imul_scalar, i);
253       int lo = INT32_MIN;
254       int hi = INT32_MAX;
255 
256       const enum root_operation root =
257          signed_integer_range_analysis(b->shader, range_ht, scalar, &lo, &hi);
258 
259       /* Copy propagation (in the backend) has trouble handling cases like
260        *
261        *    mov(8)          g60<1>D         -g59<8,8,1>D
262        *    mul(8)          g61<1>D         g63<8,8,1>D     g60<16,8,2>W
263        *
264        * If g59 had absolute value instead of negation, even improved copy
265        * propagation would not be able to make progress.
266        *
267        * In cases where both sources to the integer multiplication can fit in
268        * 16-bits, choose the source that does not have a source modifier.
269        */
270       if (root < prev_root) {
271          if (lo >= INT16_MIN && hi <= INT16_MAX) {
272             new_opcode = nir_op_imul_32x16;
273             idx = i;
274             prev_root = root;
275 
276             if (root == non_unary)
277                break;
278          } else if (lo >= 0 && hi <= UINT16_MAX) {
279             new_opcode = nir_op_umul_32x16;
280             idx = i;
281             prev_root = root;
282 
283             if (root == non_unary)
284                break;
285          }
286       }
287    }
288 
289    if (new_opcode == nir_num_opcodes) {
290       assert(idx == -1);
291       assert(prev_root == invalid_root);
292       return false;
293    }
294 
295    assert(idx != -1);
296    assert(prev_root != invalid_root);
297 
298    replace_imul_instr(b, imul, idx, new_opcode);
299    return true;
300 }
301 
302 bool
intel_nir_opt_peephole_imul32x16(nir_shader * shader)303 intel_nir_opt_peephole_imul32x16(nir_shader *shader)
304 {
305    struct pass_data cb_data;
306 
307    cb_data.range_ht = _mesa_pointer_hash_table_create(NULL);
308 
309    bool progress = nir_shader_instructions_pass(shader,
310                                                 intel_nir_opt_peephole_imul32x16_instr,
311                                                 nir_metadata_control_flow,
312                                                 &cb_data);
313 
314    _mesa_hash_table_destroy(cb_data.range_ht, NULL);
315 
316    return progress;
317 }
318 
319