xref: /aosp_15_r20/external/mesa3d/src/gallium/drivers/r600/sfn/sfn_virtualvalues.h (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /* -*- mesa-c++  -*-
2  * Copyright 2021 Collabora LTD
3  * Author: Gert Wollny <[email protected]>
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #pragma once
8 
9 #include "sfn_alu_defines.h"
10 #include "sfn_memorypool.h"
11 
12 #include <array>
13 #include <cassert>
14 #include <iosfwd>
15 #include <map>
16 #include <memory>
17 #include <set>
18 #include <vector>
19 
20 #if __cpp_exceptions >= 199711L
21 #include <exception>
22 #define ASSERT_OR_THROW(EXPR, ERROR)                                                     \
23    if (!(EXPR))                                                                          \
24    throw std::invalid_argument(ERROR)
25 #else
26 #define ASSERT_OR_THROW(EXPR, ERROR)                                                     \
27    if (!(EXPR))                                                                          \
28    unreachable(ERROR)
29 #endif
30 
31 namespace r600 {
32 
33 enum Pin {
34    pin_none,
35    pin_chan,
36    pin_array,
37    pin_group,
38    pin_chgr,
39    pin_fully,
40    pin_free
41 };
42 
43 std::ostream&
44 operator<<(std::ostream& os, Pin pin);
45 
46 class Register;
47 class RegisterVisitor;
48 class ConstRegisterVisitor;
49 class Instr;
50 class InlineConstant;
51 class LiteralConstant;
52 class UniformValue;
53 class ValueFactory;
54 
55 using InstructionSet = std::set<Instr *, std::less<Instr *>, Allocator<Instr *>>;
56 
57 class VirtualValue : public Allocate {
58 public:
59    static const uint32_t virtual_register_base = 1024;
60    static const uint32_t clause_temp_registers = 2;
61    static const uint32_t gpr_register_end = 128 - 2 * clause_temp_registers;
62    static const uint32_t clause_temp_register_begin = gpr_register_end;
63    static const uint32_t clause_temp_register_end = 128;
64 
65    static const uint32_t uniforms_begin = 512;
66    static const uint32_t uniforms_end = 640;
67 
68    using Pointer = R600_POINTER_TYPE(VirtualValue);
69 
70    VirtualValue(int sel, int chan, Pin pin);
71    VirtualValue(const VirtualValue& orig) = default;
72 
sel()73    int sel() const { return m_sel; }
chan()74    int chan() const { return m_chan; }
pin()75    Pin pin() const { return m_pins; };
76    bool is_virtual() const;
77 
set_pin(Pin p)78    void set_pin(Pin p) { m_pins = p; }
79 
80    virtual void accept(RegisterVisitor& vistor) = 0;
81    virtual void accept(ConstRegisterVisitor& vistor) const = 0;
82    virtual void print(std::ostream& os) const = 0;
83 
84    bool equal_to(const VirtualValue& other) const;
85    Pointer get_addr() const;
86 
87    static Pointer from_string(const std::string& s);
88 
as_register()89    virtual Register *as_register() { return nullptr; }
as_inline_const()90    virtual InlineConstant *as_inline_const() { return nullptr; }
as_literal()91    virtual LiteralConstant *as_literal() { return nullptr; }
as_uniform()92    virtual UniformValue *as_uniform() { return nullptr; }
93    virtual bool ready(int block, int index) const;
94 
95    static constexpr char chanchar[9] = "xyzw01?_";
96 
97 protected:
do_set_chan(int c)98    void do_set_chan(int c) { m_chan = c; }
set_sel_internal(int sel)99    void set_sel_internal(int sel) { m_sel = sel; }
100 
101 private:
102    uint32_t m_sel;
103    int m_chan;
104    Pin m_pins;
105 };
106 using PVirtualValue = VirtualValue::Pointer;
107 
108 inline std::ostream&
109 operator<<(std::ostream& os, const VirtualValue& val)
110 {
111    val.print(os);
112    return os;
113 }
114 
115 inline bool
116 operator==(const VirtualValue& lhs, const VirtualValue& rhs)
117 {
118    return lhs.equal_to(rhs);
119 }
120 
121 struct LiveRange {
LiveRangeLiveRange122    LiveRange():
123        start(-1),
124        end(-1),
125        is_pinned(false)
126    {
127    }
LiveRangeLiveRange128    LiveRange(int s, int e):
129        start(s),
130        end(e),
131        is_pinned(false)
132    {
133    }
134    int start;
135    int end;
136    int is_pinned;
137 };
138 
139 class Register : public VirtualValue {
140 public:
141    using Pointer = R600_POINTER_TYPE(Register);
142 
143    enum Flags {
144       ssa,
145       pin_start,
146       pin_end,
147       addr_or_idx,
148       flag_count
149    };
150 
151    Register(int sel, int chan, Pin pin);
152    void accept(RegisterVisitor& vistor) override;
153    void accept(ConstRegisterVisitor& vistor) const override;
154    void print(std::ostream& os) const override;
155 
156    static Pointer from_string(const std::string& s);
157 
as_register()158    Register *as_register() override { return this; }
159 
160    void add_parent(Instr *instr);
161    void del_parent(Instr *instr);
parents()162    const InstructionSet& parents() const { return m_parents; }
163 
164    bool ready(int block, int index) const override;
165 
uses()166    const InstructionSet& uses() const { return m_uses; }
167    void add_use(Instr *instr);
168    void del_use(Instr *instr);
has_uses()169    bool has_uses() const { return !m_uses.empty() || pin() == pin_array; }
set_chan(int c)170    void set_chan(int c) { do_set_chan(c); }
171 
addr()172    virtual VirtualValue *addr() const { return nullptr; }
173 
index()174    int index() const { return m_index; }
set_index(int idx)175    void set_index(int idx) { m_index = idx; }
176 
set_sel(int new_sel)177    void set_sel(int new_sel)
178    {
179       set_sel_internal(new_sel);
180       m_flags.reset(ssa);
181    }
182 
set_flag(Flags f)183    void set_flag(Flags f) { m_flags.set(f); }
reset_flag(Flags f)184    void reset_flag(Flags f) { m_flags.reset(f); }
has_flag(Flags f)185    auto has_flag(Flags f) const { return m_flags.test(f); }
flags()186    auto flags() const { return m_flags; }
187 
188 private:
189    Register(const Register& orig) = delete;
190    Register(const Register&& orig) = delete;
191    Register& operator=(const Register& orig) = delete;
192    Register& operator=(Register&& orig) = delete;
193 
forward_del_use(Instr * instr)194    virtual void forward_del_use(Instr *instr) { (void)instr; }
forward_add_use(Instr * instr)195    virtual void forward_add_use(Instr *instr) { (void)instr; }
196    virtual void add_parent_to_array(Instr *instr);
197    virtual void del_parent_from_array(Instr *instr);
198 
199    InstructionSet m_parents;
200    InstructionSet m_uses;
201 
202    int m_index{-1};
203 
204    std::bitset<flag_count> m_flags{0};
205 };
206 using PRegister = Register::Pointer;
207 
208 class AddressRegister : public Register {
209 public:
210    enum Type {
211       addr,
212       idx0 = 1,
213       idx1 = 2
214    };
AddressRegister(Type type)215    AddressRegister(Type type) :  Register(type, 0, pin_fully) {
216       set_flag(addr_or_idx);
217    }
218 
219 protected:
do_set_chan(UNUSED int c)220    void do_set_chan(UNUSED int c) { unreachable("Address registers must have chan 0");}
set_sel_internal(UNUSED int sel)221    void set_sel_internal(UNUSED int sel) {unreachable("Address registers don't support sel override");}
222 };
223 
224 
225 inline std::ostream&
226 operator<<(std::ostream& os, const Register& val)
227 {
228    val.print(os);
229    return os;
230 }
231 
232 class InlineConstant : public VirtualValue {
233 public:
234    using Pointer = R600_POINTER_TYPE(InlineConstant);
235 
236    InlineConstant(int sel, int chan = 0);
237 
238    void accept(RegisterVisitor& vistor) override;
239    void accept(ConstRegisterVisitor& vistor) const override;
240    void print(std::ostream& os) const override;
241    static Pointer from_string(const std::string& s);
242    static Pointer param_from_string(const std::string& s);
243 
as_inline_const()244    InlineConstant *as_inline_const() override { return this; }
245 
246 private:
247    InlineConstant(const InlineConstant& orig) = default;
248    static std::map<std::string, std::pair<AluInlineConstants, bool>> s_opmap;
249 };
250 using PInlineConstant = InlineConstant::Pointer;
251 
252 inline std::ostream&
253 operator<<(std::ostream& os, const InlineConstant& val)
254 {
255    val.print(os);
256    return os;
257 }
258 
259 class RegisterVec4 {
260 public:
261    using Swizzle = std::array<uint8_t, 4>;
262    RegisterVec4();
263    RegisterVec4(int sel,
264                 bool is_ssa = false,
265                 const Swizzle& swz = {0, 1, 2, 3},
266                 Pin pin = pin_group);
267    RegisterVec4(PRegister x, PRegister y, PRegister z, PRegister w, Pin pin);
268 
269    RegisterVec4(const RegisterVec4& orig);
270 
271    RegisterVec4(RegisterVec4&& orig) = default;
272    RegisterVec4& operator=(RegisterVec4& orig) = default;
273    RegisterVec4& operator=(RegisterVec4&& orig) = default;
274 
275    void add_use(Instr *instr);
276    void del_use(Instr *instr);
277    bool has_uses() const;
278 
279    int sel() const;
280    void print(std::ostream& os) const;
281 
282    class Element : public Allocate {
283    public:
284       Element(const RegisterVec4& parent, int chan);
285       Element(const RegisterVec4& parent, PRegister value);
value()286       PRegister value() { return m_value; }
set_value(PRegister reg)287       void set_value(PRegister reg) { m_value = reg; }
288 
289    private:
290       const RegisterVec4& m_parent;
291       PRegister m_value;
292    };
293 
294    friend class Element;
295 
296    PRegister operator[](int i) const { return m_values[i]->value(); }
297 
298    PRegister operator[](int i) { return m_values[i]->value(); }
299 
set_value(int i,PRegister reg)300    void set_value(int i, PRegister reg)
301    {
302       if (reg->chan() < 4) {
303          m_sel = reg->sel();
304       }
305       m_swz[i] = reg->chan();
306       m_values[i]->set_value(reg);
307    }
308 
validate()309    void validate()
310    {
311       int sel = -1;
312       for (int k = 0; k < 4; ++k) {
313          if (sel < 0) {
314             if (m_values[k]->value()->chan() < 4)
315                sel = m_values[k]->value()->sel();
316          } else {
317             assert(m_values[k]->value()->chan() > 3 ||
318                    m_values[k]->value()->sel() == sel);
319          }
320       }
321    }
322 
free_chan_mask()323    uint8_t free_chan_mask() const
324    {
325       int mask = 0xf;
326       for (int i = 0; i < 4; ++i) {
327          int chan = m_values[i]->value()->chan();
328          if (chan <= 3) {
329             mask &= ~(1 << chan);
330          }
331       }
332       return mask;
333    }
334 
335    bool ready(int block_id, int index) const;
336 
337 private:
338    int m_sel;
339    Swizzle m_swz;
340    std::array<R600_POINTER_TYPE(Element), 4> m_values;
341 };
342 
343 bool
344 operator==(const RegisterVec4& lhs, const RegisterVec4& rhs);
345 
346 inline bool
347 operator!=(const RegisterVec4& lhs, const RegisterVec4& rhs)
348 {
349    return !(lhs == rhs);
350 }
351 
352 inline std::ostream&
353 operator<<(std::ostream& os, const RegisterVec4& val)
354 {
355    val.print(os);
356    return os;
357 }
358 
359 class LiteralConstant : public VirtualValue {
360 public:
361    using Pointer = R600_POINTER_TYPE(LiteralConstant);
362 
363    LiteralConstant(uint32_t value);
364    void accept(RegisterVisitor& vistor) override;
365    void accept(ConstRegisterVisitor& vistor) const override;
366    void print(std::ostream& os) const override;
value()367    uint32_t value() const { return m_value; }
368    static Pointer from_string(const std::string& s);
as_literal()369    LiteralConstant *as_literal() override { return this; }
370 
371 private:
372    LiteralConstant(const LiteralConstant& orig) = default;
373    uint32_t m_value;
374 };
375 using PLiteralVirtualValue = LiteralConstant::Pointer;
376 
377 class UniformValue : public VirtualValue {
378 public:
379    using Pointer = R600_POINTER_TYPE(UniformValue);
380 
381    UniformValue(int sel, int chan, int kcache_bank = 0);
382    UniformValue(int sel, int chan, PVirtualValue buf_addr, int kcache_bank);
383 
384    void accept(RegisterVisitor& vistor) override;
385    void accept(ConstRegisterVisitor& vistor) const override;
386    void print(std::ostream& os) const override;
kcache_bank()387    int kcache_bank() const { return m_kcache_bank; }
388    PVirtualValue buf_addr() const;
389    void set_buf_addr(PVirtualValue addr);
as_uniform()390    UniformValue *as_uniform() override { return this; }
391 
392    bool equal_buf_and_cache(const UniformValue& other) const;
393    static Pointer from_string(const std::string& s, ValueFactory *factory);
394 
395 private:
396    int m_kcache_bank;
397    PVirtualValue m_buf_addr;
398 };
399 using PUniformVirtualValue = UniformValue::Pointer;
400 
401 inline std::ostream&
402 operator<<(std::ostream& os, const UniformValue& val)
403 {
404    val.print(os);
405    return os;
406 }
407 
408 class LocalArrayValue;
409 class LocalArray : public Register {
410 public:
411    using Pointer = R600_POINTER_TYPE(LocalArray);
412    using Values = std::vector<LocalArrayValue *, Allocator<LocalArrayValue *>>;
413 
414    LocalArray(int base_sel, int nchannels, int size, int frac = 0);
415    void accept(RegisterVisitor& vistor) override;
416    void accept(ConstRegisterVisitor& vistor) const override;
417    void print(std::ostream& os) const override;
418    bool ready_for_direct(int block, int index, int chan) const;
419    bool ready_for_indirect(int block, int index, int chan) const;
420 
421    PRegister element(size_t offset, PVirtualValue indirect, uint32_t chan);
422 
423    size_t size() const;
424    uint32_t nchannels() const;
frac()425    uint32_t frac() const { return m_frac; }
426 
427    void add_parent_to_elements(int chan, Instr *instr);
428 
429    const Register& operator()(size_t idx, size_t chan) const;
430 
begin()431    Values::iterator begin() { return m_values.begin(); }
end()432    Values::iterator end() { return m_values.end(); }
begin()433    Values::const_iterator begin() const { return m_values.begin(); }
end()434    Values::const_iterator end() const { return m_values.end(); }
435 
base_sel()436    uint32_t base_sel() const { return m_base_sel;}
437 
438 private:
439    uint32_t m_base_sel;
440    uint32_t m_nchannels;
441    size_t m_size;
442    Values m_values;
443    Values m_values_indirect;
444    int m_frac;
445 };
446 
447 inline std::ostream&
448 operator<<(std::ostream& os, const LocalArray& val)
449 {
450    val.print(os);
451    return os;
452 }
453 
454 class LocalArrayValue : public Register {
455 public:
456    using Pointer = R600_POINTER_TYPE(LocalArrayValue);
457 
458    LocalArrayValue(PRegister reg, LocalArray& array);
459    LocalArrayValue(PRegister reg, PVirtualValue index, LocalArray& array);
460 
461    void accept(RegisterVisitor& vistor) override;
462    void accept(ConstRegisterVisitor& vistor) const override;
463    void print(std::ostream& os) const override;
464    bool ready(int block, int index) const override;
465 
466    VirtualValue *addr() const override;
467    void set_addr(PRegister addr);
468    const LocalArray& array() const;
469 
470 private:
471    void forward_del_use(Instr *instr) override;
472    void forward_add_use(Instr *instr) override;
473    void add_parent_to_array(Instr *instr) override;
474    void del_parent_from_array(Instr *instr) override;
475 
476    PVirtualValue m_addr;
477    LocalArray& m_array;
478 };
479 
480 inline std::ostream&
481 operator<<(std::ostream& os, const LocalArrayValue& val)
482 {
483    val.print(os);
484    return os;
485 }
486 
487 template <typename T>
488 bool
sfn_value_equal(const T * lhs,const T * rhs)489 sfn_value_equal(const T *lhs, const T *rhs)
490 {
491    if (lhs) {
492       if (!rhs)
493          return false;
494       if (!lhs->equal_to(*rhs))
495          return false;
496    } else {
497       if (rhs)
498          return false;
499    }
500    return true;
501 }
502 
503 bool
504 value_is_const_uint(const VirtualValue& val, uint32_t value);
505 bool
506 value_is_const_float(const VirtualValue& val, float value);
507 
508 class RegisterVisitor {
509 public:
510    virtual void visit(Register& value) = 0;
511    virtual void visit(LocalArray& value) = 0;
512    virtual void visit(LocalArrayValue& value) = 0;
513    virtual void visit(UniformValue& value) = 0;
514    virtual void visit(LiteralConstant& value) = 0;
515    virtual void visit(InlineConstant& value) = 0;
516 };
517 
518 class ConstRegisterVisitor {
519 public:
520    virtual void visit(const Register& value) = 0;
521    virtual void visit(const LocalArray& value) = 0;
522    virtual void visit(const LocalArrayValue& value) = 0;
523    virtual void visit(const UniformValue& value) = 0;
524    virtual void visit(const LiteralConstant& value) = 0;
525    virtual void visit(const InlineConstant& value) = 0;
526 };
527 
528 } // namespace r600
529