xref: /aosp_15_r20/external/mesa3d/src/compiler/nir/nir_mod_analysis.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2022 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 
26 static nir_alu_type
nir_alu_src_type(const nir_alu_instr * instr,unsigned src)27 nir_alu_src_type(const nir_alu_instr *instr, unsigned src)
28 {
29    return nir_alu_type_get_base_type(nir_op_infos[instr->op].input_types[src]) |
30           nir_src_bit_size(instr->src[src].src);
31 }
32 
33 static nir_scalar
nir_alu_arg(const nir_alu_instr * alu,unsigned arg,unsigned comp)34 nir_alu_arg(const nir_alu_instr *alu, unsigned arg, unsigned comp)
35 {
36    const nir_alu_src *src = &alu->src[arg];
37    return nir_get_scalar(src->src.ssa, src->swizzle[comp]);
38 }
39 
40 /* Tries to determine the value of expression "val % div", assuming that val
41  * is interpreted as value of type "val_type". "div" must be a power of two.
42  * Returns true if it can statically tell the value of "val % div", false if not.
43  * Value of *mod is undefined if this function returned false.
44  *
45  * Tests are in mod_analysis_tests.cpp.
46  */
47 bool
nir_mod_analysis(nir_scalar val,nir_alu_type val_type,unsigned div,unsigned * mod)48 nir_mod_analysis(nir_scalar val, nir_alu_type val_type, unsigned div, unsigned *mod)
49 {
50    if (div == 1) {
51       *mod = 0;
52       return true;
53    }
54 
55    assert(util_is_power_of_two_nonzero(div));
56 
57    switch (val.def->parent_instr->type) {
58    case nir_instr_type_load_const: {
59       nir_load_const_instr *load =
60          nir_instr_as_load_const(val.def->parent_instr);
61       nir_alu_type base_type = nir_alu_type_get_base_type(val_type);
62 
63       if (base_type == nir_type_uint) {
64          assert(val.comp < load->def.num_components);
65          uint64_t ival = nir_const_value_as_uint(load->value[val.comp],
66                                                  load->def.bit_size);
67          *mod = ival % div;
68          return true;
69       } else if (base_type == nir_type_int) {
70          assert(val.comp < load->def.num_components);
71          int64_t ival = nir_const_value_as_int(load->value[val.comp],
72                                                load->def.bit_size);
73 
74          /* whole analysis collapses the moment we allow negative values */
75          if (ival < 0)
76             return false;
77 
78          *mod = ((uint64_t)ival) % div;
79          return true;
80       }
81 
82       break;
83    }
84 
85    case nir_instr_type_alu: {
86       nir_alu_instr *alu = nir_instr_as_alu(val.def->parent_instr);
87 
88       if (alu->def.num_components != 1)
89          return false;
90 
91       switch (alu->op) {
92       case nir_op_ishr: {
93          if (nir_src_is_const(alu->src[1].src)) {
94             assert(alu->src[1].src.ssa->num_components == 1);
95             uint64_t shift = nir_src_as_uint(alu->src[1].src);
96 
97             if (util_last_bit(div) + shift > 32)
98                break;
99 
100             nir_alu_type type0 = nir_alu_src_type(alu, 0);
101             if (!nir_mod_analysis(nir_alu_arg(alu, 0, val.comp), type0, div << shift, mod))
102                return false;
103 
104             *mod >>= shift;
105             return true;
106          }
107          break;
108       }
109 
110       case nir_op_iadd: {
111          unsigned mod0;
112          nir_alu_type type0 = nir_alu_src_type(alu, 0);
113          if (!nir_mod_analysis(nir_alu_arg(alu, 0, val.comp), type0, div, &mod0))
114             return false;
115 
116          unsigned mod1;
117          nir_alu_type type1 = nir_alu_src_type(alu, 1);
118          if (!nir_mod_analysis(nir_alu_arg(alu, 1, val.comp), type1, div, &mod1))
119             return false;
120 
121          *mod = (mod0 + mod1) % div;
122          return true;
123       }
124 
125       case nir_op_ishl: {
126          if (nir_src_is_const(alu->src[1].src)) {
127             assert(alu->src[1].src.ssa->num_components == 1);
128             uint64_t shift = nir_src_as_uint(alu->src[1].src);
129 
130             if ((div >> shift) == 0) {
131                *mod = 0;
132                return true;
133             }
134             nir_alu_type type0 = nir_alu_src_type(alu, 0);
135             return nir_mod_analysis(nir_alu_arg(alu, 0, val.comp), type0, div >> shift, mod);
136          }
137          break;
138       }
139 
140       case nir_op_imul_32x16: /* multiply 32-bits with low 16-bits */
141       case nir_op_imul: {
142          unsigned mod0;
143          nir_alu_type type0 = nir_alu_src_type(alu, 0);
144          bool s1 = nir_mod_analysis(nir_alu_arg(alu, 0, val.comp), type0, div, &mod0);
145 
146          if (s1 && (mod0 == 0)) {
147             *mod = 0;
148             return true;
149          }
150 
151          /* if divider is larger than 2nd source max (interpreted) value
152           * then modulo of multiplication is unknown
153           */
154          if (alu->op == nir_op_imul_32x16 && div > (1u << 16))
155             return false;
156 
157          unsigned mod1;
158          nir_alu_type type1 = nir_alu_src_type(alu, 1);
159          bool s2 = nir_mod_analysis(nir_alu_arg(alu, 1, val.comp), type1, div, &mod1);
160 
161          if (s2 && (mod1 == 0)) {
162             *mod = 0;
163             return true;
164          }
165 
166          if (!s1 || !s2)
167             return false;
168 
169          *mod = (mod0 * mod1) % div;
170          return true;
171       }
172 
173       default:
174          break;
175       }
176       break;
177    }
178 
179    default:
180       break;
181    }
182 
183    return false;
184 }
185