xref: /aosp_15_r20/external/mesa3d/src/gallium/drivers/r600/sfn/sfn_peephole.cpp (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /* -*- mesa-c++  -*-
2  * Copyright 2022 Collabora LTD
3  * Author: Gert Wollny <[email protected]>
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "sfn_peephole.h"
8 #include "sfn_instr_alugroup.h"
9 
10 namespace r600 {
11 
12 class PeepholeVisitor : public InstrVisitor {
13 public:
14    void visit(AluInstr *instr) override;
15    void visit(AluGroup *instr) override;
visit(TexInstr * instr)16    void visit(TexInstr *instr) override { (void)instr; };
visit(ExportInstr * instr)17    void visit(ExportInstr *instr) override { (void)instr; }
visit(FetchInstr * instr)18    void visit(FetchInstr *instr) override { (void)instr; }
19    void visit(Block *instr) override;
visit(ControlFlowInstr * instr)20    void visit(ControlFlowInstr *instr) override { (void)instr; }
21    void visit(IfInstr *instr) override;
visit(ScratchIOInstr * instr)22    void visit(ScratchIOInstr *instr) override { (void)instr; }
visit(StreamOutInstr * instr)23    void visit(StreamOutInstr *instr) override { (void)instr; }
visit(MemRingOutInstr * instr)24    void visit(MemRingOutInstr *instr) override { (void)instr; }
visit(EmitVertexInstr * instr)25    void visit(EmitVertexInstr *instr) override { (void)instr; }
visit(GDSInstr * instr)26    void visit(GDSInstr *instr) override { (void)instr; };
visit(WriteTFInstr * instr)27    void visit(WriteTFInstr *instr) override { (void)instr; };
visit(LDSAtomicInstr * instr)28    void visit(LDSAtomicInstr *instr) override { (void)instr; };
visit(LDSReadInstr * instr)29    void visit(LDSReadInstr *instr) override { (void)instr; };
visit(RatInstr * instr)30    void visit(RatInstr *instr) override { (void)instr; };
31 
32    void convert_to_mov(AluInstr *alu, int src_idx);
33 
34    void apply_source_mods(AluInstr *alu);
35    void apply_dest_clamp(AluInstr *alu);
36    void try_fuse_with_prev(AluInstr *alu);
37 
38    bool progress{false};
39 };
40 
41 bool
peephole(Shader & sh)42 peephole(Shader& sh)
43 {
44    PeepholeVisitor peephole;
45    for (auto b : sh.func())
46       b->accept(peephole);
47    return peephole.progress;
48 }
49 
50 class ReplacePredicate : public AluInstrVisitor {
51 public:
ReplacePredicate(AluInstr * pred)52    ReplacePredicate(AluInstr *pred):
53        m_pred(pred)
54    {
55    }
56 
57    using AluInstrVisitor::visit;
58 
59    void visit(AluInstr *alu) override;
60 
61    AluInstr *m_pred;
62    bool success{false};
63 };
64 
65 void
visit(AluInstr * instr)66 PeepholeVisitor::visit(AluInstr *instr)
67 {
68    switch (instr->opcode()) {
69    case op1_mov:
70       if (instr->has_alu_flag(alu_dst_clamp))
71          apply_dest_clamp(instr);
72       else if (!instr->has_source_mod(0, AluInstr::mod_abs) &&
73                !instr->has_source_mod(0, AluInstr::mod_neg))
74          try_fuse_with_prev(instr);
75       break;
76    case op2_add:
77    case op2_add_int:
78       if (value_is_const_uint(instr->src(0), 0))
79          convert_to_mov(instr, 1);
80       else if (value_is_const_uint(instr->src(1), 0))
81          convert_to_mov(instr, 0);
82       break;
83    case op2_mul:
84    case op2_mul_ieee:
85       if (value_is_const_float(instr->src(0), 1.0f))
86          convert_to_mov(instr, 1);
87       else if (value_is_const_float(instr->src(1), 1.0f))
88          convert_to_mov(instr, 0);
89       break;
90    case op3_muladd:
91    case op3_muladd_ieee:
92       if (value_is_const_uint(instr->src(0), 0) || value_is_const_uint(instr->src(1), 0))
93          convert_to_mov(instr, 2);
94       break;
95    case op2_killne_int:
96       if (value_is_const_uint(instr->src(1), 0)) {
97          auto src0 = instr->psrc(0)->as_register();
98          if (src0 && src0->has_flag(Register::ssa)) {
99             auto parent = *src0->parents().begin();
100             ReplacePredicate visitor(instr);
101             parent->accept(visitor);
102             progress |= visitor.success;
103          }
104       }
105       break;
106    default:;
107    }
108 
109    auto opinfo = alu_ops.at(instr->opcode());
110    if (opinfo.can_srcmod)
111          apply_source_mods(instr);
112 }
113 
114 void
convert_to_mov(AluInstr * alu,int src_idx)115 PeepholeVisitor::convert_to_mov(AluInstr *alu, int src_idx)
116 {
117    AluInstr::SrcValues new_src{alu->psrc(src_idx)};
118    alu->set_sources(new_src);
119    alu->set_op(op1_mov);
120    progress = true;
121 }
122 
123 void
visit(UNUSED AluGroup * instr)124 PeepholeVisitor::visit(UNUSED AluGroup *instr)
125 {
126    for (auto alu : *instr) {
127       if (!alu)
128          continue;
129       visit(alu);
130    }
131 }
132 
133 void
visit(Block * instr)134 PeepholeVisitor::visit(Block *instr)
135 {
136    for (auto& i : *instr)
137       i->accept(*this);
138 }
139 
140 void
visit(IfInstr * instr)141 PeepholeVisitor::visit(IfInstr *instr)
142 {
143    auto pred = instr->predicate();
144 
145    auto& src1 = pred->src(1);
146    if (value_is_const_uint(src1, 0)) {
147       auto src0 = pred->src(0).as_register();
148       if (src0 && src0->has_flag(Register::ssa) && !src0->parents().empty()) {
149          assert(src0->parents().size() == 1);
150          auto parent = *src0->parents().begin();
151 
152          ReplacePredicate visitor(pred);
153          parent->accept(visitor);
154          progress |= visitor.success;
155       }
156    }
157 }
158 
apply_source_mods(AluInstr * alu)159 void PeepholeVisitor::apply_source_mods(AluInstr *alu)
160 {
161    bool has_abs = alu->n_sources() / alu->alu_slots() < 3;
162 
163    for (unsigned i = 0; i < alu->n_sources(); ++i) {
164 
165       auto reg = alu->psrc(i)->as_register();
166       if (!reg)
167          continue;
168       if (!reg->has_flag(Register::ssa))
169          continue;
170       if (reg->parents().size() != 1)
171          continue;
172 
173       auto p = (*reg->parents().begin())->as_alu();
174       if (!p)
175          continue;
176 
177       if (p->opcode() != op1_mov)
178          continue;
179 
180       if (!has_abs && p->has_source_mod(0, AluInstr::mod_abs))
181          continue;
182 
183       if (!p->has_source_mod(0, AluInstr::mod_abs) &&
184           !p->has_source_mod(0, AluInstr::mod_neg))
185          continue;
186 
187       if (p->has_alu_flag(alu_dst_clamp))
188          continue;
189 
190       auto new_src = p->psrc(0);
191       bool new_src_not_pinned = new_src->pin() == pin_free ||
192                                 new_src->pin() == pin_none;
193 
194       bool old_src_not_pinned = reg->pin() == pin_free ||
195                                 reg->pin() == pin_none;
196 
197       bool sources_equal_channel = reg->pin() == pin_chan &&
198                                    new_src->pin() == pin_chan &&
199                                    new_src->chan() == reg->chan();
200 
201       if (!new_src_not_pinned &&
202           !old_src_not_pinned &&
203           !sources_equal_channel)
204          continue;
205 
206       uint32_t to_set = 0;
207       AluInstr::SourceMod to_clear = AluInstr::mod_none;
208 
209       if (p->has_source_mod(0, AluInstr::mod_abs))
210          to_set |= AluInstr::mod_abs;
211       if (p->has_source_mod(0, AluInstr::mod_neg)) {
212          if (!alu->has_source_mod(i, AluInstr::mod_neg))
213             to_set |= AluInstr::mod_neg;
214          else
215             to_clear = AluInstr::mod_neg;
216       }
217 
218       progress |= alu->replace_src(i, new_src, to_set, to_clear);
219    }
220 }
221 
try_fuse_with_prev(AluInstr * alu)222 void PeepholeVisitor::try_fuse_with_prev(AluInstr *alu)
223 {
224    if (auto reg = alu->src(0).as_register()) {
225       if (!reg->has_flag(Register::ssa) ||
226           reg->uses().size() != 1 ||
227           reg->parents().size() != 1)
228          return;
229       auto p = *reg->parents().begin();
230       auto dest = alu->dest();
231       if (!dest->has_flag(Register::ssa) &&
232           alu->block_id() != p->block_id())
233          return;
234       if (p->replace_dest(dest, alu)) {
235          dest->del_parent(alu);
236          dest->add_parent(p);
237          for (auto d : alu->dependend_instr()) {
238             d->add_required_instr(p);
239          }
240          alu->set_dead();
241          progress = true;
242       }
243    }
244 }
245 
apply_dest_clamp(AluInstr * alu)246 void PeepholeVisitor::apply_dest_clamp(AluInstr *alu)
247 {
248    if (alu->has_source_mod(0, AluInstr::mod_abs) ||
249        alu->has_source_mod(0, AluInstr::mod_neg))
250        return;
251 
252    auto dest = alu->dest();
253 
254    assert(dest);
255 
256    if (!dest->has_flag(Register::ssa))
257       return;
258 
259    auto src = alu->psrc(0)->as_register();
260    if (!src)
261       return;
262 
263    if (src->parents().size() != 1)
264       return;
265 
266    if (src->uses().size() != 1)
267       return;
268 
269    auto new_parent = (*src->parents().begin())->as_alu();
270    if (!new_parent)
271       return;
272 
273    auto opinfo = alu_ops.at(new_parent->opcode());
274    if (!opinfo.can_clamp)
275       return;
276 
277    // Move clamp flag to the parent, and let copy propagation do the rest
278    new_parent->set_alu_flag(alu_dst_clamp);
279    alu->reset_alu_flag(alu_dst_clamp);
280 
281    progress = true;
282 }
283 
284 
285 static EAluOp
pred_from_op(EAluOp pred_op,EAluOp op)286 pred_from_op(EAluOp pred_op, EAluOp op)
287 {
288    switch (pred_op) {
289    case op2_pred_setne_int:
290       switch (op) {
291       case op2_setge_dx10:
292          return op2_pred_setge;
293       case op2_setgt_dx10:
294          return op2_pred_setgt;
295       case op2_sete_dx10:
296          return op2_pred_sete;
297       case op2_setne_dx10:
298          return op2_pred_setne;
299 
300       case op2_setge_int:
301          return op2_pred_setge_int;
302       case op2_setgt_int:
303          return op2_pred_setgt_int;
304       case op2_setge_uint:
305          return op2_pred_setge_uint;
306       case op2_setgt_uint:
307          return op2_pred_setgt_uint;
308       case op2_sete_int:
309          return op2_prede_int;
310       case op2_setne_int:
311          return op2_pred_setne_int;
312       default:
313          return op0_nop;
314       }
315    case op2_prede_int:
316       switch (op) {
317       case op2_sete_int:
318          return op2_pred_setne_int;
319       case op2_setne_int:
320          return op2_prede_int;
321       default:
322          return op0_nop;
323       }
324    case op2_pred_setne:
325       switch (op) {
326       case op2_setge:
327          return op2_pred_setge;
328       case op2_setgt:
329          return op2_pred_setgt;
330       case op2_sete:
331          return op2_pred_sete;
332       default:
333          return op0_nop;
334       }
335    case op2_killne_int:
336       switch (op) {
337       case op2_setge_dx10:
338          return op2_killge;
339       case op2_setgt_dx10:
340          return op2_killgt;
341       case op2_sete_dx10:
342          return op2_kille;
343       case op2_setne_dx10:
344          return op2_killne;
345       case op2_setge_int:
346          return op2_killge_int;
347       case op2_setgt_int:
348          return op2_killgt_int;
349       case op2_setge_uint:
350          return op2_killge_uint;
351       case op2_setgt_uint:
352          return op2_killgt_uint;
353       case op2_sete_int:
354          return op2_kille_int;
355       case op2_setne_int:
356          return op2_killne_int;
357       default:
358          return op0_nop;
359       }
360 
361    default:
362       return op0_nop;
363    }
364 }
365 
366 void
visit(AluInstr * alu)367 ReplacePredicate::visit(AluInstr *alu)
368 {
369    auto new_op = pred_from_op(m_pred->opcode(), alu->opcode());
370 
371    if (new_op == op0_nop)
372       return;
373 
374    for (auto& s : alu->sources()) {
375       auto reg = s->as_register();
376       /* Protect against propagating
377        *
378        *   V = COND(R, X)
379        *   R = SOME_OP
380        *   IF (V)
381        *
382        * to
383        *
384        *   R = SOME_OP
385        *   IF (COND(R, X))
386        */
387       if (reg && !reg->has_flag(Register::ssa))
388          return;
389    }
390 
391    m_pred->set_op(new_op);
392    m_pred->set_sources(alu->sources());
393 
394    std::array<AluInstr::SourceMod, 2> mods = { AluInstr::mod_abs, AluInstr::mod_neg };
395 
396    for (int i = 0; i < 2; ++i) {
397       for (auto m : mods) {
398          if (alu->has_source_mod(i, m))
399             m_pred->set_source_mod(i, m);
400       }
401    }
402 
403    success = true;
404 }
405 
406 } // namespace r600
407