xref: /aosp_15_r20/external/mesa3d/src/gallium/drivers/r600/sfn/sfn_valuefactory.h (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /* -*- mesa-c++  -*-
2  * Copyright 2021 Collabora LTD
3  * Author: Gert Wollny <[email protected]>
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #ifndef VALUEFACTORY_H
8 #define VALUEFACTORY_H
9 
10 #include "nir.h"
11 #include "sfn_alu_defines.h"
12 #include "sfn_virtualvalues.h"
13 
14 #include <cassert>
15 #include <list>
16 #include <ostream>
17 #include <unordered_map>
18 
19 struct r600_shader;
20 
21 namespace r600 {
22 
23 struct LiveRangeEntry {
24    enum EUse {
25       use_export,
26       use_unspecified
27    };
28 
LiveRangeEntryLiveRangeEntry29    LiveRangeEntry(Register *reg):
30        m_register(reg)
31    {
32    }
33    int m_start{-1};
34    int m_end{-1};
35    int m_index{-1};
36    int m_color{-1};
37    bool m_alu_clause_local{false};
38    std::bitset<use_unspecified> m_use;
39    Register *m_register;
40 
printLiveRangeEntry41    void print(std::ostream& os) const
42    {
43       os << *m_register << "(" << m_index << ", " << m_color << ") [" << m_start << ":"
44          << m_end << "]";
45    }
46 };
47 
48 inline std::ostream&
49 operator<<(std::ostream& os, const LiveRangeEntry& lre)
50 {
51    lre.print(os);
52    return os;
53 }
54 
55 class LiveRangeMap {
56 public:
57    using ChannelLiveRange = std::vector<LiveRangeEntry>;
58 
operator()59    LiveRangeEntry& operator()(int index, int chan)
60    {
61       assert(chan < 4);
62       return m_life_ranges[chan].at(index);
63    }
64 
65    void append_register(Register *reg);
66 
set_life_range(const Register & reg,int start,int end)67    void set_life_range(const Register& reg, int start, int end)
68    {
69       auto& entry = m_life_ranges[reg.chan()].at(reg.index());
70       entry.m_start = start;
71       entry.m_end = end;
72    }
73 
74    std::array<size_t, 4> sizes() const;
75 
component(int i)76    ChannelLiveRange& component(int i) { return m_life_ranges[i]; }
77 
component(int i)78    const ChannelLiveRange& component(int i) const { return m_life_ranges[i]; }
79 
80 private:
81    std::array<ChannelLiveRange, 4> m_life_ranges;
82 };
83 
84 std::ostream&
85 operator<<(std::ostream& os, const LiveRangeMap& lrm);
86 
87 bool
88 operator==(const LiveRangeMap& lhs, const LiveRangeMap& rhs);
89 
90 inline bool
91 operator!=(const LiveRangeMap& lhs, const LiveRangeMap& rhs)
92 {
93    return !(lhs == rhs);
94 }
95 
96 enum EValuePool {
97    vp_ssa,
98    vp_register,
99    vp_temp,
100    vp_array,
101    vp_ignore
102 };
103 
104 union RegisterKey {
105    struct {
106       uint32_t index;
107       uint32_t chan : 29;
108       EValuePool pool : 3;
109    } value;
110    uint64_t hash;
111 
RegisterKey(uint32_t index,uint32_t chan,EValuePool pool)112    RegisterKey(uint32_t index, uint32_t chan, EValuePool pool)
113    {
114       value.index = index;
115       value.chan = chan;
116       value.pool = pool;
117    }
118 
print(std::ostream & os)119    void print(std::ostream& os) const
120    {
121       os << "(" << value.index << ", " << value.chan << ", ";
122       switch (value.pool) {
123       case vp_ssa:
124          os << "ssa";
125          break;
126       case vp_register:
127          os << "reg";
128          break;
129       case vp_temp:
130          os << "temp";
131          break;
132       case vp_array:
133          os << "array";
134          break;
135       case vp_ignore:
136          break;
137       }
138       os << ")";
139    }
140 };
141 
142 inline bool
143 operator==(const RegisterKey& lhs, const RegisterKey& rhs)
144 {
145    return lhs.hash == rhs.hash;
146 }
147 
148 inline std::ostream&
149 operator<<(std::ostream& os, const RegisterKey& key)
150 {
151    key.print(os);
152    return os;
153 }
154 
155 struct register_key_hash {
operatorregister_key_hash156    std::size_t operator()(const RegisterKey& key) const { return key.hash; }
157 };
158 
159 class ChannelCounts {
160 public:
inc_count(int chan)161    void inc_count(int chan) { ++m_counts[chan]; }
inc_count(int chan,int n)162    void inc_count(int chan, int n) { m_counts[chan] += n; }
least_used(uint8_t mask)163    int least_used(uint8_t mask) const
164    {
165       int least_used = 0;
166       uint32_t count = m_counts[0];
167       for (int i = 1; i < 4; ++i) {
168          if (!((1 << i) & mask))
169             continue;
170          if (count > m_counts[i]) {
171             count = m_counts[i];
172             least_used = i;
173          }
174       }
175       return least_used;
176    }
print(std::ostream & os)177    void print(std::ostream& os) const
178    {
179       os << "CC:" << m_counts[0] << " " << m_counts[1] << " " << m_counts[2] << " "
180          << m_counts[3];
181    }
182 
183 private:
184    std::array<uint32_t, 4> m_counts{0, 0, 0, 0};
185 };
186 
187 inline std::ostream&
188 operator<<(std::ostream& os, const ChannelCounts& cc)
189 {
190    cc.print(os);
191    return os;
192 }
193 
194 class ValueFactory : public Allocate {
195 public:
196    ValueFactory();
197 
198    void clear();
199 
200    ValueFactory(const ValueFactory& orig) = delete;
201    ValueFactory& operator=(const ValueFactory& orig) = delete;
202 
203    void set_virtual_register_base(int base);
204 
205    int new_register_index();
206 
207 
208    /* Allocate registers */
209    bool allocate_registers(const std::list<nir_intrinsic_instr *>& regs);
210    PRegister allocate_pinned_register(int sel, int chan);
211    RegisterVec4 allocate_pinned_vec4(int sel, bool is_ssa);
212 
213    /* Inject a predefined value for a given dest value
214     * (usually the result of a sysvalue load) */
215    void inject_value(const nir_def& def, int chan, PVirtualValue value);
216 
217    /* Get or create a destination value of vector of values */
218    PRegister
219    dest(const nir_def& def, int chan, Pin pin_channel, uint8_t chan_mask = 0xf);
220 
221    RegisterVec4 dest_vec4(const nir_def& dest, Pin pin);
222 
223    std::vector<PRegister, Allocator<PRegister>> dest_vec(const nir_def& dest,
224                                                          int num_components);
225 
226    PRegister dummy_dest(unsigned chan);
227 
228 
229    /* Create and get a temporary value */
230    PRegister temp_register(int pinned_channel = -1, bool is_ssa = true);
231    RegisterVec4 temp_vec4(Pin pin, const RegisterVec4::Swizzle& swizzle = {0, 1, 2, 3});
232 
233 
234    RegisterVec4
235    src_vec4(const nir_src& src, Pin pin, const RegisterVec4::Swizzle& swz = {0, 1, 2, 3});
236 
237    PVirtualValue src(const nir_alu_src& alu_src, int chan);
238    PVirtualValue src64(const nir_alu_src& alu_src, int chan, int comp);
239    PVirtualValue src(const nir_src& src, int chan);
240    PVirtualValue src(const nir_tex_src& tex_src, int chan);
241    PVirtualValue literal(uint32_t value);
242    PVirtualValue uniform(nir_intrinsic_instr *load_uniform, int chan);
243    PVirtualValue uniform(uint32_t index, int chan, int kcache);
244    std::vector<PVirtualValue, Allocator<PVirtualValue>> src_vec(const nir_src& src,
245                                                                 int components);
246 
247 
248    void allocate_const(nir_load_const_instr *load_const);
249 
250    PRegister dest_from_string(const std::string& s);
251    RegisterVec4 dest_vec4_from_string(const std::string& s,
252                                       RegisterVec4::Swizzle& swz,
253                                       Pin pin = pin_none);
254    PVirtualValue src_from_string(const std::string& s);
255    RegisterVec4 src_vec4_from_string(const std::string& s);
256 
257    LocalArray *array_from_string(const std::string& s);
258 
259 
260    PInlineConstant inline_const(AluInlineConstants sel, int chan);
261 
262    void get_shader_info(r600_shader *sh_info);
263 
264    PRegister undef(int index, int chan);
265    PVirtualValue zero();
266    PVirtualValue one();
267    PVirtualValue one_i();
268 
269    LiveRangeMap prepare_live_range_map();
270 
271    void clear_pins();
272 
next_register_index()273    int next_register_index() const { return m_next_register_index; }
array_registers()274    uint32_t array_registers() const { return m_required_array_registers; }
275 
276    PRegister addr();
277    PRegister idx_reg(unsigned idx);
278 
279 private:
280    PVirtualValue ssa_src(const nir_def& dest, int chan);
281 
282    int m_next_register_index;
283    int m_next_temp_channel{0};
284 
285    template <typename Key, typename T>
286    using unordered_map_alloc = std::unordered_map<Key,
287                                                   T,
288                                                   std::hash<Key>,
289                                                   std::equal_to<Key>,
290                                                   Allocator<std::pair<const Key, T>>>;
291 
292    template <typename Key, typename T>
293    using unordered_reg_map_alloc = std::unordered_map<Key,
294                                                       T,
295                                                       register_key_hash,
296                                                       std::equal_to<Key>,
297                                                       Allocator<std::pair<const Key, T>>>;
298 
299    using RegisterMap = unordered_reg_map_alloc<RegisterKey, PRegister>;
300    using ROValueMap = unordered_reg_map_alloc<RegisterKey, PVirtualValue>;
301 
302    RegisterMap m_registers;
303    std::list<PRegister, Allocator<PRegister>> m_pinned_registers;
304    ROValueMap m_values;
305    unordered_map_alloc<uint32_t, PLiteralVirtualValue> m_literal_values;
306    unordered_map_alloc<uint32_t, InlineConstant::Pointer> m_inline_constants;
307    unordered_map_alloc<uint32_t, uint32_t> m_ssa_index_to_sel;
308 
309    uint32_t m_nowrite_idx;
310 
311    RegisterVec4 m_dummy_dest_pinned{
312       g_registers_end, pin_chan, {0, 1, 2, 3}
313    };
314    ChannelCounts m_channel_counts;
315    uint32_t m_required_array_registers{0};
316 
317    AddressRegister *m_ar{nullptr};
318    AddressRegister *m_idx0{nullptr};
319    AddressRegister *m_idx1{nullptr};
320 };
321 
322 } // namespace r600
323 
324 #endif // VALUEFACTORY_H
325