xref: /aosp_15_r20/external/mesa3d/src/amd/compiler/aco_optimizer.cpp (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2018 Valve Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "aco_builder.h"
8 #include "aco_ir.h"
9 
10 #include "util/half_float.h"
11 #include "util/memstream.h"
12 
13 #include <algorithm>
14 #include <array>
15 #include <vector>
16 
17 namespace aco {
18 
19 namespace {
20 /**
21  * The optimizer works in 4 phases:
22  * (1) The first pass collects information for each ssa-def,
23  *     propagates reg->reg operands of the same type, inline constants
24  *     and neg/abs input modifiers.
25  * (2) The second pass combines instructions like mad, omod, clamp and
26  *     propagates sgpr's on VALU instructions.
27  *     This pass depends on information collected in the first pass.
28  * (3) The third pass goes backwards, and selects instructions,
29  *     i.e. decides if a mad instruction is profitable and eliminates dead code.
30  * (4) The fourth pass cleans up the sequence: literals get applied and dead
31  *     instructions are removed from the sequence.
32  */
33 
34 struct mad_info {
35    aco_ptr<Instruction> add_instr;
36    uint32_t mul_temp_id;
37    uint16_t literal_mask;
38    uint16_t fp16_mask;
39 
mad_infoaco::__anon9e387afb0111::mad_info40    mad_info(aco_ptr<Instruction> instr, uint32_t id)
41        : add_instr(std::move(instr)), mul_temp_id(id), literal_mask(0), fp16_mask(0)
42    {}
43 };
44 
45 enum Label {
46    label_vec = 1 << 0,
47    label_constant_32bit = 1 << 1,
48    /* label_{abs,neg,mul,omod2,omod4,omod5,clamp} are used for both 16 and
49     * 32-bit operations but this shouldn't cause any issues because we don't
50     * look through any conversions */
51    label_abs = 1 << 2,
52    label_neg = 1 << 3,
53    label_mul = 1 << 4,
54    label_temp = 1 << 5,
55    label_literal = 1 << 6,
56    label_mad = 1 << 7,
57    label_omod2 = 1 << 8,
58    label_omod4 = 1 << 9,
59    label_omod5 = 1 << 10,
60    label_clamp = 1 << 12,
61    label_b2f = 1 << 16,
62    label_add_sub = 1 << 17,
63    label_bitwise = 1 << 18,
64    label_minmax = 1 << 19,
65    label_vopc = 1 << 20,
66    label_uniform_bool = 1 << 21,
67    label_constant_64bit = 1 << 22,
68    label_uniform_bitwise = 1 << 23,
69    label_scc_invert = 1 << 24,
70    label_scc_needed = 1 << 26,
71    label_b2i = 1 << 27,
72    label_fcanonicalize = 1 << 28,
73    label_constant_16bit = 1 << 29,
74    label_usedef = 1 << 30,   /* generic label */
75    label_vop3p = 1ull << 31, /* 1ull to prevent sign extension */
76    label_canonicalized = 1ull << 32,
77    label_extract = 1ull << 33,
78    label_insert = 1ull << 34,
79    label_dpp16 = 1ull << 35,
80    label_dpp8 = 1ull << 36,
81    label_f2f32 = 1ull << 37,
82    label_f2f16 = 1ull << 38,
83    label_split = 1ull << 39,
84 };
85 
86 static constexpr uint64_t instr_usedef_labels =
87    label_vec | label_mul | label_add_sub | label_vop3p | label_bitwise | label_uniform_bitwise |
88    label_minmax | label_vopc | label_usedef | label_extract | label_dpp16 | label_dpp8 |
89    label_f2f32;
90 static constexpr uint64_t instr_mod_labels =
91    label_omod2 | label_omod4 | label_omod5 | label_clamp | label_insert | label_f2f16;
92 
93 static constexpr uint64_t instr_labels = instr_usedef_labels | instr_mod_labels | label_split;
94 static constexpr uint64_t temp_labels = label_abs | label_neg | label_temp | label_b2f |
95                                         label_uniform_bool | label_scc_invert | label_b2i |
96                                         label_fcanonicalize;
97 static constexpr uint32_t val_labels =
98    label_constant_32bit | label_constant_64bit | label_constant_16bit | label_literal | label_mad;
99 
100 static_assert((instr_labels & temp_labels) == 0, "labels cannot intersect");
101 static_assert((instr_labels & val_labels) == 0, "labels cannot intersect");
102 static_assert((temp_labels & val_labels) == 0, "labels cannot intersect");
103 
104 struct ssa_info {
105    uint64_t label;
106    union {
107       uint32_t val;
108       Temp temp;
109       Instruction* instr;
110    };
111 
ssa_infoaco::__anon9e387afb0111::ssa_info112    ssa_info() : label(0) {}
113 
add_labelaco::__anon9e387afb0111::ssa_info114    void add_label(Label new_label)
115    {
116       /* Since all the instr_usedef_labels use instr for the same thing
117        * (indicating the defining instruction), there is usually no need to
118        * clear any other instr labels. */
119       if (new_label & instr_usedef_labels)
120          label &= ~(instr_mod_labels | temp_labels | val_labels); /* instr, temp and val alias */
121 
122       if (new_label & instr_mod_labels) {
123          label &= ~instr_labels;
124          label &= ~(temp_labels | val_labels); /* instr, temp and val alias */
125       }
126 
127       if (new_label & temp_labels) {
128          label &= ~temp_labels;
129          label &= ~(instr_labels | val_labels); /* instr, temp and val alias */
130       }
131 
132       uint32_t const_labels =
133          label_literal | label_constant_32bit | label_constant_64bit | label_constant_16bit;
134       if (new_label & const_labels) {
135          label &= ~val_labels | const_labels;
136          label &= ~(instr_labels | temp_labels); /* instr, temp and val alias */
137       } else if (new_label & val_labels) {
138          label &= ~val_labels;
139          label &= ~(instr_labels | temp_labels); /* instr, temp and val alias */
140       }
141 
142       label |= new_label;
143    }
144 
set_vecaco::__anon9e387afb0111::ssa_info145    void set_vec(Instruction* vec)
146    {
147       add_label(label_vec);
148       instr = vec;
149    }
150 
is_vecaco::__anon9e387afb0111::ssa_info151    bool is_vec() { return label & label_vec; }
152 
set_constantaco::__anon9e387afb0111::ssa_info153    void set_constant(amd_gfx_level gfx_level, uint64_t constant)
154    {
155       Operand op16 = Operand::c16(constant);
156       Operand op32 = Operand::get_const(gfx_level, constant, 4);
157       add_label(label_literal);
158       val = constant;
159 
160       /* check that no upper bits are lost in case of packed 16bit constants */
161       if (gfx_level >= GFX8 && !op16.isLiteral() &&
162           op16.constantValue16(true) == ((constant >> 16) & 0xffff))
163          add_label(label_constant_16bit);
164 
165       if (!op32.isLiteral())
166          add_label(label_constant_32bit);
167 
168       if (Operand::is_constant_representable(constant, 8))
169          add_label(label_constant_64bit);
170 
171       if (label & label_constant_64bit) {
172          val = Operand::c64(constant).constantValue();
173          if (val != constant)
174             label &= ~(label_literal | label_constant_16bit | label_constant_32bit);
175       }
176    }
177 
is_constantaco::__anon9e387afb0111::ssa_info178    bool is_constant(unsigned bits)
179    {
180       switch (bits) {
181       case 8: return label & label_literal;
182       case 16: return label & label_constant_16bit;
183       case 32: return label & label_constant_32bit;
184       case 64: return label & label_constant_64bit;
185       }
186       return false;
187    }
188 
is_literalaco::__anon9e387afb0111::ssa_info189    bool is_literal(unsigned bits)
190    {
191       bool is_lit = label & label_literal;
192       switch (bits) {
193       case 8: return false;
194       case 16: return is_lit && ~(label & label_constant_16bit);
195       case 32: return is_lit && ~(label & label_constant_32bit);
196       case 64: return false;
197       }
198       return false;
199    }
200 
is_constant_or_literalaco::__anon9e387afb0111::ssa_info201    bool is_constant_or_literal(unsigned bits)
202    {
203       if (bits == 64)
204          return label & label_constant_64bit;
205       else
206          return label & label_literal;
207    }
208 
set_absaco::__anon9e387afb0111::ssa_info209    void set_abs(Temp abs_temp)
210    {
211       add_label(label_abs);
212       temp = abs_temp;
213    }
214 
is_absaco::__anon9e387afb0111::ssa_info215    bool is_abs() { return label & label_abs; }
216 
set_negaco::__anon9e387afb0111::ssa_info217    void set_neg(Temp neg_temp)
218    {
219       add_label(label_neg);
220       temp = neg_temp;
221    }
222 
is_negaco::__anon9e387afb0111::ssa_info223    bool is_neg() { return label & label_neg; }
224 
set_neg_absaco::__anon9e387afb0111::ssa_info225    void set_neg_abs(Temp neg_abs_temp)
226    {
227       add_label((Label)((uint32_t)label_abs | (uint32_t)label_neg));
228       temp = neg_abs_temp;
229    }
230 
set_mulaco::__anon9e387afb0111::ssa_info231    void set_mul(Instruction* mul)
232    {
233       add_label(label_mul);
234       instr = mul;
235    }
236 
is_mulaco::__anon9e387afb0111::ssa_info237    bool is_mul() { return label & label_mul; }
238 
set_tempaco::__anon9e387afb0111::ssa_info239    void set_temp(Temp tmp)
240    {
241       add_label(label_temp);
242       temp = tmp;
243    }
244 
is_tempaco::__anon9e387afb0111::ssa_info245    bool is_temp() { return label & label_temp; }
246 
set_madaco::__anon9e387afb0111::ssa_info247    void set_mad(uint32_t mad_info_idx)
248    {
249       add_label(label_mad);
250       val = mad_info_idx;
251    }
252 
is_madaco::__anon9e387afb0111::ssa_info253    bool is_mad() { return label & label_mad; }
254 
set_omod2aco::__anon9e387afb0111::ssa_info255    void set_omod2(Instruction* mul)
256    {
257       if (label & temp_labels)
258          return;
259       add_label(label_omod2);
260       instr = mul;
261    }
262 
is_omod2aco::__anon9e387afb0111::ssa_info263    bool is_omod2() { return label & label_omod2; }
264 
set_omod4aco::__anon9e387afb0111::ssa_info265    void set_omod4(Instruction* mul)
266    {
267       if (label & temp_labels)
268          return;
269       add_label(label_omod4);
270       instr = mul;
271    }
272 
is_omod4aco::__anon9e387afb0111::ssa_info273    bool is_omod4() { return label & label_omod4; }
274 
set_omod5aco::__anon9e387afb0111::ssa_info275    void set_omod5(Instruction* mul)
276    {
277       if (label & temp_labels)
278          return;
279       add_label(label_omod5);
280       instr = mul;
281    }
282 
is_omod5aco::__anon9e387afb0111::ssa_info283    bool is_omod5() { return label & label_omod5; }
284 
set_clampaco::__anon9e387afb0111::ssa_info285    void set_clamp(Instruction* med3)
286    {
287       if (label & temp_labels)
288          return;
289       add_label(label_clamp);
290       instr = med3;
291    }
292 
is_clampaco::__anon9e387afb0111::ssa_info293    bool is_clamp() { return label & label_clamp; }
294 
set_f2f16aco::__anon9e387afb0111::ssa_info295    void set_f2f16(Instruction* conv)
296    {
297       if (label & temp_labels)
298          return;
299       add_label(label_f2f16);
300       instr = conv;
301    }
302 
is_f2f16aco::__anon9e387afb0111::ssa_info303    bool is_f2f16() { return label & label_f2f16; }
304 
set_b2faco::__anon9e387afb0111::ssa_info305    void set_b2f(Temp b2f_val)
306    {
307       add_label(label_b2f);
308       temp = b2f_val;
309    }
310 
is_b2faco::__anon9e387afb0111::ssa_info311    bool is_b2f() { return label & label_b2f; }
312 
set_add_subaco::__anon9e387afb0111::ssa_info313    void set_add_sub(Instruction* add_sub_instr)
314    {
315       add_label(label_add_sub);
316       instr = add_sub_instr;
317    }
318 
is_add_subaco::__anon9e387afb0111::ssa_info319    bool is_add_sub() { return label & label_add_sub; }
320 
set_bitwiseaco::__anon9e387afb0111::ssa_info321    void set_bitwise(Instruction* bitwise_instr)
322    {
323       add_label(label_bitwise);
324       instr = bitwise_instr;
325    }
326 
is_bitwiseaco::__anon9e387afb0111::ssa_info327    bool is_bitwise() { return label & label_bitwise; }
328 
set_uniform_bitwiseaco::__anon9e387afb0111::ssa_info329    void set_uniform_bitwise() { add_label(label_uniform_bitwise); }
330 
is_uniform_bitwiseaco::__anon9e387afb0111::ssa_info331    bool is_uniform_bitwise() { return label & label_uniform_bitwise; }
332 
set_minmaxaco::__anon9e387afb0111::ssa_info333    void set_minmax(Instruction* minmax_instr)
334    {
335       add_label(label_minmax);
336       instr = minmax_instr;
337    }
338 
is_minmaxaco::__anon9e387afb0111::ssa_info339    bool is_minmax() { return label & label_minmax; }
340 
set_vopcaco::__anon9e387afb0111::ssa_info341    void set_vopc(Instruction* vopc_instr)
342    {
343       add_label(label_vopc);
344       instr = vopc_instr;
345    }
346 
is_vopcaco::__anon9e387afb0111::ssa_info347    bool is_vopc() { return label & label_vopc; }
348 
set_scc_neededaco::__anon9e387afb0111::ssa_info349    void set_scc_needed() { add_label(label_scc_needed); }
350 
is_scc_neededaco::__anon9e387afb0111::ssa_info351    bool is_scc_needed() { return label & label_scc_needed; }
352 
set_scc_invertaco::__anon9e387afb0111::ssa_info353    void set_scc_invert(Temp scc_inv)
354    {
355       add_label(label_scc_invert);
356       temp = scc_inv;
357    }
358 
is_scc_invertaco::__anon9e387afb0111::ssa_info359    bool is_scc_invert() { return label & label_scc_invert; }
360 
set_uniform_boolaco::__anon9e387afb0111::ssa_info361    void set_uniform_bool(Temp uniform_bool)
362    {
363       add_label(label_uniform_bool);
364       temp = uniform_bool;
365    }
366 
is_uniform_boolaco::__anon9e387afb0111::ssa_info367    bool is_uniform_bool() { return label & label_uniform_bool; }
368 
set_b2iaco::__anon9e387afb0111::ssa_info369    void set_b2i(Temp b2i_val)
370    {
371       add_label(label_b2i);
372       temp = b2i_val;
373    }
374 
is_b2iaco::__anon9e387afb0111::ssa_info375    bool is_b2i() { return label & label_b2i; }
376 
set_usedefaco::__anon9e387afb0111::ssa_info377    void set_usedef(Instruction* label_instr)
378    {
379       add_label(label_usedef);
380       instr = label_instr;
381    }
382 
is_usedefaco::__anon9e387afb0111::ssa_info383    bool is_usedef() { return label & label_usedef; }
384 
set_vop3paco::__anon9e387afb0111::ssa_info385    void set_vop3p(Instruction* vop3p_instr)
386    {
387       add_label(label_vop3p);
388       instr = vop3p_instr;
389    }
390 
is_vop3paco::__anon9e387afb0111::ssa_info391    bool is_vop3p() { return label & label_vop3p; }
392 
set_fcanonicalizeaco::__anon9e387afb0111::ssa_info393    void set_fcanonicalize(Temp tmp)
394    {
395       add_label(label_fcanonicalize);
396       temp = tmp;
397    }
398 
is_fcanonicalizeaco::__anon9e387afb0111::ssa_info399    bool is_fcanonicalize() { return label & label_fcanonicalize; }
400 
set_canonicalizedaco::__anon9e387afb0111::ssa_info401    void set_canonicalized() { add_label(label_canonicalized); }
402 
is_canonicalizedaco::__anon9e387afb0111::ssa_info403    bool is_canonicalized() { return label & label_canonicalized; }
404 
set_f2f32aco::__anon9e387afb0111::ssa_info405    void set_f2f32(Instruction* cvt)
406    {
407       add_label(label_f2f32);
408       instr = cvt;
409    }
410 
is_f2f32aco::__anon9e387afb0111::ssa_info411    bool is_f2f32() { return label & label_f2f32; }
412 
set_extractaco::__anon9e387afb0111::ssa_info413    void set_extract(Instruction* extract)
414    {
415       add_label(label_extract);
416       instr = extract;
417    }
418 
is_extractaco::__anon9e387afb0111::ssa_info419    bool is_extract() { return label & label_extract; }
420 
set_insertaco::__anon9e387afb0111::ssa_info421    void set_insert(Instruction* insert)
422    {
423       if (label & temp_labels)
424          return;
425       add_label(label_insert);
426       instr = insert;
427    }
428 
is_insertaco::__anon9e387afb0111::ssa_info429    bool is_insert() { return label & label_insert; }
430 
set_dpp16aco::__anon9e387afb0111::ssa_info431    void set_dpp16(Instruction* mov)
432    {
433       add_label(label_dpp16);
434       instr = mov;
435    }
436 
set_dpp8aco::__anon9e387afb0111::ssa_info437    void set_dpp8(Instruction* mov)
438    {
439       add_label(label_dpp8);
440       instr = mov;
441    }
442 
is_dppaco::__anon9e387afb0111::ssa_info443    bool is_dpp() { return label & (label_dpp16 | label_dpp8); }
is_dpp16aco::__anon9e387afb0111::ssa_info444    bool is_dpp16() { return label & label_dpp16; }
is_dpp8aco::__anon9e387afb0111::ssa_info445    bool is_dpp8() { return label & label_dpp8; }
446 
set_splitaco::__anon9e387afb0111::ssa_info447    void set_split(Instruction* split)
448    {
449       add_label(label_split);
450       instr = split;
451    }
452 
is_splitaco::__anon9e387afb0111::ssa_info453    bool is_split() { return label & label_split; }
454 };
455 
456 struct opt_ctx {
457    Program* program;
458    float_mode fp_mode;
459    std::vector<aco_ptr<Instruction>> instructions;
460    std::vector<ssa_info> info;
461    std::pair<uint32_t, Temp> last_literal;
462    std::vector<mad_info> mad_infos;
463    std::vector<uint16_t> uses;
464 };
465 
466 bool
can_use_VOP3(opt_ctx & ctx,const aco_ptr<Instruction> & instr)467 can_use_VOP3(opt_ctx& ctx, const aco_ptr<Instruction>& instr)
468 {
469    if (instr->isVOP3())
470       return true;
471 
472    if (instr->isVOP3P() || instr->isVINTERP_INREG())
473       return false;
474 
475    if (instr->operands.size() && instr->operands[0].isLiteral() && ctx.program->gfx_level < GFX10)
476       return false;
477 
478    if (instr->isSDWA())
479       return false;
480 
481    if (instr->isDPP() && ctx.program->gfx_level < GFX11)
482       return false;
483 
484    return instr->opcode != aco_opcode::v_madmk_f32 && instr->opcode != aco_opcode::v_madak_f32 &&
485           instr->opcode != aco_opcode::v_madmk_f16 && instr->opcode != aco_opcode::v_madak_f16 &&
486           instr->opcode != aco_opcode::v_fmamk_f32 && instr->opcode != aco_opcode::v_fmaak_f32 &&
487           instr->opcode != aco_opcode::v_fmamk_f16 && instr->opcode != aco_opcode::v_fmaak_f16 &&
488           instr->opcode != aco_opcode::v_permlane64_b32 &&
489           instr->opcode != aco_opcode::v_readlane_b32 &&
490           instr->opcode != aco_opcode::v_writelane_b32 &&
491           instr->opcode != aco_opcode::v_readfirstlane_b32;
492 }
493 
494 bool
pseudo_propagate_temp(opt_ctx & ctx,aco_ptr<Instruction> & instr,Temp temp,unsigned index)495 pseudo_propagate_temp(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp temp, unsigned index)
496 {
497    if (instr->definitions.empty())
498       return false;
499 
500    const bool vgpr =
501       instr->opcode == aco_opcode::p_as_uniform ||
502       std::all_of(instr->definitions.begin(), instr->definitions.end(),
503                   [](const Definition& def) { return def.regClass().type() == RegType::vgpr; });
504 
505    /* don't propagate VGPRs into SGPR instructions */
506    if (temp.type() == RegType::vgpr && !vgpr)
507       return false;
508 
509    bool can_accept_sgpr =
510       ctx.program->gfx_level >= GFX9 ||
511       std::none_of(instr->definitions.begin(), instr->definitions.end(),
512                    [](const Definition& def) { return def.regClass().is_subdword(); });
513 
514    switch (instr->opcode) {
515    case aco_opcode::p_phi:
516    case aco_opcode::p_linear_phi:
517    case aco_opcode::p_parallelcopy:
518    case aco_opcode::p_create_vector:
519    case aco_opcode::p_start_linear_vgpr:
520       if (temp.bytes() != instr->operands[index].bytes())
521          return false;
522       break;
523    case aco_opcode::p_extract_vector:
524    case aco_opcode::p_extract:
525       if (temp.type() == RegType::sgpr && !can_accept_sgpr)
526          return false;
527       break;
528    case aco_opcode::p_split_vector: {
529       if (temp.type() == RegType::sgpr && !can_accept_sgpr)
530          return false;
531       /* don't increase the vector size */
532       if (temp.bytes() > instr->operands[index].bytes())
533          return false;
534       /* We can decrease the vector size as smaller temporaries are only
535        * propagated by p_as_uniform instructions.
536        * If this propagation leads to invalid IR or hits the assertion below,
537        * it means that some undefined bytes within a dword are begin accessed
538        * and a bug in instruction_selection is likely. */
539       int decrease = instr->operands[index].bytes() - temp.bytes();
540       while (decrease > 0) {
541          decrease -= instr->definitions.back().bytes();
542          instr->definitions.pop_back();
543       }
544       assert(decrease == 0);
545       break;
546    }
547    case aco_opcode::p_as_uniform:
548       if (temp.regClass() == instr->definitions[0].regClass())
549          instr->opcode = aco_opcode::p_parallelcopy;
550       break;
551    default: return false;
552    }
553 
554    instr->operands[index].setTemp(temp);
555    return true;
556 }
557 
558 /* This expects the DPP modifier to be removed. */
559 bool
can_apply_sgprs(opt_ctx & ctx,aco_ptr<Instruction> & instr)560 can_apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
561 {
562    assert(instr->isVALU());
563    if (instr->isSDWA() && ctx.program->gfx_level < GFX9)
564       return false;
565    return instr->opcode != aco_opcode::v_readfirstlane_b32 &&
566           instr->opcode != aco_opcode::v_readlane_b32 &&
567           instr->opcode != aco_opcode::v_readlane_b32_e64 &&
568           instr->opcode != aco_opcode::v_writelane_b32 &&
569           instr->opcode != aco_opcode::v_writelane_b32_e64 &&
570           instr->opcode != aco_opcode::v_permlane16_b32 &&
571           instr->opcode != aco_opcode::v_permlanex16_b32 &&
572           instr->opcode != aco_opcode::v_permlane64_b32 &&
573           instr->opcode != aco_opcode::v_interp_p1_f32 &&
574           instr->opcode != aco_opcode::v_interp_p2_f32 &&
575           instr->opcode != aco_opcode::v_interp_mov_f32 &&
576           instr->opcode != aco_opcode::v_interp_p1ll_f16 &&
577           instr->opcode != aco_opcode::v_interp_p1lv_f16 &&
578           instr->opcode != aco_opcode::v_interp_p2_legacy_f16 &&
579           instr->opcode != aco_opcode::v_interp_p2_f16 &&
580           instr->opcode != aco_opcode::v_interp_p2_hi_f16 &&
581           instr->opcode != aco_opcode::v_interp_p10_f32_inreg &&
582           instr->opcode != aco_opcode::v_interp_p2_f32_inreg &&
583           instr->opcode != aco_opcode::v_interp_p10_f16_f32_inreg &&
584           instr->opcode != aco_opcode::v_interp_p2_f16_f32_inreg &&
585           instr->opcode != aco_opcode::v_interp_p10_rtz_f16_f32_inreg &&
586           instr->opcode != aco_opcode::v_interp_p2_rtz_f16_f32_inreg &&
587           instr->opcode != aco_opcode::v_wmma_f32_16x16x16_f16 &&
588           instr->opcode != aco_opcode::v_wmma_f32_16x16x16_bf16 &&
589           instr->opcode != aco_opcode::v_wmma_f16_16x16x16_f16 &&
590           instr->opcode != aco_opcode::v_wmma_bf16_16x16x16_bf16 &&
591           instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu8 &&
592           instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu4;
593 }
594 
595 /* only covers special cases */
596 bool
alu_can_accept_constant(const aco_ptr<Instruction> & instr,unsigned operand)597 alu_can_accept_constant(const aco_ptr<Instruction>& instr, unsigned operand)
598 {
599    /* Fixed operands can't accept constants because we need them
600     * to be in their fixed register.
601     */
602    assert(instr->operands.size() > operand);
603    if (instr->operands[operand].isFixed())
604       return false;
605 
606    /* SOPP instructions can't use constants. */
607    if (instr->isSOPP())
608       return false;
609 
610    switch (instr->opcode) {
611    case aco_opcode::s_fmac_f16:
612    case aco_opcode::s_fmac_f32:
613    case aco_opcode::v_mac_f32:
614    case aco_opcode::v_writelane_b32:
615    case aco_opcode::v_writelane_b32_e64:
616    case aco_opcode::v_cndmask_b32: return operand != 2;
617    case aco_opcode::s_addk_i32:
618    case aco_opcode::s_mulk_i32:
619    case aco_opcode::p_extract_vector:
620    case aco_opcode::p_split_vector:
621    case aco_opcode::v_readlane_b32:
622    case aco_opcode::v_readlane_b32_e64:
623    case aco_opcode::v_readfirstlane_b32:
624    case aco_opcode::p_extract:
625    case aco_opcode::p_insert: return operand != 0;
626    case aco_opcode::p_bpermute_readlane:
627    case aco_opcode::p_bpermute_shared_vgpr:
628    case aco_opcode::p_bpermute_permlane:
629    case aco_opcode::p_interp_gfx11:
630    case aco_opcode::p_dual_src_export_gfx11:
631    case aco_opcode::v_interp_p1_f32:
632    case aco_opcode::v_interp_p2_f32:
633    case aco_opcode::v_interp_mov_f32:
634    case aco_opcode::v_interp_p1ll_f16:
635    case aco_opcode::v_interp_p1lv_f16:
636    case aco_opcode::v_interp_p2_legacy_f16:
637    case aco_opcode::v_interp_p10_f32_inreg:
638    case aco_opcode::v_interp_p2_f32_inreg:
639    case aco_opcode::v_interp_p10_f16_f32_inreg:
640    case aco_opcode::v_interp_p2_f16_f32_inreg:
641    case aco_opcode::v_interp_p10_rtz_f16_f32_inreg:
642    case aco_opcode::v_interp_p2_rtz_f16_f32_inreg:
643    case aco_opcode::v_wmma_f32_16x16x16_f16:
644    case aco_opcode::v_wmma_f32_16x16x16_bf16:
645    case aco_opcode::v_wmma_f16_16x16x16_f16:
646    case aco_opcode::v_wmma_bf16_16x16x16_bf16:
647    case aco_opcode::v_wmma_i32_16x16x16_iu8:
648    case aco_opcode::v_wmma_i32_16x16x16_iu4: return false;
649    default: return true;
650    }
651 }
652 
653 bool
valu_can_accept_vgpr(aco_ptr<Instruction> & instr,unsigned operand)654 valu_can_accept_vgpr(aco_ptr<Instruction>& instr, unsigned operand)
655 {
656    if (instr->opcode == aco_opcode::v_writelane_b32 ||
657        instr->opcode == aco_opcode::v_writelane_b32_e64)
658       return operand == 2;
659    if (instr->opcode == aco_opcode::v_permlane16_b32 ||
660        instr->opcode == aco_opcode::v_permlanex16_b32 ||
661        instr->opcode == aco_opcode::v_readlane_b32 ||
662        instr->opcode == aco_opcode::v_readlane_b32_e64)
663       return operand == 0;
664    return instr_info.classes[(int)instr->opcode] != instr_class::valu_pseudo_scalar_trans;
665 }
666 
667 /* check constant bus and literal limitations */
668 bool
check_vop3_operands(opt_ctx & ctx,unsigned num_operands,Operand * operands)669 check_vop3_operands(opt_ctx& ctx, unsigned num_operands, Operand* operands)
670 {
671    int limit = ctx.program->gfx_level >= GFX10 ? 2 : 1;
672    Operand literal32(s1);
673    Operand literal64(s2);
674    unsigned num_sgprs = 0;
675    unsigned sgpr[] = {0, 0};
676 
677    for (unsigned i = 0; i < num_operands; i++) {
678       Operand op = operands[i];
679 
680       if (op.hasRegClass() && op.regClass().type() == RegType::sgpr) {
681          /* two reads of the same SGPR count as 1 to the limit */
682          if (op.tempId() != sgpr[0] && op.tempId() != sgpr[1]) {
683             if (num_sgprs < 2)
684                sgpr[num_sgprs++] = op.tempId();
685             limit--;
686             if (limit < 0)
687                return false;
688          }
689       } else if (op.isLiteral()) {
690          if (ctx.program->gfx_level < GFX10)
691             return false;
692 
693          if (!literal32.isUndefined() && literal32.constantValue() != op.constantValue())
694             return false;
695          if (!literal64.isUndefined() && literal64.constantValue() != op.constantValue())
696             return false;
697 
698          /* Any number of 32-bit literals counts as only 1 to the limit. Same
699           * (but separately) for 64-bit literals. */
700          if (op.size() == 1 && literal32.isUndefined()) {
701             limit--;
702             literal32 = op;
703          } else if (op.size() == 2 && literal64.isUndefined()) {
704             limit--;
705             literal64 = op;
706          }
707 
708          if (limit < 0)
709             return false;
710       }
711    }
712 
713    return true;
714 }
715 
716 bool
parse_base_offset(opt_ctx & ctx,Instruction * instr,unsigned op_index,Temp * base,uint32_t * offset,bool prevent_overflow)717 parse_base_offset(opt_ctx& ctx, Instruction* instr, unsigned op_index, Temp* base, uint32_t* offset,
718                   bool prevent_overflow)
719 {
720    Operand op = instr->operands[op_index];
721 
722    if (!op.isTemp())
723       return false;
724    Temp tmp = op.getTemp();
725    if (!ctx.info[tmp.id()].is_add_sub())
726       return false;
727 
728    Instruction* add_instr = ctx.info[tmp.id()].instr;
729 
730    unsigned mask = 0x3;
731    bool is_sub = false;
732    switch (add_instr->opcode) {
733    case aco_opcode::v_add_u32:
734    case aco_opcode::v_add_co_u32:
735    case aco_opcode::v_add_co_u32_e64:
736    case aco_opcode::s_add_i32:
737    case aco_opcode::s_add_u32: break;
738    case aco_opcode::v_sub_u32:
739    case aco_opcode::v_sub_i32:
740    case aco_opcode::v_sub_co_u32:
741    case aco_opcode::v_sub_co_u32_e64:
742    case aco_opcode::s_sub_u32:
743    case aco_opcode::s_sub_i32:
744       mask = 0x2;
745       is_sub = true;
746       break;
747    case aco_opcode::v_subrev_u32:
748    case aco_opcode::v_subrev_co_u32:
749    case aco_opcode::v_subrev_co_u32_e64:
750       mask = 0x1;
751       is_sub = true;
752       break;
753    default: return false;
754    }
755    if (prevent_overflow && !add_instr->definitions[0].isNUW())
756       return false;
757 
758    if (add_instr->usesModifiers())
759       return false;
760 
761    u_foreach_bit (i, mask) {
762       if (add_instr->operands[i].isConstant()) {
763          *offset = add_instr->operands[i].constantValue() * (uint32_t)(is_sub ? -1 : 1);
764       } else if (add_instr->operands[i].isTemp() &&
765                  ctx.info[add_instr->operands[i].tempId()].is_constant_or_literal(32)) {
766          *offset = ctx.info[add_instr->operands[i].tempId()].val * (uint32_t)(is_sub ? -1 : 1);
767       } else {
768          continue;
769       }
770       if (!add_instr->operands[!i].isTemp())
771          continue;
772 
773       uint32_t offset2 = 0;
774       if (parse_base_offset(ctx, add_instr, !i, base, &offset2, prevent_overflow)) {
775          *offset += offset2;
776       } else {
777          *base = add_instr->operands[!i].getTemp();
778       }
779       return true;
780    }
781 
782    return false;
783 }
784 
785 void
skip_smem_offset_align(opt_ctx & ctx,SMEM_instruction * smem)786 skip_smem_offset_align(opt_ctx& ctx, SMEM_instruction* smem)
787 {
788    bool soe = smem->operands.size() >= (!smem->definitions.empty() ? 3 : 4);
789    if (soe && !smem->operands[1].isConstant())
790       return;
791    /* We don't need to check the constant offset because the address seems to be calculated with
792     * (offset&-4 + const_offset&-4), not (offset+const_offset)&-4.
793     */
794 
795    Operand& op = smem->operands[soe ? smem->operands.size() - 1 : 1];
796    if (!op.isTemp() || !ctx.info[op.tempId()].is_bitwise())
797       return;
798 
799    Instruction* bitwise_instr = ctx.info[op.tempId()].instr;
800    if (bitwise_instr->opcode != aco_opcode::s_and_b32)
801       return;
802 
803    if (bitwise_instr->operands[0].constantEquals(-4) &&
804        bitwise_instr->operands[1].isOfType(op.regClass().type()))
805       op.setTemp(bitwise_instr->operands[1].getTemp());
806    else if (bitwise_instr->operands[1].constantEquals(-4) &&
807             bitwise_instr->operands[0].isOfType(op.regClass().type()))
808       op.setTemp(bitwise_instr->operands[0].getTemp());
809 }
810 
811 void
smem_combine(opt_ctx & ctx,aco_ptr<Instruction> & instr)812 smem_combine(opt_ctx& ctx, aco_ptr<Instruction>& instr)
813 {
814    /* skip &-4 before offset additions: load((a + 16) & -4, 0) */
815    if (!instr->operands.empty())
816       skip_smem_offset_align(ctx, &instr->smem());
817 
818    /* propagate constants and combine additions */
819    if (!instr->operands.empty() && instr->operands[1].isTemp()) {
820       SMEM_instruction& smem = instr->smem();
821       ssa_info info = ctx.info[instr->operands[1].tempId()];
822 
823       Temp base;
824       uint32_t offset;
825       if (info.is_constant_or_literal(32) &&
826           ((ctx.program->gfx_level == GFX6 && info.val <= 0x3FF) ||
827            (ctx.program->gfx_level == GFX7 && info.val <= 0xFFFFFFFF) ||
828            (ctx.program->gfx_level >= GFX8 && info.val <= 0xFFFFF))) {
829          instr->operands[1] = Operand::c32(info.val);
830       } else if (parse_base_offset(ctx, instr.get(), 1, &base, &offset, true) &&
831                  base.regClass() == s1 && offset <= 0xFFFFF && ctx.program->gfx_level >= GFX9 &&
832                  offset % 4u == 0) {
833          bool soe = smem.operands.size() >= (!smem.definitions.empty() ? 3 : 4);
834          if (soe) {
835             if (ctx.info[smem.operands.back().tempId()].is_constant_or_literal(32) &&
836                 ctx.info[smem.operands.back().tempId()].val == 0) {
837                smem.operands[1] = Operand::c32(offset);
838                smem.operands.back() = Operand(base);
839             }
840          } else {
841             Instruction* new_instr = create_instruction(
842                smem.opcode, Format::SMEM, smem.operands.size() + 1, smem.definitions.size());
843             new_instr->operands[0] = smem.operands[0];
844             new_instr->operands[1] = Operand::c32(offset);
845             if (smem.definitions.empty())
846                new_instr->operands[2] = smem.operands[2];
847             new_instr->operands.back() = Operand(base);
848             if (!smem.definitions.empty())
849                new_instr->definitions[0] = smem.definitions[0];
850             new_instr->smem().sync = smem.sync;
851             new_instr->smem().cache = smem.cache;
852             instr.reset(new_instr);
853          }
854       }
855    }
856 
857    /* skip &-4 after offset additions: load(a & -4, 16) */
858    if (!instr->operands.empty())
859       skip_smem_offset_align(ctx, &instr->smem());
860 }
861 
862 Operand
get_constant_op(opt_ctx & ctx,ssa_info info,uint32_t bits)863 get_constant_op(opt_ctx& ctx, ssa_info info, uint32_t bits)
864 {
865    if (bits == 64)
866       return Operand::c32_or_c64(info.val, true);
867    return Operand::get_const(ctx.program->gfx_level, info.val, bits / 8u);
868 }
869 
870 void
propagate_constants_vop3p(opt_ctx & ctx,aco_ptr<Instruction> & instr,ssa_info & info,unsigned i)871 propagate_constants_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr, ssa_info& info, unsigned i)
872 {
873    if (!info.is_constant_or_literal(32))
874       return;
875 
876    assert(instr->operands[i].isTemp());
877    unsigned bits = get_operand_size(instr, i);
878    if (info.is_constant(bits)) {
879       instr->operands[i] = get_constant_op(ctx, info, bits);
880       return;
881    }
882 
883    /* The accumulation operand of dot product instructions ignores opsel. */
884    bool cannot_use_opsel =
885       (instr->opcode == aco_opcode::v_dot4_i32_i8 || instr->opcode == aco_opcode::v_dot2_i32_i16 ||
886        instr->opcode == aco_opcode::v_dot4_i32_iu8 || instr->opcode == aco_opcode::v_dot4_u32_u8 ||
887        instr->opcode == aco_opcode::v_dot2_u32_u16) &&
888       i == 2;
889    if (cannot_use_opsel)
890       return;
891 
892    /* try to fold inline constants */
893    VALU_instruction* vop3p = &instr->valu();
894    bool opsel_lo = vop3p->opsel_lo[i];
895    bool opsel_hi = vop3p->opsel_hi[i];
896 
897    Operand const_op[2];
898    bool const_opsel[2] = {false, false};
899    for (unsigned j = 0; j < 2; j++) {
900       if ((unsigned)opsel_lo != j && (unsigned)opsel_hi != j)
901          continue; /* this half is unused */
902 
903       uint16_t val = info.val >> (j ? 16 : 0);
904       Operand op = Operand::get_const(ctx.program->gfx_level, val, bits / 8u);
905       if (bits == 32 && op.isLiteral()) /* try sign extension */
906          op = Operand::get_const(ctx.program->gfx_level, val | 0xffff0000, 4);
907       if (bits == 32 && op.isLiteral()) { /* try shifting left */
908          op = Operand::get_const(ctx.program->gfx_level, val << 16, 4);
909          const_opsel[j] = true;
910       }
911       if (op.isLiteral())
912          return;
913       const_op[j] = op;
914    }
915 
916    Operand const_lo = const_op[0];
917    Operand const_hi = const_op[1];
918    bool const_lo_opsel = const_opsel[0];
919    bool const_hi_opsel = const_opsel[1];
920 
921    if (opsel_lo == opsel_hi) {
922       /* use the single 16bit value */
923       instr->operands[i] = opsel_lo ? const_hi : const_lo;
924 
925       /* opsel must point the same for both halves */
926       opsel_lo = opsel_lo ? const_hi_opsel : const_lo_opsel;
927       opsel_hi = opsel_lo;
928    } else if (const_lo == const_hi) {
929       /* both constants are the same */
930       instr->operands[i] = const_lo;
931 
932       /* opsel must point the same for both halves */
933       opsel_lo = const_lo_opsel;
934       opsel_hi = const_lo_opsel;
935    } else if (const_lo.constantValue16(const_lo_opsel) ==
936               const_hi.constantValue16(!const_hi_opsel)) {
937       instr->operands[i] = const_hi;
938 
939       /* redirect opsel selection */
940       opsel_lo = opsel_lo ? const_hi_opsel : !const_hi_opsel;
941       opsel_hi = opsel_hi ? const_hi_opsel : !const_hi_opsel;
942    } else if (const_hi.constantValue16(const_hi_opsel) ==
943               const_lo.constantValue16(!const_lo_opsel)) {
944       instr->operands[i] = const_lo;
945 
946       /* redirect opsel selection */
947       opsel_lo = opsel_lo ? !const_lo_opsel : const_lo_opsel;
948       opsel_hi = opsel_hi ? !const_lo_opsel : const_lo_opsel;
949    } else if (bits == 16 && const_lo.constantValue() == (const_hi.constantValue() ^ (1 << 15))) {
950       assert(const_lo_opsel == false && const_hi_opsel == false);
951 
952       /* const_lo == -const_hi */
953       if (!can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, i))
954          return;
955 
956       instr->operands[i] = Operand::c16(const_lo.constantValue() & 0x7FFF);
957       bool neg_lo = const_lo.constantValue() & (1 << 15);
958       vop3p->neg_lo[i] ^= opsel_lo ^ neg_lo;
959       vop3p->neg_hi[i] ^= opsel_hi ^ neg_lo;
960 
961       /* opsel must point to lo for both operands */
962       opsel_lo = false;
963       opsel_hi = false;
964    }
965 
966    vop3p->opsel_lo[i] = opsel_lo;
967    vop3p->opsel_hi[i] = opsel_hi;
968 }
969 
970 bool
fixed_to_exec(Operand op)971 fixed_to_exec(Operand op)
972 {
973    return op.isFixed() && op.physReg() == exec;
974 }
975 
976 SubdwordSel
parse_extract(Instruction * instr)977 parse_extract(Instruction* instr)
978 {
979    if (instr->opcode == aco_opcode::p_extract) {
980       unsigned size = instr->operands[2].constantValue() / 8;
981       unsigned offset = instr->operands[1].constantValue() * size;
982       bool sext = instr->operands[3].constantEquals(1);
983       return SubdwordSel(size, offset, sext);
984    } else if (instr->opcode == aco_opcode::p_insert && instr->operands[1].constantEquals(0)) {
985       return instr->operands[2].constantEquals(8) ? SubdwordSel::ubyte : SubdwordSel::uword;
986    } else if (instr->opcode == aco_opcode::p_extract_vector) {
987       unsigned size = instr->definitions[0].bytes();
988       unsigned offset = instr->operands[1].constantValue() * size;
989       if (size <= 2)
990          return SubdwordSel(size, offset, false);
991    } else if (instr->opcode == aco_opcode::p_split_vector) {
992       assert(instr->operands[0].bytes() == 4 && instr->definitions[1].bytes() == 2);
993       return SubdwordSel(2, 2, false);
994    }
995 
996    return SubdwordSel();
997 }
998 
999 SubdwordSel
parse_insert(Instruction * instr)1000 parse_insert(Instruction* instr)
1001 {
1002    if (instr->opcode == aco_opcode::p_extract && instr->operands[3].constantEquals(0) &&
1003        instr->operands[1].constantEquals(0)) {
1004       return instr->operands[2].constantEquals(8) ? SubdwordSel::ubyte : SubdwordSel::uword;
1005    } else if (instr->opcode == aco_opcode::p_insert) {
1006       unsigned size = instr->operands[2].constantValue() / 8;
1007       unsigned offset = instr->operands[1].constantValue() * size;
1008       return SubdwordSel(size, offset, false);
1009    } else {
1010       return SubdwordSel();
1011    }
1012 }
1013 
1014 bool
can_apply_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr,unsigned idx,ssa_info & info)1015 can_apply_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, ssa_info& info)
1016 {
1017    Temp tmp = info.instr->operands[0].getTemp();
1018    SubdwordSel sel = parse_extract(info.instr);
1019 
1020    if (!sel) {
1021       return false;
1022    } else if (sel.size() == 4) {
1023       return true;
1024    } else if ((instr->opcode == aco_opcode::v_cvt_f32_u32 ||
1025                instr->opcode == aco_opcode::v_cvt_f32_i32) &&
1026               sel.size() == 1 && !sel.sign_extend()) {
1027       return true;
1028    } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && instr->operands[0].isConstant() &&
1029               sel.offset() == 0 &&
1030               ((sel.size() == 2 && instr->operands[0].constantValue() >= 16u) ||
1031                (sel.size() == 1 && instr->operands[0].constantValue() >= 24u))) {
1032       return true;
1033    } else if (instr->opcode == aco_opcode::v_mul_u32_u24 && ctx.program->gfx_level >= GFX10 &&
1034               !instr->usesModifiers() && sel.size() == 2 && !sel.sign_extend() &&
1035               (instr->operands[!idx].is16bit() ||
1036                (instr->operands[!idx].isConstant() &&
1037                 instr->operands[!idx].constantValue() <= UINT16_MAX))) {
1038       return true;
1039    } else if (idx < 2 && can_use_SDWA(ctx.program->gfx_level, instr, true) &&
1040               (tmp.type() == RegType::vgpr || ctx.program->gfx_level >= GFX9)) {
1041       if (instr->isSDWA() && instr->sdwa().sel[idx] != SubdwordSel::dword)
1042          return false;
1043       return true;
1044    } else if (instr->isVALU() && sel.size() == 2 && !instr->valu().opsel[idx] &&
1045               can_use_opsel(ctx.program->gfx_level, instr->opcode, idx)) {
1046       return true;
1047    } else if (instr->opcode == aco_opcode::s_pack_ll_b32_b16 && sel.size() == 2 &&
1048               (idx == 1 || ctx.program->gfx_level >= GFX11 || !sel.offset())) {
1049       return true;
1050    } else if (sel.size() == 2 &&
1051               ((instr->opcode == aco_opcode::s_pack_lh_b32_b16 && idx == 0) ||
1052                (instr->opcode == aco_opcode::s_pack_hl_b32_b16 && idx == 1))) {
1053       return true;
1054    } else if (instr->opcode == aco_opcode::p_extract) {
1055       SubdwordSel instrSel = parse_extract(instr.get());
1056 
1057       /* the outer offset must be within extracted range */
1058       if (instrSel.offset() >= sel.size())
1059          return false;
1060 
1061       /* don't remove the sign-extension when increasing the size further */
1062       if (instrSel.size() > sel.size() && !instrSel.sign_extend() && sel.sign_extend())
1063          return false;
1064 
1065       return true;
1066    }
1067 
1068    return false;
1069 }
1070 
1071 /* Combine an p_extract (or p_insert, in some cases) instruction with instr.
1072  * instr(p_extract(...)) -> instr()
1073  */
1074 void
apply_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr,unsigned idx,ssa_info & info)1075 apply_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, ssa_info& info)
1076 {
1077    Temp tmp = info.instr->operands[0].getTemp();
1078    SubdwordSel sel = parse_extract(info.instr);
1079    assert(sel);
1080 
1081    instr->operands[idx].set16bit(false);
1082    instr->operands[idx].set24bit(false);
1083 
1084    ctx.info[tmp.id()].label &= ~label_insert;
1085 
1086    if (sel.size() == 4) {
1087       /* full dword selection */
1088    } else if ((instr->opcode == aco_opcode::v_cvt_f32_u32 ||
1089                instr->opcode == aco_opcode::v_cvt_f32_i32) &&
1090               sel.size() == 1 && !sel.sign_extend()) {
1091       switch (sel.offset()) {
1092       case 0: instr->opcode = aco_opcode::v_cvt_f32_ubyte0; break;
1093       case 1: instr->opcode = aco_opcode::v_cvt_f32_ubyte1; break;
1094       case 2: instr->opcode = aco_opcode::v_cvt_f32_ubyte2; break;
1095       case 3: instr->opcode = aco_opcode::v_cvt_f32_ubyte3; break;
1096       }
1097    } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && instr->operands[0].isConstant() &&
1098               sel.offset() == 0 &&
1099               ((sel.size() == 2 && instr->operands[0].constantValue() >= 16u) ||
1100                (sel.size() == 1 && instr->operands[0].constantValue() >= 24u))) {
1101       /* The undesirable upper bits are already shifted out. */
1102       return;
1103    } else if (instr->opcode == aco_opcode::v_mul_u32_u24 && ctx.program->gfx_level >= GFX10 &&
1104               !instr->usesModifiers() && sel.size() == 2 && !sel.sign_extend() &&
1105               (instr->operands[!idx].is16bit() ||
1106                instr->operands[!idx].constantValue() <= UINT16_MAX)) {
1107       Instruction* mad = create_instruction(aco_opcode::v_mad_u32_u16, Format::VOP3, 3, 1);
1108       mad->definitions[0] = instr->definitions[0];
1109       mad->operands[0] = instr->operands[0];
1110       mad->operands[1] = instr->operands[1];
1111       mad->operands[2] = Operand::zero();
1112       mad->valu().opsel[idx] = sel.offset();
1113       mad->pass_flags = instr->pass_flags;
1114       instr.reset(mad);
1115    } else if (can_use_SDWA(ctx.program->gfx_level, instr, true) &&
1116               (tmp.type() == RegType::vgpr || ctx.program->gfx_level >= GFX9)) {
1117       convert_to_SDWA(ctx.program->gfx_level, instr);
1118       instr->sdwa().sel[idx] = sel;
1119    } else if (instr->isVALU()) {
1120       if (sel.offset()) {
1121          instr->valu().opsel[idx] = true;
1122 
1123          /* VOP12C cannot use opsel with SGPRs. */
1124          if (!instr->isVOP3() && !instr->isVINTERP_INREG() &&
1125              !info.instr->operands[0].isOfType(RegType::vgpr))
1126             instr->format = asVOP3(instr->format);
1127       }
1128    } else if (instr->opcode == aco_opcode::s_pack_ll_b32_b16) {
1129       if (sel.offset())
1130          instr->opcode = idx ? aco_opcode::s_pack_lh_b32_b16 : aco_opcode::s_pack_hl_b32_b16;
1131    } else if (instr->opcode == aco_opcode::s_pack_lh_b32_b16 ||
1132               instr->opcode == aco_opcode::s_pack_hl_b32_b16) {
1133       if (sel.offset())
1134          instr->opcode = aco_opcode::s_pack_hh_b32_b16;
1135    } else if (instr->opcode == aco_opcode::p_extract) {
1136       SubdwordSel instrSel = parse_extract(instr.get());
1137 
1138       unsigned size = std::min(sel.size(), instrSel.size());
1139       unsigned offset = sel.offset() + instrSel.offset();
1140       unsigned sign_extend =
1141          instrSel.sign_extend() && (sel.sign_extend() || instrSel.size() <= sel.size());
1142 
1143       instr->operands[1] = Operand::c32(offset / size);
1144       instr->operands[2] = Operand::c32(size * 8u);
1145       instr->operands[3] = Operand::c32(sign_extend);
1146       return;
1147    }
1148 
1149    /* These are the only labels worth keeping at the moment. */
1150    for (Definition& def : instr->definitions) {
1151       ctx.info[def.tempId()].label &=
1152          (label_mul | label_minmax | label_usedef | label_vopc | label_f2f32 | instr_mod_labels);
1153       if (ctx.info[def.tempId()].label & instr_usedef_labels)
1154          ctx.info[def.tempId()].instr = instr.get();
1155    }
1156 }
1157 
1158 void
check_sdwa_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr)1159 check_sdwa_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1160 {
1161    for (unsigned i = 0; i < instr->operands.size(); i++) {
1162       Operand op = instr->operands[i];
1163       if (!op.isTemp())
1164          continue;
1165       ssa_info& info = ctx.info[op.tempId()];
1166       if (info.is_extract() && (info.instr->operands[0].getTemp().type() == RegType::vgpr ||
1167                                 op.getTemp().type() == RegType::sgpr)) {
1168          if (!can_apply_extract(ctx, instr, i, info))
1169             info.label &= ~label_extract;
1170       }
1171    }
1172 }
1173 
1174 bool
does_fp_op_flush_denorms(opt_ctx & ctx,aco_opcode op)1175 does_fp_op_flush_denorms(opt_ctx& ctx, aco_opcode op)
1176 {
1177    switch (op) {
1178    case aco_opcode::v_min_f32:
1179    case aco_opcode::v_max_f32:
1180    case aco_opcode::v_med3_f32:
1181    case aco_opcode::v_min3_f32:
1182    case aco_opcode::v_max3_f32:
1183    case aco_opcode::v_min_f16:
1184    case aco_opcode::v_max_f16: return ctx.program->gfx_level > GFX8;
1185    case aco_opcode::v_cndmask_b32:
1186    case aco_opcode::v_cndmask_b16:
1187    case aco_opcode::v_mov_b32:
1188    case aco_opcode::v_mov_b16: return false;
1189    default: return true;
1190    }
1191 }
1192 
1193 bool
can_eliminate_fcanonicalize(opt_ctx & ctx,aco_ptr<Instruction> & instr,Temp tmp,unsigned idx)1194 can_eliminate_fcanonicalize(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp tmp, unsigned idx)
1195 {
1196    float_mode* fp = &ctx.fp_mode;
1197    if (ctx.info[tmp.id()].is_canonicalized() ||
1198        (tmp.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep)
1199       return true;
1200 
1201    aco_opcode op = instr->opcode;
1202    return can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, idx) &&
1203           does_fp_op_flush_denorms(ctx, op);
1204 }
1205 
1206 bool
can_eliminate_and_exec(opt_ctx & ctx,Temp tmp,unsigned pass_flags)1207 can_eliminate_and_exec(opt_ctx& ctx, Temp tmp, unsigned pass_flags)
1208 {
1209    if (ctx.info[tmp.id()].is_vopc()) {
1210       Instruction* vopc_instr = ctx.info[tmp.id()].instr;
1211       /* Remove superfluous s_and when the VOPC instruction uses the same exec and thus
1212        * already produces the same result */
1213       return vopc_instr->pass_flags == pass_flags;
1214    }
1215    if (ctx.info[tmp.id()].is_bitwise()) {
1216       Instruction* instr = ctx.info[tmp.id()].instr;
1217       if (instr->operands.size() != 2 || instr->pass_flags != pass_flags)
1218          return false;
1219       if (!(instr->operands[0].isTemp() && instr->operands[1].isTemp()))
1220          return false;
1221       if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_and_b64) {
1222          return can_eliminate_and_exec(ctx, instr->operands[0].getTemp(), pass_flags) ||
1223                 can_eliminate_and_exec(ctx, instr->operands[1].getTemp(), pass_flags);
1224       } else {
1225          return can_eliminate_and_exec(ctx, instr->operands[0].getTemp(), pass_flags) &&
1226                 can_eliminate_and_exec(ctx, instr->operands[1].getTemp(), pass_flags);
1227       }
1228    }
1229    return false;
1230 }
1231 
1232 bool
is_copy_label(opt_ctx & ctx,aco_ptr<Instruction> & instr,ssa_info & info,unsigned idx)1233 is_copy_label(opt_ctx& ctx, aco_ptr<Instruction>& instr, ssa_info& info, unsigned idx)
1234 {
1235    return info.is_temp() ||
1236           (info.is_fcanonicalize() && can_eliminate_fcanonicalize(ctx, instr, info.temp, idx));
1237 }
1238 
1239 bool
is_op_canonicalized(opt_ctx & ctx,Operand op)1240 is_op_canonicalized(opt_ctx& ctx, Operand op)
1241 {
1242    float_mode* fp = &ctx.fp_mode;
1243    if ((op.isTemp() && ctx.info[op.tempId()].is_canonicalized()) ||
1244        (op.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep)
1245       return true;
1246 
1247    if (op.isConstant() || (op.isTemp() && ctx.info[op.tempId()].is_constant_or_literal(32))) {
1248       uint32_t val = op.isTemp() ? ctx.info[op.tempId()].val : op.constantValue();
1249       if (op.bytes() == 2)
1250          return (val & 0x7fff) == 0 || (val & 0x7fff) > 0x3ff;
1251       else if (op.bytes() == 4)
1252          return (val & 0x7fffffff) == 0 || (val & 0x7fffffff) > 0x7fffff;
1253    }
1254    return false;
1255 }
1256 
1257 bool
is_scratch_offset_valid(opt_ctx & ctx,Instruction * instr,int64_t offset0,int64_t offset1)1258 is_scratch_offset_valid(opt_ctx& ctx, Instruction* instr, int64_t offset0, int64_t offset1)
1259 {
1260    bool negative_unaligned_scratch_offset_bug = ctx.program->gfx_level == GFX10;
1261    int32_t min = ctx.program->dev.scratch_global_offset_min;
1262    int32_t max = ctx.program->dev.scratch_global_offset_max;
1263 
1264    int64_t offset = offset0 + offset1;
1265 
1266    bool has_vgpr_offset = instr && !instr->operands[0].isUndefined();
1267    if (negative_unaligned_scratch_offset_bug && has_vgpr_offset && offset < 0 && offset % 4)
1268       return false;
1269 
1270    return offset >= min && offset <= max;
1271 }
1272 
1273 bool
detect_clamp(Instruction * instr,unsigned * clamped_idx)1274 detect_clamp(Instruction* instr, unsigned* clamped_idx)
1275 {
1276    VALU_instruction& valu = instr->valu();
1277    if (valu.omod != 0 || valu.opsel != 0)
1278       return false;
1279 
1280    unsigned idx = 0;
1281    bool found_zero = false, found_one = false;
1282    bool is_fp16 = instr->opcode == aco_opcode::v_med3_f16;
1283    for (unsigned i = 0; i < 3; i++) {
1284       if (!valu.neg[i] && instr->operands[i].constantEquals(0))
1285          found_zero = true;
1286       else if (!valu.neg[i] &&
1287                instr->operands[i].constantEquals(is_fp16 ? 0x3c00 : 0x3f800000)) /* 1.0 */
1288          found_one = true;
1289       else
1290          idx = i;
1291    }
1292    if (found_zero && found_one && instr->operands[idx].isTemp()) {
1293       *clamped_idx = idx;
1294       return true;
1295    } else {
1296       return false;
1297    }
1298 }
1299 
1300 void
label_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)1301 label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1302 {
1303    if (instr->isSMEM())
1304       smem_combine(ctx, instr);
1305 
1306    for (unsigned i = 0; i < instr->operands.size(); i++) {
1307       if (!instr->operands[i].isTemp())
1308          continue;
1309 
1310       ssa_info info = ctx.info[instr->operands[i].tempId()];
1311       /* propagate reg->reg of same type */
1312       while (info.is_temp() && info.temp.regClass() == instr->operands[i].getTemp().regClass()) {
1313          instr->operands[i].setTemp(ctx.info[instr->operands[i].tempId()].temp);
1314          info = ctx.info[info.temp.id()];
1315       }
1316 
1317       /* PSEUDO: propagate temporaries */
1318       if (instr->isPseudo()) {
1319          while (info.is_temp()) {
1320             pseudo_propagate_temp(ctx, instr, info.temp, i);
1321             info = ctx.info[info.temp.id()];
1322          }
1323       }
1324 
1325       /* SALU / PSEUDO: propagate inline constants */
1326       if (instr->isSALU() || instr->isPseudo()) {
1327          unsigned bits = get_operand_size(instr, i);
1328          if ((info.is_constant(bits) || (info.is_literal(bits) && instr->isPseudo())) &&
1329              alu_can_accept_constant(instr, i)) {
1330             instr->operands[i] = get_constant_op(ctx, info, bits);
1331             continue;
1332          }
1333       }
1334 
1335       /* VALU: propagate neg, abs & inline constants */
1336       else if (instr->isVALU()) {
1337          if (is_copy_label(ctx, instr, info, i) && info.temp.type() == RegType::vgpr &&
1338              valu_can_accept_vgpr(instr, i)) {
1339             instr->operands[i].setTemp(info.temp);
1340             info = ctx.info[info.temp.id()];
1341          }
1342          /* applying SGPRs to VOP1 doesn't increase code size and DCE is helped by doing it earlier */
1343          if (info.is_temp() && info.temp.type() == RegType::sgpr && can_apply_sgprs(ctx, instr) &&
1344              instr->operands.size() == 1) {
1345             instr->format = withoutDPP(instr->format);
1346             instr->operands[i].setTemp(info.temp);
1347             info = ctx.info[info.temp.id()];
1348          }
1349 
1350          /* for instructions other than v_cndmask_b32, the size of the instruction should match the
1351           * operand size */
1352          bool can_use_mod =
1353             instr->opcode != aco_opcode::v_cndmask_b32 || instr->operands[i].getTemp().bytes() == 4;
1354          can_use_mod &= can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, i);
1355 
1356          bool packed_math = instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_mix_f32 &&
1357                             instr->opcode != aco_opcode::v_fma_mixlo_f16 &&
1358                             instr->opcode != aco_opcode::v_fma_mixhi_f16;
1359 
1360          if (instr->isSDWA())
1361             can_use_mod &= instr->sdwa().sel[i].size() == 4;
1362          else if (instr->isVOP3P())
1363             can_use_mod &= !packed_math || !info.is_abs();
1364          else if (instr->isVINTERP_INREG())
1365             can_use_mod &= !info.is_abs();
1366          else
1367             can_use_mod &= instr->isDPP16() || can_use_VOP3(ctx, instr);
1368 
1369          unsigned bits = get_operand_size(instr, i);
1370          can_use_mod &= instr->operands[i].bytes() * 8 == bits;
1371 
1372          if (info.is_neg() && can_use_mod &&
1373              can_eliminate_fcanonicalize(ctx, instr, info.temp, i)) {
1374             instr->operands[i].setTemp(info.temp);
1375             if (!packed_math && instr->valu().abs[i]) {
1376                /* fabs(fneg(a)) -> fabs(a) */
1377             } else if (instr->opcode == aco_opcode::v_add_f32) {
1378                instr->opcode = i ? aco_opcode::v_sub_f32 : aco_opcode::v_subrev_f32;
1379             } else if (instr->opcode == aco_opcode::v_add_f16) {
1380                instr->opcode = i ? aco_opcode::v_sub_f16 : aco_opcode::v_subrev_f16;
1381             } else if (packed_math) {
1382                /* Bit size compat should ensure this. */
1383                assert(!instr->valu().opsel_lo[i] && !instr->valu().opsel_hi[i]);
1384                instr->valu().neg_lo[i] ^= true;
1385                instr->valu().neg_hi[i] ^= true;
1386             } else {
1387                if (!instr->isDPP16() && can_use_VOP3(ctx, instr))
1388                   instr->format = asVOP3(instr->format);
1389                instr->valu().neg[i] ^= true;
1390             }
1391          }
1392          if (info.is_abs() && can_use_mod &&
1393              can_eliminate_fcanonicalize(ctx, instr, info.temp, i)) {
1394             if (!instr->isDPP16() && can_use_VOP3(ctx, instr))
1395                instr->format = asVOP3(instr->format);
1396             instr->operands[i] = Operand(info.temp);
1397             instr->valu().abs[i] = true;
1398             continue;
1399          }
1400 
1401          if (instr->isVOP3P()) {
1402             propagate_constants_vop3p(ctx, instr, info, i);
1403             continue;
1404          }
1405 
1406          if (info.is_constant(bits) && alu_can_accept_constant(instr, i) &&
1407              (!instr->isSDWA() || ctx.program->gfx_level >= GFX9) && (!instr->isDPP() || i != 1)) {
1408             Operand op = get_constant_op(ctx, info, bits);
1409             if (i == 0 || instr->isSDWA() || instr->opcode == aco_opcode::v_readlane_b32 ||
1410                 instr->opcode == aco_opcode::v_writelane_b32) {
1411                instr->format = withoutDPP(instr->format);
1412                instr->operands[i] = op;
1413                continue;
1414             } else if (!instr->isVOP3() && can_swap_operands(instr, &instr->opcode)) {
1415                instr->operands[i] = op;
1416                instr->valu().swapOperands(0, i);
1417                continue;
1418             } else if (can_use_VOP3(ctx, instr)) {
1419                instr->format = asVOP3(instr->format);
1420                instr->operands[i] = op;
1421                continue;
1422             }
1423          }
1424       }
1425 
1426       /* MUBUF: propagate constants and combine additions */
1427       else if (instr->isMUBUF()) {
1428          MUBUF_instruction& mubuf = instr->mubuf();
1429          Temp base;
1430          uint32_t offset;
1431          while (info.is_temp())
1432             info = ctx.info[info.temp.id()];
1433 
1434          bool swizzled = ctx.program->gfx_level >= GFX12 ? mubuf.cache.gfx12.swizzled
1435                                                          : (mubuf.cache.value & ac_swizzled);
1436          /* According to AMDGPUDAGToDAGISel::SelectMUBUFScratchOffen(), vaddr
1437           * overflow for scratch accesses works only on GFX9+ and saddr overflow
1438           * never works. Since swizzling is the only thing that separates
1439           * scratch accesses and other accesses and swizzling changing how
1440           * addressing works significantly, this probably applies to swizzled
1441           * MUBUF accesses. */
1442          bool vaddr_prevent_overflow = swizzled && ctx.program->gfx_level < GFX9;
1443 
1444          if (mubuf.offen && mubuf.idxen && i == 1 && info.is_vec() &&
1445              info.instr->operands.size() == 2 && info.instr->operands[0].isTemp() &&
1446              info.instr->operands[0].regClass() == v1 && info.instr->operands[1].isConstant() &&
1447              mubuf.offset + info.instr->operands[1].constantValue() < 4096) {
1448             instr->operands[1] = info.instr->operands[0];
1449             mubuf.offset += info.instr->operands[1].constantValue();
1450             mubuf.offen = false;
1451             continue;
1452          } else if (mubuf.offen && i == 1 && info.is_constant_or_literal(32) &&
1453                     mubuf.offset + info.val < 4096) {
1454             assert(!mubuf.idxen);
1455             instr->operands[1] = Operand(v1);
1456             mubuf.offset += info.val;
1457             mubuf.offen = false;
1458             continue;
1459          } else if (i == 2 && info.is_constant_or_literal(32) && mubuf.offset + info.val < 4096) {
1460             instr->operands[2] = Operand::c32(0);
1461             mubuf.offset += info.val;
1462             continue;
1463          } else if (mubuf.offen && i == 1 &&
1464                     parse_base_offset(ctx, instr.get(), i, &base, &offset,
1465                                       vaddr_prevent_overflow) &&
1466                     base.regClass() == v1 && mubuf.offset + offset < 4096) {
1467             assert(!mubuf.idxen);
1468             instr->operands[1].setTemp(base);
1469             mubuf.offset += offset;
1470             continue;
1471          } else if (i == 2 && parse_base_offset(ctx, instr.get(), i, &base, &offset, true) &&
1472                     base.regClass() == s1 && mubuf.offset + offset < 4096 && !swizzled) {
1473             instr->operands[i].setTemp(base);
1474             mubuf.offset += offset;
1475             continue;
1476          }
1477       }
1478 
1479       else if (instr->isMTBUF()) {
1480          MTBUF_instruction& mtbuf = instr->mtbuf();
1481          while (info.is_temp())
1482             info = ctx.info[info.temp.id()];
1483 
1484          if (mtbuf.offen && mtbuf.idxen && i == 1 && info.is_vec() &&
1485              info.instr->operands.size() == 2 && info.instr->operands[0].isTemp() &&
1486              info.instr->operands[0].regClass() == v1 && info.instr->operands[1].isConstant() &&
1487              mtbuf.offset + info.instr->operands[1].constantValue() < 4096) {
1488             instr->operands[1] = info.instr->operands[0];
1489             mtbuf.offset += info.instr->operands[1].constantValue();
1490             mtbuf.offen = false;
1491             continue;
1492          }
1493       }
1494 
1495       /* SCRATCH: propagate constants and combine additions */
1496       else if (instr->isScratch()) {
1497          FLAT_instruction& scratch = instr->scratch();
1498          Temp base;
1499          uint32_t offset;
1500          while (info.is_temp())
1501             info = ctx.info[info.temp.id()];
1502 
1503          /* The hardware probably does: 'scratch_base + u2u64(saddr) + i2i64(offset)'. This means
1504           * we can't combine the addition if the unsigned addition overflows and offset is
1505           * positive. In theory, there is also issues if
1506           * 'ilt(offset, 0) && ige(saddr, 0) && ilt(saddr + offset, 0)', but that just
1507           * replaces an already out-of-bounds access with a larger one since 'saddr + offset'
1508           * would be larger than INT32_MAX.
1509           */
1510          if (i <= 1 && parse_base_offset(ctx, instr.get(), i, &base, &offset, true) &&
1511              base.regClass() == instr->operands[i].regClass() &&
1512              is_scratch_offset_valid(ctx, instr.get(), scratch.offset, (int32_t)offset)) {
1513             instr->operands[i].setTemp(base);
1514             scratch.offset += (int32_t)offset;
1515             continue;
1516          } else if (i <= 1 && parse_base_offset(ctx, instr.get(), i, &base, &offset, false) &&
1517                     base.regClass() == instr->operands[i].regClass() && (int32_t)offset < 0 &&
1518                     is_scratch_offset_valid(ctx, instr.get(), scratch.offset, (int32_t)offset)) {
1519             instr->operands[i].setTemp(base);
1520             scratch.offset += (int32_t)offset;
1521             continue;
1522          } else if (i <= 1 && info.is_constant_or_literal(32) &&
1523                     ctx.program->gfx_level >= GFX10_3 &&
1524                     is_scratch_offset_valid(ctx, NULL, scratch.offset, (int32_t)info.val)) {
1525             /* GFX10.3+ can disable both SADDR and ADDR. */
1526             instr->operands[i] = Operand(instr->operands[i].regClass());
1527             scratch.offset += (int32_t)info.val;
1528             continue;
1529          }
1530       }
1531 
1532       /* DS: combine additions */
1533       else if (instr->isDS()) {
1534 
1535          DS_instruction& ds = instr->ds();
1536          Temp base;
1537          uint32_t offset;
1538          bool has_usable_ds_offset = ctx.program->gfx_level >= GFX7;
1539          if (has_usable_ds_offset && i == 0 &&
1540              parse_base_offset(ctx, instr.get(), i, &base, &offset, false) &&
1541              base.regClass() == instr->operands[i].regClass() &&
1542              instr->opcode != aco_opcode::ds_swizzle_b32) {
1543             if (instr->opcode == aco_opcode::ds_write2_b32 ||
1544                 instr->opcode == aco_opcode::ds_read2_b32 ||
1545                 instr->opcode == aco_opcode::ds_write2_b64 ||
1546                 instr->opcode == aco_opcode::ds_read2_b64 ||
1547                 instr->opcode == aco_opcode::ds_write2st64_b32 ||
1548                 instr->opcode == aco_opcode::ds_read2st64_b32 ||
1549                 instr->opcode == aco_opcode::ds_write2st64_b64 ||
1550                 instr->opcode == aco_opcode::ds_read2st64_b64) {
1551                bool is64bit = instr->opcode == aco_opcode::ds_write2_b64 ||
1552                               instr->opcode == aco_opcode::ds_read2_b64 ||
1553                               instr->opcode == aco_opcode::ds_write2st64_b64 ||
1554                               instr->opcode == aco_opcode::ds_read2st64_b64;
1555                bool st64 = instr->opcode == aco_opcode::ds_write2st64_b32 ||
1556                            instr->opcode == aco_opcode::ds_read2st64_b32 ||
1557                            instr->opcode == aco_opcode::ds_write2st64_b64 ||
1558                            instr->opcode == aco_opcode::ds_read2st64_b64;
1559                unsigned shifts = (is64bit ? 3 : 2) + (st64 ? 6 : 0);
1560                unsigned mask = BITFIELD_MASK(shifts);
1561 
1562                if ((offset & mask) == 0 && ds.offset0 + (offset >> shifts) <= 255 &&
1563                    ds.offset1 + (offset >> shifts) <= 255) {
1564                   instr->operands[i].setTemp(base);
1565                   ds.offset0 += offset >> shifts;
1566                   ds.offset1 += offset >> shifts;
1567                }
1568             } else {
1569                if (ds.offset0 + offset <= 65535) {
1570                   instr->operands[i].setTemp(base);
1571                   ds.offset0 += offset;
1572                }
1573             }
1574          }
1575       }
1576 
1577       else if (instr->isBranch()) {
1578          if (ctx.info[instr->operands[0].tempId()].is_scc_invert()) {
1579             /* Flip the branch instruction to get rid of the scc_invert instruction */
1580             instr->opcode = instr->opcode == aco_opcode::p_cbranch_z ? aco_opcode::p_cbranch_nz
1581                                                                      : aco_opcode::p_cbranch_z;
1582             instr->operands[0].setTemp(ctx.info[instr->operands[0].tempId()].temp);
1583          }
1584       }
1585    }
1586 
1587    /* if this instruction doesn't define anything, return */
1588    if (instr->definitions.empty()) {
1589       check_sdwa_extract(ctx, instr);
1590       return;
1591    }
1592 
1593    if (instr->isVALU() || instr->isVINTRP()) {
1594       if (instr_info.can_use_output_modifiers[(int)instr->opcode] || instr->isVINTRP() ||
1595           instr->opcode == aco_opcode::v_cndmask_b32) {
1596          bool canonicalized = true;
1597          if (!does_fp_op_flush_denorms(ctx, instr->opcode)) {
1598             unsigned ops = instr->opcode == aco_opcode::v_cndmask_b32 ? 2 : instr->operands.size();
1599             for (unsigned i = 0; canonicalized && (i < ops); i++)
1600                canonicalized = is_op_canonicalized(ctx, instr->operands[i]);
1601          }
1602          if (canonicalized)
1603             ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1604       }
1605 
1606       if (instr->isVOPC()) {
1607          ctx.info[instr->definitions[0].tempId()].set_vopc(instr.get());
1608          check_sdwa_extract(ctx, instr);
1609          return;
1610       }
1611       if (instr->isVOP3P()) {
1612          ctx.info[instr->definitions[0].tempId()].set_vop3p(instr.get());
1613          return;
1614       }
1615    }
1616 
1617    switch (instr->opcode) {
1618    case aco_opcode::p_create_vector: {
1619       bool copy_prop = instr->operands.size() == 1 && instr->operands[0].isTemp() &&
1620                        instr->operands[0].regClass() == instr->definitions[0].regClass();
1621       if (copy_prop) {
1622          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1623          break;
1624       }
1625 
1626       /* expand vector operands */
1627       std::vector<Operand> ops;
1628       unsigned offset = 0;
1629       for (const Operand& op : instr->operands) {
1630          /* ensure that any expanded operands are properly aligned */
1631          bool aligned = offset % 4 == 0 || op.bytes() < 4;
1632          offset += op.bytes();
1633          if (aligned && op.isTemp() && ctx.info[op.tempId()].is_vec()) {
1634             Instruction* vec = ctx.info[op.tempId()].instr;
1635             for (const Operand& vec_op : vec->operands)
1636                ops.emplace_back(vec_op);
1637          } else {
1638             ops.emplace_back(op);
1639          }
1640       }
1641 
1642       /* combine expanded operands to new vector */
1643       if (ops.size() != instr->operands.size()) {
1644          assert(ops.size() > instr->operands.size());
1645          Definition def = instr->definitions[0];
1646          instr.reset(
1647             create_instruction(aco_opcode::p_create_vector, Format::PSEUDO, ops.size(), 1));
1648          for (unsigned i = 0; i < ops.size(); i++) {
1649             if (ops[i].isTemp() && ctx.info[ops[i].tempId()].is_temp() &&
1650                 ops[i].regClass() == ctx.info[ops[i].tempId()].temp.regClass())
1651                ops[i].setTemp(ctx.info[ops[i].tempId()].temp);
1652             instr->operands[i] = ops[i];
1653          }
1654          instr->definitions[0] = def;
1655       } else {
1656          for (unsigned i = 0; i < ops.size(); i++) {
1657             assert(instr->operands[i] == ops[i]);
1658          }
1659       }
1660       ctx.info[instr->definitions[0].tempId()].set_vec(instr.get());
1661 
1662       if (instr->operands.size() == 2) {
1663          /* check if this is created from split_vector */
1664          if (instr->operands[1].isTemp() && ctx.info[instr->operands[1].tempId()].is_split()) {
1665             Instruction* split = ctx.info[instr->operands[1].tempId()].instr;
1666             if (instr->operands[0].isTemp() &&
1667                 instr->operands[0].getTemp() == split->definitions[0].getTemp())
1668                ctx.info[instr->definitions[0].tempId()].set_temp(split->operands[0].getTemp());
1669          }
1670       }
1671       break;
1672    }
1673    case aco_opcode::p_split_vector: {
1674       ssa_info& info = ctx.info[instr->operands[0].tempId()];
1675 
1676       if (info.is_constant_or_literal(32)) {
1677          uint64_t val = info.val;
1678          for (Definition def : instr->definitions) {
1679             uint32_t mask = u_bit_consecutive(0, def.bytes() * 8u);
1680             ctx.info[def.tempId()].set_constant(ctx.program->gfx_level, val & mask);
1681             val >>= def.bytes() * 8u;
1682          }
1683          break;
1684       } else if (!info.is_vec()) {
1685          if (instr->definitions.size() == 2 && instr->operands[0].isTemp() &&
1686              instr->definitions[0].bytes() == instr->definitions[1].bytes()) {
1687             ctx.info[instr->definitions[1].tempId()].set_split(instr.get());
1688             if (instr->operands[0].bytes() == 4) {
1689                /* D16 subdword split */
1690                ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1691                ctx.info[instr->definitions[1].tempId()].set_extract(instr.get());
1692             }
1693          }
1694          break;
1695       }
1696 
1697       Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
1698       unsigned split_offset = 0;
1699       unsigned vec_offset = 0;
1700       unsigned vec_index = 0;
1701       for (unsigned i = 0; i < instr->definitions.size();
1702            split_offset += instr->definitions[i++].bytes()) {
1703          while (vec_offset < split_offset && vec_index < vec->operands.size())
1704             vec_offset += vec->operands[vec_index++].bytes();
1705 
1706          if (vec_offset != split_offset ||
1707              vec->operands[vec_index].bytes() != instr->definitions[i].bytes())
1708             continue;
1709 
1710          Operand vec_op = vec->operands[vec_index];
1711          if (vec_op.isConstant()) {
1712             ctx.info[instr->definitions[i].tempId()].set_constant(ctx.program->gfx_level,
1713                                                                   vec_op.constantValue64());
1714          } else if (vec_op.isTemp()) {
1715             ctx.info[instr->definitions[i].tempId()].set_temp(vec_op.getTemp());
1716          }
1717       }
1718       break;
1719    }
1720    case aco_opcode::p_extract_vector: { /* mov */
1721       const unsigned index = instr->operands[1].constantValue();
1722 
1723       if (instr->operands[0].isTemp()) {
1724          ssa_info& info = ctx.info[instr->operands[0].tempId()];
1725          const unsigned dst_offset = index * instr->definitions[0].bytes();
1726 
1727          if (info.is_vec()) {
1728             /* check if we index directly into a vector element */
1729             Instruction* vec = info.instr;
1730             unsigned offset = 0;
1731 
1732             for (const Operand& op : vec->operands) {
1733                if (offset < dst_offset) {
1734                   offset += op.bytes();
1735                   continue;
1736                } else if (offset != dst_offset || op.bytes() != instr->definitions[0].bytes()) {
1737                   break;
1738                }
1739                instr->operands[0] = op;
1740                break;
1741             }
1742          } else if (info.is_constant_or_literal(32)) {
1743             /* propagate constants */
1744             uint32_t mask = u_bit_consecutive(0, instr->definitions[0].bytes() * 8u);
1745             uint32_t val = (info.val >> (dst_offset * 8u)) & mask;
1746             instr->operands[0] =
1747                Operand::get_const(ctx.program->gfx_level, val, instr->definitions[0].bytes());
1748             ;
1749          }
1750       }
1751 
1752       if (instr->operands[0].bytes() != instr->definitions[0].bytes()) {
1753          if (instr->operands[0].size() != 1 || !instr->operands[0].isTemp())
1754             break;
1755 
1756          if (index == 0)
1757             ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1758          else
1759             ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
1760          break;
1761       }
1762 
1763       /* convert this extract into a copy instruction */
1764       instr->opcode = aco_opcode::p_parallelcopy;
1765       instr->operands.pop_back();
1766       FALLTHROUGH;
1767    }
1768    case aco_opcode::p_parallelcopy: /* propagate */
1769       if (instr->operands[0].isTemp() && ctx.info[instr->operands[0].tempId()].is_vec() &&
1770           instr->operands[0].regClass() != instr->definitions[0].regClass()) {
1771          /* We might not be able to copy-propagate if it's a SGPR->VGPR copy, so
1772           * duplicate the vector instead.
1773           */
1774          Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
1775          aco_ptr<Instruction> old_copy = std::move(instr);
1776 
1777          instr.reset(create_instruction(aco_opcode::p_create_vector, Format::PSEUDO,
1778                                         vec->operands.size(), 1));
1779          instr->definitions[0] = old_copy->definitions[0];
1780          std::copy(vec->operands.begin(), vec->operands.end(), instr->operands.begin());
1781          for (unsigned i = 0; i < vec->operands.size(); i++) {
1782             Operand& op = instr->operands[i];
1783             if (op.isTemp() && ctx.info[op.tempId()].is_temp() &&
1784                 ctx.info[op.tempId()].temp.type() == instr->definitions[0].regClass().type())
1785                op.setTemp(ctx.info[op.tempId()].temp);
1786          }
1787          ctx.info[instr->definitions[0].tempId()].set_vec(instr.get());
1788          break;
1789       }
1790       FALLTHROUGH;
1791    case aco_opcode::p_as_uniform:
1792       if (instr->definitions[0].isFixed()) {
1793          /* don't copy-propagate copies into fixed registers */
1794       } else if (instr->operands[0].isConstant()) {
1795          ctx.info[instr->definitions[0].tempId()].set_constant(
1796             ctx.program->gfx_level, instr->operands[0].constantValue64());
1797       } else if (instr->operands[0].isTemp()) {
1798          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1799          if (ctx.info[instr->operands[0].tempId()].is_canonicalized())
1800             ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1801       } else {
1802          assert(instr->operands[0].isFixed());
1803       }
1804       break;
1805    case aco_opcode::v_mov_b32:
1806       if (instr->isDPP16()) {
1807          /* anything else doesn't make sense in SSA */
1808          assert(instr->dpp16().row_mask == 0xf && instr->dpp16().bank_mask == 0xf);
1809          ctx.info[instr->definitions[0].tempId()].set_dpp16(instr.get());
1810       } else if (instr->isDPP8()) {
1811          ctx.info[instr->definitions[0].tempId()].set_dpp8(instr.get());
1812       }
1813       break;
1814    case aco_opcode::p_is_helper:
1815       if (!ctx.program->needs_wqm)
1816          ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->gfx_level, 0u);
1817       break;
1818    case aco_opcode::v_mul_f64_e64:
1819    case aco_opcode::v_mul_f64: ctx.info[instr->definitions[0].tempId()].set_mul(instr.get()); break;
1820    case aco_opcode::v_mul_f16:
1821    case aco_opcode::v_mul_f32:
1822    case aco_opcode::v_mul_legacy_f32: { /* omod */
1823       ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
1824 
1825       /* TODO: try to move the negate/abs modifier to the consumer instead */
1826       bool uses_mods = instr->usesModifiers();
1827       bool fp16 = instr->opcode == aco_opcode::v_mul_f16;
1828 
1829       for (unsigned i = 0; i < 2; i++) {
1830          if (instr->operands[!i].isConstant() && instr->operands[i].isTemp()) {
1831             if (!instr->isDPP() && !instr->isSDWA() && !instr->valu().opsel &&
1832                 (instr->operands[!i].constantEquals(fp16 ? 0x3c00 : 0x3f800000) ||   /* 1.0 */
1833                  instr->operands[!i].constantEquals(fp16 ? 0xbc00 : 0xbf800000u))) { /* -1.0 */
1834                bool neg1 = instr->operands[!i].constantEquals(fp16 ? 0xbc00 : 0xbf800000u);
1835 
1836                VALU_instruction* valu = &instr->valu();
1837                if (valu->abs[!i] || valu->neg[!i] || valu->omod)
1838                   continue;
1839 
1840                bool abs = valu->abs[i];
1841                bool neg = neg1 ^ valu->neg[i];
1842                Temp other = instr->operands[i].getTemp();
1843 
1844                if (valu->clamp) {
1845                   if (!abs && !neg && other.type() == RegType::vgpr)
1846                      ctx.info[other.id()].set_clamp(instr.get());
1847                   continue;
1848                }
1849 
1850                if (abs && neg && other.type() == RegType::vgpr)
1851                   ctx.info[instr->definitions[0].tempId()].set_neg_abs(other);
1852                else if (abs && !neg && other.type() == RegType::vgpr)
1853                   ctx.info[instr->definitions[0].tempId()].set_abs(other);
1854                else if (!abs && neg && other.type() == RegType::vgpr)
1855                   ctx.info[instr->definitions[0].tempId()].set_neg(other);
1856                else if (!abs && !neg)
1857                   ctx.info[instr->definitions[0].tempId()].set_fcanonicalize(other);
1858             } else if (uses_mods || (instr->definitions[0].isSZPreserve() &&
1859                                      instr->opcode != aco_opcode::v_mul_legacy_f32)) {
1860                continue; /* omod uses a legacy multiplication. */
1861             } else if (instr->operands[!i].constantValue() == 0u &&
1862                        ((!instr->definitions[0].isNaNPreserve() &&
1863                          !instr->definitions[0].isInfPreserve()) ||
1864                         instr->opcode == aco_opcode::v_mul_legacy_f32)) { /* 0.0 */
1865                ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->gfx_level, 0u);
1866             } else if ((fp16 ? ctx.fp_mode.denorm16_64 : ctx.fp_mode.denorm32) != fp_denorm_flush) {
1867                /* omod has no effect if denormals are enabled. */
1868                continue;
1869             } else if (instr->operands[!i].constantValue() ==
1870                        (fp16 ? 0x4000 : 0x40000000)) { /* 2.0 */
1871                ctx.info[instr->operands[i].tempId()].set_omod2(instr.get());
1872             } else if (instr->operands[!i].constantValue() ==
1873                        (fp16 ? 0x4400 : 0x40800000)) { /* 4.0 */
1874                ctx.info[instr->operands[i].tempId()].set_omod4(instr.get());
1875             } else if (instr->operands[!i].constantValue() ==
1876                        (fp16 ? 0x3800 : 0x3f000000)) { /* 0.5 */
1877                ctx.info[instr->operands[i].tempId()].set_omod5(instr.get());
1878             } else {
1879                continue;
1880             }
1881             break;
1882          }
1883       }
1884       break;
1885    }
1886    case aco_opcode::v_mul_lo_u16:
1887    case aco_opcode::v_mul_lo_u16_e64:
1888    case aco_opcode::v_mul_u32_u24:
1889       ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
1890       break;
1891    case aco_opcode::v_med3_f16:
1892    case aco_opcode::v_med3_f32: { /* clamp */
1893       unsigned idx;
1894       if (detect_clamp(instr.get(), &idx) && !instr->valu().abs && !instr->valu().neg)
1895          ctx.info[instr->operands[idx].tempId()].set_clamp(instr.get());
1896       break;
1897    }
1898    case aco_opcode::v_cndmask_b32:
1899       if (instr->operands[0].constantEquals(0) && instr->operands[1].constantEquals(0x3f800000u))
1900          ctx.info[instr->definitions[0].tempId()].set_b2f(instr->operands[2].getTemp());
1901       else if (instr->operands[0].constantEquals(0) && instr->operands[1].constantEquals(1))
1902          ctx.info[instr->definitions[0].tempId()].set_b2i(instr->operands[2].getTemp());
1903 
1904       break;
1905    case aco_opcode::v_add_u32:
1906    case aco_opcode::v_add_co_u32:
1907    case aco_opcode::v_add_co_u32_e64:
1908    case aco_opcode::s_add_i32:
1909    case aco_opcode::s_add_u32:
1910    case aco_opcode::v_subbrev_co_u32:
1911    case aco_opcode::v_sub_u32:
1912    case aco_opcode::v_sub_i32:
1913    case aco_opcode::v_sub_co_u32:
1914    case aco_opcode::v_sub_co_u32_e64:
1915    case aco_opcode::s_sub_u32:
1916    case aco_opcode::s_sub_i32:
1917    case aco_opcode::v_subrev_u32:
1918    case aco_opcode::v_subrev_co_u32:
1919    case aco_opcode::v_subrev_co_u32_e64:
1920       ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
1921       break;
1922    case aco_opcode::s_not_b32:
1923    case aco_opcode::s_not_b64:
1924       if (!instr->operands[0].isTemp()) {
1925       } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
1926          ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
1927          ctx.info[instr->definitions[1].tempId()].set_scc_invert(
1928             ctx.info[instr->operands[0].tempId()].temp);
1929       } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bitwise()) {
1930          ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
1931          ctx.info[instr->definitions[1].tempId()].set_scc_invert(
1932             ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
1933       }
1934       ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
1935       break;
1936    case aco_opcode::s_and_b32:
1937    case aco_opcode::s_and_b64:
1938       if (fixed_to_exec(instr->operands[1]) && instr->operands[0].isTemp()) {
1939          if (ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
1940             /* Try to get rid of the superfluous s_cselect + s_and_b64 that comes from turning a
1941              * uniform bool into divergent */
1942             ctx.info[instr->definitions[1].tempId()].set_temp(
1943                ctx.info[instr->operands[0].tempId()].temp);
1944             ctx.info[instr->definitions[0].tempId()].set_uniform_bool(
1945                ctx.info[instr->operands[0].tempId()].temp);
1946             break;
1947          } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bitwise()) {
1948             /* Try to get rid of the superfluous s_and_b64, since the uniform bitwise instruction
1949              * already produces the same SCC */
1950             ctx.info[instr->definitions[1].tempId()].set_temp(
1951                ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
1952             ctx.info[instr->definitions[0].tempId()].set_uniform_bool(
1953                ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
1954             break;
1955          } else if ((ctx.program->stage.num_sw_stages() > 1 ||
1956                      ctx.program->stage.hw == AC_HW_NEXT_GEN_GEOMETRY_SHADER) &&
1957                     instr->pass_flags == 1) {
1958             /* In case of merged shaders, pass_flags=1 means that all lanes are active (exec=-1), so
1959              * s_and is unnecessary. */
1960             ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1961             break;
1962          }
1963       }
1964       FALLTHROUGH;
1965    case aco_opcode::s_or_b32:
1966    case aco_opcode::s_or_b64:
1967    case aco_opcode::s_xor_b32:
1968    case aco_opcode::s_xor_b64:
1969       if (std::all_of(instr->operands.begin(), instr->operands.end(),
1970                       [&ctx](const Operand& op)
1971                       {
1972                          return op.isTemp() && (ctx.info[op.tempId()].is_uniform_bool() ||
1973                                                 ctx.info[op.tempId()].is_uniform_bitwise());
1974                       })) {
1975          ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
1976       }
1977       ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
1978       break;
1979    case aco_opcode::s_lshl_b32:
1980    case aco_opcode::v_or_b32:
1981    case aco_opcode::v_lshlrev_b32:
1982    case aco_opcode::v_bcnt_u32_b32:
1983    case aco_opcode::v_and_b32:
1984    case aco_opcode::v_xor_b32:
1985    case aco_opcode::v_not_b32:
1986       ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
1987       break;
1988    case aco_opcode::v_min_f32:
1989    case aco_opcode::v_min_f16:
1990    case aco_opcode::v_min_u32:
1991    case aco_opcode::v_min_i32:
1992    case aco_opcode::v_min_u16:
1993    case aco_opcode::v_min_i16:
1994    case aco_opcode::v_min_u16_e64:
1995    case aco_opcode::v_min_i16_e64:
1996    case aco_opcode::v_max_f32:
1997    case aco_opcode::v_max_f16:
1998    case aco_opcode::v_max_u32:
1999    case aco_opcode::v_max_i32:
2000    case aco_opcode::v_max_u16:
2001    case aco_opcode::v_max_i16:
2002    case aco_opcode::v_max_u16_e64:
2003    case aco_opcode::v_max_i16_e64:
2004       ctx.info[instr->definitions[0].tempId()].set_minmax(instr.get());
2005       break;
2006    case aco_opcode::s_cselect_b64:
2007    case aco_opcode::s_cselect_b32:
2008       if (instr->operands[0].constantEquals((unsigned)-1) && instr->operands[1].constantEquals(0)) {
2009          /* Found a cselect that operates on a uniform bool that comes from eg. s_cmp */
2010          ctx.info[instr->definitions[0].tempId()].set_uniform_bool(instr->operands[2].getTemp());
2011       }
2012       if (instr->operands[2].isTemp() && ctx.info[instr->operands[2].tempId()].is_scc_invert()) {
2013          /* Flip the operands to get rid of the scc_invert instruction */
2014          std::swap(instr->operands[0], instr->operands[1]);
2015          instr->operands[2].setTemp(ctx.info[instr->operands[2].tempId()].temp);
2016       }
2017       break;
2018    case aco_opcode::s_mul_i32:
2019       /* Testing every uint32_t shows that 0x3f800000*n is never a denormal.
2020        * This pattern is created from a uniform nir_op_b2f. */
2021       if (instr->operands[0].constantEquals(0x3f800000u))
2022          ctx.info[instr->definitions[0].tempId()].set_canonicalized();
2023       break;
2024    case aco_opcode::p_extract: {
2025       if (instr->definitions[0].bytes() == 4 && instr->operands[0].isTemp()) {
2026          ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
2027          if (instr->operands[0].regClass() == v1 && parse_insert(instr.get()))
2028             ctx.info[instr->operands[0].tempId()].set_insert(instr.get());
2029       }
2030       break;
2031    }
2032    case aco_opcode::p_insert: {
2033       if (instr->operands[0].bytes() == 4 && instr->operands[0].isTemp()) {
2034          if (instr->operands[0].regClass() == v1)
2035             ctx.info[instr->operands[0].tempId()].set_insert(instr.get());
2036          if (parse_extract(instr.get()))
2037             ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
2038          ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
2039       }
2040       break;
2041    }
2042    case aco_opcode::ds_read_u8:
2043    case aco_opcode::ds_read_u8_d16:
2044    case aco_opcode::ds_read_u16:
2045    case aco_opcode::ds_read_u16_d16: {
2046       ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
2047       break;
2048    }
2049    case aco_opcode::v_cvt_f16_f32: {
2050       if (instr->operands[0].isTemp()) {
2051          ssa_info& info = ctx.info[instr->operands[0].tempId()];
2052          if (!info.is_dpp() || info.instr->pass_flags != instr->pass_flags)
2053             info.set_f2f16(instr.get());
2054       }
2055       break;
2056    }
2057    case aco_opcode::v_cvt_f32_f16: {
2058       if (instr->operands[0].isTemp())
2059          ctx.info[instr->definitions[0].tempId()].set_f2f32(instr.get());
2060       break;
2061    }
2062    default: break;
2063    }
2064 
2065    /* Don't remove label_extract if we can't apply the extract to
2066     * neg/abs instructions because we'll likely combine it into another valu. */
2067    if (!(ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs)))
2068       check_sdwa_extract(ctx, instr);
2069 }
2070 
2071 unsigned
original_temp_id(opt_ctx & ctx,Temp tmp)2072 original_temp_id(opt_ctx& ctx, Temp tmp)
2073 {
2074    if (ctx.info[tmp.id()].is_temp())
2075       return ctx.info[tmp.id()].temp.id();
2076    else
2077       return tmp.id();
2078 }
2079 
2080 void
decrease_op_uses_if_dead(opt_ctx & ctx,Instruction * instr)2081 decrease_op_uses_if_dead(opt_ctx& ctx, Instruction* instr)
2082 {
2083    if (is_dead(ctx.uses, instr)) {
2084       for (const Operand& op : instr->operands) {
2085          if (op.isTemp())
2086             ctx.uses[op.tempId()]--;
2087       }
2088    }
2089 }
2090 
2091 void
decrease_uses(opt_ctx & ctx,Instruction * instr)2092 decrease_uses(opt_ctx& ctx, Instruction* instr)
2093 {
2094    ctx.uses[instr->definitions[0].tempId()]--;
2095    decrease_op_uses_if_dead(ctx, instr);
2096 }
2097 
2098 Operand
copy_operand(opt_ctx & ctx,Operand op)2099 copy_operand(opt_ctx& ctx, Operand op)
2100 {
2101    if (op.isTemp())
2102       ctx.uses[op.tempId()]++;
2103    return op;
2104 }
2105 
2106 Instruction*
follow_operand(opt_ctx & ctx,Operand op,bool ignore_uses=false)2107 follow_operand(opt_ctx& ctx, Operand op, bool ignore_uses = false)
2108 {
2109    if (!op.isTemp() || !(ctx.info[op.tempId()].label & instr_usedef_labels))
2110       return nullptr;
2111    if (!ignore_uses && ctx.uses[op.tempId()] > 1)
2112       return nullptr;
2113 
2114    Instruction* instr = ctx.info[op.tempId()].instr;
2115 
2116    if (instr->definitions.size() == 2) {
2117       unsigned idx = ctx.info[op.tempId()].label & label_split ? 1 : 0;
2118       assert(instr->definitions[idx].isTemp() && instr->definitions[idx].tempId() == op.tempId());
2119       if (instr->definitions[!idx].isTemp() && ctx.uses[instr->definitions[!idx].tempId()])
2120          return nullptr;
2121    }
2122 
2123    for (Operand& operand : instr->operands) {
2124       if (fixed_to_exec(operand))
2125          return nullptr;
2126    }
2127 
2128    return instr;
2129 }
2130 
2131 bool
is_operand_constant(opt_ctx & ctx,Operand op,unsigned bit_size,uint64_t * value)2132 is_operand_constant(opt_ctx& ctx, Operand op, unsigned bit_size, uint64_t* value)
2133 {
2134    if (op.isConstant()) {
2135       *value = op.constantValue64();
2136       return true;
2137    } else if (op.isTemp()) {
2138       unsigned id = original_temp_id(ctx, op.getTemp());
2139       if (!ctx.info[id].is_constant_or_literal(bit_size))
2140          return false;
2141       *value = get_constant_op(ctx, ctx.info[id], bit_size).constantValue64();
2142       return true;
2143    }
2144    return false;
2145 }
2146 
2147 /* s_not(cmp(a, b)) -> get_vcmp_inverse(cmp)(a, b) */
2148 bool
combine_inverse_comparison(opt_ctx & ctx,aco_ptr<Instruction> & instr)2149 combine_inverse_comparison(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2150 {
2151    if (ctx.uses[instr->definitions[1].tempId()])
2152       return false;
2153    if (!instr->operands[0].isTemp() || ctx.uses[instr->operands[0].tempId()] != 1)
2154       return false;
2155 
2156    Instruction* cmp = follow_operand(ctx, instr->operands[0]);
2157    if (!cmp)
2158       return false;
2159 
2160    aco_opcode new_opcode = get_vcmp_inverse(cmp->opcode);
2161    if (new_opcode == aco_opcode::num_opcodes)
2162       return false;
2163 
2164    /* Invert compare instruction and assign this instruction's definition */
2165    cmp->opcode = new_opcode;
2166    ctx.info[instr->definitions[0].tempId()] = ctx.info[cmp->definitions[0].tempId()];
2167    std::swap(instr->definitions[0], cmp->definitions[0]);
2168 
2169    ctx.uses[instr->operands[0].tempId()]--;
2170    return true;
2171 }
2172 
2173 /* op1(op2(1, 2), 0) if swap = false
2174  * op1(0, op2(1, 2)) if swap = true */
2175 bool
match_op3_for_vop3(opt_ctx & ctx,aco_opcode op1,aco_opcode op2,Instruction * op1_instr,bool swap,const char * shuffle_str,Operand operands[3],bitarray8 & neg,bitarray8 & abs,bitarray8 & opsel,bool * op1_clamp,uint8_t * op1_omod,bool * inbetween_neg,bool * inbetween_abs,bool * inbetween_opsel,bool * precise)2176 match_op3_for_vop3(opt_ctx& ctx, aco_opcode op1, aco_opcode op2, Instruction* op1_instr, bool swap,
2177                    const char* shuffle_str, Operand operands[3], bitarray8& neg, bitarray8& abs,
2178                    bitarray8& opsel, bool* op1_clamp, uint8_t* op1_omod, bool* inbetween_neg,
2179                    bool* inbetween_abs, bool* inbetween_opsel, bool* precise)
2180 {
2181    /* checks */
2182    if (op1_instr->opcode != op1)
2183       return false;
2184 
2185    Instruction* op2_instr = follow_operand(ctx, op1_instr->operands[swap]);
2186    if (!op2_instr || op2_instr->opcode != op2)
2187       return false;
2188 
2189    VALU_instruction* op1_valu = op1_instr->isVALU() ? &op1_instr->valu() : NULL;
2190    VALU_instruction* op2_valu = op2_instr->isVALU() ? &op2_instr->valu() : NULL;
2191 
2192    if (op1_instr->isSDWA() || op2_instr->isSDWA())
2193       return false;
2194    if (op1_instr->isDPP() || op2_instr->isDPP())
2195       return false;
2196 
2197    /* don't support inbetween clamp/omod */
2198    if (op2_valu && (op2_valu->clamp || op2_valu->omod))
2199       return false;
2200 
2201    /* get operands and modifiers and check inbetween modifiers */
2202    *op1_clamp = op1_valu ? (bool)op1_valu->clamp : false;
2203    *op1_omod = op1_valu ? (unsigned)op1_valu->omod : 0u;
2204 
2205    if (inbetween_neg)
2206       *inbetween_neg = op1_valu ? op1_valu->neg[swap] : false;
2207    else if (op1_valu && op1_valu->neg[swap])
2208       return false;
2209 
2210    if (inbetween_abs)
2211       *inbetween_abs = op1_valu ? op1_valu->abs[swap] : false;
2212    else if (op1_valu && op1_valu->abs[swap])
2213       return false;
2214 
2215    if (inbetween_opsel)
2216       *inbetween_opsel = op1_valu ? op1_valu->opsel[swap] : false;
2217    else if (op1_valu && op1_valu->opsel[swap])
2218       return false;
2219 
2220    *precise = op1_instr->definitions[0].isPrecise() || op2_instr->definitions[0].isPrecise();
2221 
2222    int shuffle[3];
2223    shuffle[shuffle_str[0] - '0'] = 0;
2224    shuffle[shuffle_str[1] - '0'] = 1;
2225    shuffle[shuffle_str[2] - '0'] = 2;
2226 
2227    operands[shuffle[0]] = op1_instr->operands[!swap];
2228    neg[shuffle[0]] = op1_valu ? op1_valu->neg[!swap] : false;
2229    abs[shuffle[0]] = op1_valu ? op1_valu->abs[!swap] : false;
2230    opsel[shuffle[0]] = op1_valu ? op1_valu->opsel[!swap] : false;
2231 
2232    for (unsigned i = 0; i < 2; i++) {
2233       operands[shuffle[i + 1]] = op2_instr->operands[i];
2234       neg[shuffle[i + 1]] = op2_valu ? op2_valu->neg[i] : false;
2235       abs[shuffle[i + 1]] = op2_valu ? op2_valu->abs[i] : false;
2236       opsel[shuffle[i + 1]] = op2_valu ? op2_valu->opsel[i] : false;
2237    }
2238 
2239    /* check operands */
2240    if (!check_vop3_operands(ctx, 3, operands))
2241       return false;
2242 
2243    return true;
2244 }
2245 
2246 void
create_vop3_for_op3(opt_ctx & ctx,aco_opcode opcode,aco_ptr<Instruction> & instr,Operand operands[3],uint8_t neg,uint8_t abs,uint8_t opsel,bool clamp,unsigned omod)2247 create_vop3_for_op3(opt_ctx& ctx, aco_opcode opcode, aco_ptr<Instruction>& instr,
2248                     Operand operands[3], uint8_t neg, uint8_t abs, uint8_t opsel, bool clamp,
2249                     unsigned omod)
2250 {
2251    Instruction* new_instr = create_instruction(opcode, Format::VOP3, 3, 1);
2252    new_instr->valu().neg = neg;
2253    new_instr->valu().abs = abs;
2254    new_instr->valu().clamp = clamp;
2255    new_instr->valu().omod = omod;
2256    new_instr->valu().opsel = opsel;
2257    new_instr->operands[0] = operands[0];
2258    new_instr->operands[1] = operands[1];
2259    new_instr->operands[2] = operands[2];
2260    new_instr->definitions[0] = instr->definitions[0];
2261    new_instr->pass_flags = instr->pass_flags;
2262    ctx.info[instr->definitions[0].tempId()].label = 0;
2263 
2264    instr.reset(new_instr);
2265 }
2266 
2267 bool
combine_three_valu_op(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode op2,aco_opcode new_op,const char * shuffle,uint8_t ops)2268 combine_three_valu_op(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode op2, aco_opcode new_op,
2269                       const char* shuffle, uint8_t ops)
2270 {
2271    for (unsigned swap = 0; swap < 2; swap++) {
2272       if (!((1 << swap) & ops))
2273          continue;
2274 
2275       Operand operands[3];
2276       bool clamp, precise;
2277       bitarray8 neg = 0, abs = 0, opsel = 0;
2278       uint8_t omod = 0;
2279       if (match_op3_for_vop3(ctx, instr->opcode, op2, instr.get(), swap, shuffle, operands, neg,
2280                              abs, opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) {
2281          ctx.uses[instr->operands[swap].tempId()]--;
2282          create_vop3_for_op3(ctx, new_op, instr, operands, neg, abs, opsel, clamp, omod);
2283          return true;
2284       }
2285    }
2286    return false;
2287 }
2288 
2289 /* creates v_lshl_add_u32, v_lshl_or_b32 or v_and_or_b32 */
2290 bool
combine_add_or_then_and_lshl(opt_ctx & ctx,aco_ptr<Instruction> & instr)2291 combine_add_or_then_and_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2292 {
2293    bool is_or = instr->opcode == aco_opcode::v_or_b32;
2294    aco_opcode new_op_lshl = is_or ? aco_opcode::v_lshl_or_b32 : aco_opcode::v_lshl_add_u32;
2295 
2296    if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::s_and_b32, aco_opcode::v_and_or_b32,
2297                                       "120", 1 | 2))
2298       return true;
2299    if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::v_and_b32, aco_opcode::v_and_or_b32,
2300                                       "120", 1 | 2))
2301       return true;
2302    if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, new_op_lshl, "120", 1 | 2))
2303       return true;
2304    if (combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, new_op_lshl, "210", 1 | 2))
2305       return true;
2306 
2307    if (instr->isSDWA() || instr->isDPP())
2308       return false;
2309 
2310    /* v_or_b32(p_extract(a, 0, 8/16, 0), b) -> v_and_or_b32(a, 0xff/0xffff, b)
2311     * v_or_b32(p_insert(a, 0, 8/16), b) -> v_and_or_b32(a, 0xff/0xffff, b)
2312     * v_or_b32(p_insert(a, 24/16, 8/16), b) -> v_lshl_or_b32(a, 24/16, b)
2313     * v_add_u32(p_insert(a, 24/16, 8/16), b) -> v_lshl_add_b32(a, 24/16, b)
2314     */
2315    for (unsigned i = 0; i < 2; i++) {
2316       Instruction* extins = follow_operand(ctx, instr->operands[i]);
2317       if (!extins)
2318          continue;
2319 
2320       aco_opcode op;
2321       Operand operands[3];
2322 
2323       if (extins->opcode == aco_opcode::p_insert &&
2324           (extins->operands[1].constantValue() + 1) * extins->operands[2].constantValue() == 32) {
2325          op = new_op_lshl;
2326          operands[1] =
2327             Operand::c32(extins->operands[1].constantValue() * extins->operands[2].constantValue());
2328       } else if (is_or &&
2329                  (extins->opcode == aco_opcode::p_insert ||
2330                   (extins->opcode == aco_opcode::p_extract &&
2331                    extins->operands[3].constantEquals(0))) &&
2332                  extins->operands[1].constantEquals(0)) {
2333          op = aco_opcode::v_and_or_b32;
2334          operands[1] = Operand::c32(extins->operands[2].constantEquals(8) ? 0xffu : 0xffffu);
2335       } else {
2336          continue;
2337       }
2338 
2339       operands[0] = extins->operands[0];
2340       operands[2] = instr->operands[!i];
2341 
2342       if (!check_vop3_operands(ctx, 3, operands))
2343          continue;
2344 
2345       uint8_t neg = 0, abs = 0, opsel = 0, omod = 0;
2346       bool clamp = false;
2347       if (instr->isVOP3())
2348          clamp = instr->valu().clamp;
2349 
2350       ctx.uses[instr->operands[i].tempId()]--;
2351       create_vop3_for_op3(ctx, op, instr, operands, neg, abs, opsel, clamp, omod);
2352       return true;
2353    }
2354 
2355    return false;
2356 }
2357 
2358 /* v_xor(a, s_not(b)) -> v_xnor(a, b)
2359  * v_xor(a, v_not(b)) -> v_xnor(a, b)
2360  */
2361 bool
combine_xor_not(opt_ctx & ctx,aco_ptr<Instruction> & instr)2362 combine_xor_not(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2363 {
2364    if (instr->usesModifiers())
2365       return false;
2366 
2367    for (unsigned i = 0; i < 2; i++) {
2368       Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
2369       if (!op_instr ||
2370           (op_instr->opcode != aco_opcode::v_not_b32 &&
2371            op_instr->opcode != aco_opcode::s_not_b32) ||
2372           op_instr->usesModifiers() || op_instr->operands[0].isLiteral())
2373          continue;
2374 
2375       instr->opcode = aco_opcode::v_xnor_b32;
2376       instr->operands[i] = copy_operand(ctx, op_instr->operands[0]);
2377       decrease_uses(ctx, op_instr);
2378       if (instr->operands[0].isOfType(RegType::vgpr))
2379          std::swap(instr->operands[0], instr->operands[1]);
2380       if (!instr->operands[1].isOfType(RegType::vgpr))
2381          instr->format = asVOP3(instr->format);
2382 
2383       return true;
2384    }
2385 
2386    return false;
2387 }
2388 
2389 /* v_not(v_xor(a, b)) -> v_xnor(a, b) */
2390 bool
combine_not_xor(opt_ctx & ctx,aco_ptr<Instruction> & instr)2391 combine_not_xor(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2392 {
2393    if (instr->usesModifiers())
2394       return false;
2395 
2396    Instruction* op_instr = follow_operand(ctx, instr->operands[0]);
2397    if (!op_instr || op_instr->opcode != aco_opcode::v_xor_b32 || op_instr->isSDWA())
2398       return false;
2399 
2400    ctx.uses[instr->operands[0].tempId()]--;
2401    std::swap(instr->definitions[0], op_instr->definitions[0]);
2402    op_instr->opcode = aco_opcode::v_xnor_b32;
2403    ctx.info[op_instr->definitions[0].tempId()].label = 0;
2404 
2405    return true;
2406 }
2407 
2408 bool
combine_minmax(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode opposite,aco_opcode op3src,aco_opcode minmax)2409 combine_minmax(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode opposite, aco_opcode op3src,
2410                aco_opcode minmax)
2411 {
2412    /* TODO: this can handle SDWA min/max instructions by using opsel */
2413 
2414    /* min(min(a, b), c) -> min3(a, b, c)
2415     * max(max(a, b), c) -> max3(a, b, c)
2416     * gfx11: min(-min(a, b), c) -> maxmin(-a, -b, c)
2417     * gfx11: max(-max(a, b), c) -> minmax(-a, -b, c)
2418     */
2419    for (unsigned swap = 0; swap < 2; swap++) {
2420       Operand operands[3];
2421       bool clamp, precise;
2422       bitarray8 opsel = 0, neg = 0, abs = 0;
2423       uint8_t omod = 0;
2424       bool inbetween_neg;
2425       if (match_op3_for_vop3(ctx, instr->opcode, instr->opcode, instr.get(), swap, "120", operands,
2426                              neg, abs, opsel, &clamp, &omod, &inbetween_neg, NULL, NULL,
2427                              &precise) &&
2428           (!inbetween_neg ||
2429            (minmax != aco_opcode::num_opcodes && ctx.program->gfx_level >= GFX11))) {
2430          ctx.uses[instr->operands[swap].tempId()]--;
2431          if (inbetween_neg) {
2432             neg[0] = !neg[0];
2433             neg[1] = !neg[1];
2434             create_vop3_for_op3(ctx, minmax, instr, operands, neg, abs, opsel, clamp, omod);
2435          } else {
2436             create_vop3_for_op3(ctx, op3src, instr, operands, neg, abs, opsel, clamp, omod);
2437          }
2438          return true;
2439       }
2440    }
2441 
2442    /* min(-max(a, b), c) -> min3(-a, -b, c)
2443     * max(-min(a, b), c) -> max3(-a, -b, c)
2444     * gfx11: min(max(a, b), c) -> maxmin(a, b, c)
2445     * gfx11: max(min(a, b), c) -> minmax(a, b, c)
2446     */
2447    for (unsigned swap = 0; swap < 2; swap++) {
2448       Operand operands[3];
2449       bool clamp, precise;
2450       bitarray8 opsel = 0, neg = 0, abs = 0;
2451       uint8_t omod = 0;
2452       bool inbetween_neg;
2453       if (match_op3_for_vop3(ctx, instr->opcode, opposite, instr.get(), swap, "120", operands, neg,
2454                              abs, opsel, &clamp, &omod, &inbetween_neg, NULL, NULL, &precise) &&
2455           (inbetween_neg ||
2456            (minmax != aco_opcode::num_opcodes && ctx.program->gfx_level >= GFX11))) {
2457          ctx.uses[instr->operands[swap].tempId()]--;
2458          if (inbetween_neg) {
2459             neg[0] = !neg[0];
2460             neg[1] = !neg[1];
2461             create_vop3_for_op3(ctx, op3src, instr, operands, neg, abs, opsel, clamp, omod);
2462          } else {
2463             create_vop3_for_op3(ctx, minmax, instr, operands, neg, abs, opsel, clamp, omod);
2464          }
2465          return true;
2466       }
2467    }
2468    return false;
2469 }
2470 
2471 /* s_not_b32(s_and_b32(a, b)) -> s_nand_b32(a, b)
2472  * s_not_b32(s_or_b32(a, b)) -> s_nor_b32(a, b)
2473  * s_not_b32(s_xor_b32(a, b)) -> s_xnor_b32(a, b)
2474  * s_not_b64(s_and_b64(a, b)) -> s_nand_b64(a, b)
2475  * s_not_b64(s_or_b64(a, b)) -> s_nor_b64(a, b)
2476  * s_not_b64(s_xor_b64(a, b)) -> s_xnor_b64(a, b) */
2477 bool
combine_salu_not_bitwise(opt_ctx & ctx,aco_ptr<Instruction> & instr)2478 combine_salu_not_bitwise(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2479 {
2480    /* checks */
2481    if (!instr->operands[0].isTemp())
2482       return false;
2483    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2484       return false;
2485 
2486    Instruction* op2_instr = follow_operand(ctx, instr->operands[0]);
2487    if (!op2_instr)
2488       return false;
2489    switch (op2_instr->opcode) {
2490    case aco_opcode::s_and_b32:
2491    case aco_opcode::s_or_b32:
2492    case aco_opcode::s_xor_b32:
2493    case aco_opcode::s_and_b64:
2494    case aco_opcode::s_or_b64:
2495    case aco_opcode::s_xor_b64: break;
2496    default: return false;
2497    }
2498 
2499    /* create instruction */
2500    std::swap(instr->definitions[0], op2_instr->definitions[0]);
2501    std::swap(instr->definitions[1], op2_instr->definitions[1]);
2502    ctx.uses[instr->operands[0].tempId()]--;
2503    ctx.info[op2_instr->definitions[0].tempId()].label = 0;
2504 
2505    switch (op2_instr->opcode) {
2506    case aco_opcode::s_and_b32: op2_instr->opcode = aco_opcode::s_nand_b32; break;
2507    case aco_opcode::s_or_b32: op2_instr->opcode = aco_opcode::s_nor_b32; break;
2508    case aco_opcode::s_xor_b32: op2_instr->opcode = aco_opcode::s_xnor_b32; break;
2509    case aco_opcode::s_and_b64: op2_instr->opcode = aco_opcode::s_nand_b64; break;
2510    case aco_opcode::s_or_b64: op2_instr->opcode = aco_opcode::s_nor_b64; break;
2511    case aco_opcode::s_xor_b64: op2_instr->opcode = aco_opcode::s_xnor_b64; break;
2512    default: break;
2513    }
2514 
2515    return true;
2516 }
2517 
2518 /* s_and_b32(a, s_not_b32(b)) -> s_andn2_b32(a, b)
2519  * s_or_b32(a, s_not_b32(b)) -> s_orn2_b32(a, b)
2520  * s_and_b64(a, s_not_b64(b)) -> s_andn2_b64(a, b)
2521  * s_or_b64(a, s_not_b64(b)) -> s_orn2_b64(a, b) */
2522 bool
combine_salu_n2(opt_ctx & ctx,aco_ptr<Instruction> & instr)2523 combine_salu_n2(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2524 {
2525    if (instr->definitions[0].isTemp() && ctx.info[instr->definitions[0].tempId()].is_uniform_bool())
2526       return false;
2527 
2528    for (unsigned i = 0; i < 2; i++) {
2529       Instruction* op2_instr = follow_operand(ctx, instr->operands[i]);
2530       if (!op2_instr || (op2_instr->opcode != aco_opcode::s_not_b32 &&
2531                          op2_instr->opcode != aco_opcode::s_not_b64))
2532          continue;
2533       if (ctx.uses[op2_instr->definitions[1].tempId()])
2534          continue;
2535 
2536       if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
2537           instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
2538          continue;
2539 
2540       ctx.uses[instr->operands[i].tempId()]--;
2541       instr->operands[0] = instr->operands[!i];
2542       instr->operands[1] = op2_instr->operands[0];
2543       ctx.info[instr->definitions[0].tempId()].label = 0;
2544 
2545       switch (instr->opcode) {
2546       case aco_opcode::s_and_b32: instr->opcode = aco_opcode::s_andn2_b32; break;
2547       case aco_opcode::s_or_b32: instr->opcode = aco_opcode::s_orn2_b32; break;
2548       case aco_opcode::s_and_b64: instr->opcode = aco_opcode::s_andn2_b64; break;
2549       case aco_opcode::s_or_b64: instr->opcode = aco_opcode::s_orn2_b64; break;
2550       default: break;
2551       }
2552 
2553       return true;
2554    }
2555    return false;
2556 }
2557 
2558 /* s_add_{i32,u32}(a, s_lshl_b32(b, <n>)) -> s_lshl<n>_add_u32(a, b) */
2559 bool
combine_salu_lshl_add(opt_ctx & ctx,aco_ptr<Instruction> & instr)2560 combine_salu_lshl_add(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2561 {
2562    if (instr->opcode == aco_opcode::s_add_i32 && ctx.uses[instr->definitions[1].tempId()])
2563       return false;
2564 
2565    for (unsigned i = 0; i < 2; i++) {
2566       Instruction* op2_instr = follow_operand(ctx, instr->operands[i], true);
2567       if (!op2_instr || op2_instr->opcode != aco_opcode::s_lshl_b32 ||
2568           ctx.uses[op2_instr->definitions[1].tempId()])
2569          continue;
2570       if (!op2_instr->operands[1].isConstant())
2571          continue;
2572 
2573       uint32_t shift = op2_instr->operands[1].constantValue();
2574       if (shift < 1 || shift > 4)
2575          continue;
2576 
2577       if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
2578           instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
2579          continue;
2580 
2581       instr->operands[1] = instr->operands[!i];
2582       instr->operands[0] = copy_operand(ctx, op2_instr->operands[0]);
2583       decrease_uses(ctx, op2_instr);
2584       ctx.info[instr->definitions[0].tempId()].label = 0;
2585 
2586       instr->opcode = std::array<aco_opcode, 4>{
2587          aco_opcode::s_lshl1_add_u32, aco_opcode::s_lshl2_add_u32, aco_opcode::s_lshl3_add_u32,
2588          aco_opcode::s_lshl4_add_u32}[shift - 1];
2589 
2590       return true;
2591    }
2592    return false;
2593 }
2594 
2595 /* s_abs_i32(s_sub_[iu]32(a, b)) -> s_absdiff_i32(a, b)
2596  * s_abs_i32(s_add_[iu]32(a, #b)) -> s_absdiff_i32(a, -b)
2597  */
2598 bool
combine_sabsdiff(opt_ctx & ctx,aco_ptr<Instruction> & instr)2599 combine_sabsdiff(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2600 {
2601    if (!instr->operands[0].isTemp() || !ctx.info[instr->operands[0].tempId()].is_add_sub())
2602       return false;
2603 
2604    Instruction* op_instr = follow_operand(ctx, instr->operands[0], false);
2605    if (!op_instr)
2606       return false;
2607 
2608    if (op_instr->opcode == aco_opcode::s_add_i32 || op_instr->opcode == aco_opcode::s_add_u32) {
2609       for (unsigned i = 0; i < 2; i++) {
2610          uint64_t constant;
2611          if (op_instr->operands[!i].isLiteral() ||
2612              !is_operand_constant(ctx, op_instr->operands[i], 32, &constant))
2613             continue;
2614 
2615          if (op_instr->operands[i].isTemp())
2616             ctx.uses[op_instr->operands[i].tempId()]--;
2617          op_instr->operands[0] = op_instr->operands[!i];
2618          op_instr->operands[1] = Operand::c32(-int32_t(constant));
2619          goto use_absdiff;
2620       }
2621       return false;
2622    }
2623 
2624 use_absdiff:
2625    op_instr->opcode = aco_opcode::s_absdiff_i32;
2626    std::swap(instr->definitions[0], op_instr->definitions[0]);
2627    std::swap(instr->definitions[1], op_instr->definitions[1]);
2628    ctx.uses[instr->operands[0].tempId()]--;
2629    ctx.info[op_instr->definitions[0].tempId()].label = 0;
2630 
2631    return true;
2632 }
2633 
2634 bool
combine_add_sub_b2i(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode new_op,uint8_t ops)2635 combine_add_sub_b2i(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode new_op, uint8_t ops)
2636 {
2637    if (instr->usesModifiers())
2638       return false;
2639 
2640    for (unsigned i = 0; i < 2; i++) {
2641       if (!((1 << i) & ops))
2642          continue;
2643       if (instr->operands[i].isTemp() && ctx.info[instr->operands[i].tempId()].is_b2i() &&
2644           ctx.uses[instr->operands[i].tempId()] == 1) {
2645 
2646          aco_ptr<Instruction> new_instr;
2647          if (instr->operands[!i].isTemp() &&
2648              instr->operands[!i].getTemp().type() == RegType::vgpr) {
2649             new_instr.reset(create_instruction(new_op, Format::VOP2, 3, 2));
2650          } else if (ctx.program->gfx_level >= GFX10 ||
2651                     (instr->operands[!i].isConstant() && !instr->operands[!i].isLiteral())) {
2652             new_instr.reset(create_instruction(new_op, asVOP3(Format::VOP2), 3, 2));
2653          } else {
2654             return false;
2655          }
2656          ctx.uses[instr->operands[i].tempId()]--;
2657          new_instr->definitions[0] = instr->definitions[0];
2658          if (instr->definitions.size() == 2) {
2659             new_instr->definitions[1] = instr->definitions[1];
2660          } else {
2661             new_instr->definitions[1] =
2662                Definition(ctx.program->allocateTmp(ctx.program->lane_mask));
2663             /* Make sure the uses vector is large enough and the number of
2664              * uses properly initialized to 0.
2665              */
2666             ctx.uses.push_back(0);
2667             ctx.info.push_back(ssa_info{});
2668          }
2669          new_instr->operands[0] = Operand::zero();
2670          new_instr->operands[1] = instr->operands[!i];
2671          new_instr->operands[2] = Operand(ctx.info[instr->operands[i].tempId()].temp);
2672          new_instr->pass_flags = instr->pass_flags;
2673          instr = std::move(new_instr);
2674          ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
2675          return true;
2676       }
2677    }
2678 
2679    return false;
2680 }
2681 
2682 bool
combine_add_bcnt(opt_ctx & ctx,aco_ptr<Instruction> & instr)2683 combine_add_bcnt(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2684 {
2685    if (instr->usesModifiers())
2686       return false;
2687 
2688    for (unsigned i = 0; i < 2; i++) {
2689       Instruction* op_instr = follow_operand(ctx, instr->operands[i]);
2690       if (op_instr && op_instr->opcode == aco_opcode::v_bcnt_u32_b32 &&
2691           !op_instr->usesModifiers() && op_instr->operands[0].isTemp() &&
2692           op_instr->operands[0].getTemp().type() == RegType::vgpr &&
2693           op_instr->operands[1].constantEquals(0)) {
2694          aco_ptr<Instruction> new_instr{
2695             create_instruction(aco_opcode::v_bcnt_u32_b32, Format::VOP3, 2, 1)};
2696          ctx.uses[instr->operands[i].tempId()]--;
2697          new_instr->operands[0] = op_instr->operands[0];
2698          new_instr->operands[1] = instr->operands[!i];
2699          new_instr->definitions[0] = instr->definitions[0];
2700          new_instr->pass_flags = instr->pass_flags;
2701          instr = std::move(new_instr);
2702          ctx.info[instr->definitions[0].tempId()].label = 0;
2703 
2704          return true;
2705       }
2706    }
2707 
2708    return false;
2709 }
2710 
2711 bool
get_minmax_info(aco_opcode op,aco_opcode * min,aco_opcode * max,aco_opcode * min3,aco_opcode * max3,aco_opcode * med3,aco_opcode * minmax,bool * some_gfx9_only)2712 get_minmax_info(aco_opcode op, aco_opcode* min, aco_opcode* max, aco_opcode* min3, aco_opcode* max3,
2713                 aco_opcode* med3, aco_opcode* minmax, bool* some_gfx9_only)
2714 {
2715    switch (op) {
2716 #define MINMAX(type, gfx9)                                                                         \
2717    case aco_opcode::v_min_##type:                                                                  \
2718    case aco_opcode::v_max_##type:                                                                  \
2719       *min = aco_opcode::v_min_##type;                                                             \
2720       *max = aco_opcode::v_max_##type;                                                             \
2721       *med3 = aco_opcode::v_med3_##type;                                                           \
2722       *min3 = aco_opcode::v_min3_##type;                                                           \
2723       *max3 = aco_opcode::v_max3_##type;                                                           \
2724       *minmax = op == *min ? aco_opcode::v_maxmin_##type : aco_opcode::v_minmax_##type;            \
2725       *some_gfx9_only = gfx9;                                                                      \
2726       return true;
2727 #define MINMAX_INT16(type, gfx9)                                                                   \
2728    case aco_opcode::v_min_##type:                                                                  \
2729    case aco_opcode::v_max_##type:                                                                  \
2730       *min = aco_opcode::v_min_##type;                                                             \
2731       *max = aco_opcode::v_max_##type;                                                             \
2732       *med3 = aco_opcode::v_med3_##type;                                                           \
2733       *min3 = aco_opcode::v_min3_##type;                                                           \
2734       *max3 = aco_opcode::v_max3_##type;                                                           \
2735       *minmax = aco_opcode::num_opcodes;                                                           \
2736       *some_gfx9_only = gfx9;                                                                      \
2737       return true;
2738 #define MINMAX_INT16_E64(type, gfx9)                                                               \
2739    case aco_opcode::v_min_##type##_e64:                                                            \
2740    case aco_opcode::v_max_##type##_e64:                                                            \
2741       *min = aco_opcode::v_min_##type##_e64;                                                       \
2742       *max = aco_opcode::v_max_##type##_e64;                                                       \
2743       *med3 = aco_opcode::v_med3_##type;                                                           \
2744       *min3 = aco_opcode::v_min3_##type;                                                           \
2745       *max3 = aco_opcode::v_max3_##type;                                                           \
2746       *minmax = aco_opcode::num_opcodes;                                                           \
2747       *some_gfx9_only = gfx9;                                                                      \
2748       return true;
2749       MINMAX(f32, false)
2750       MINMAX(u32, false)
2751       MINMAX(i32, false)
2752       MINMAX(f16, true)
2753       MINMAX_INT16(u16, true)
2754       MINMAX_INT16(i16, true)
2755       MINMAX_INT16_E64(u16, true)
2756       MINMAX_INT16_E64(i16, true)
2757 #undef MINMAX_INT16_E64
2758 #undef MINMAX_INT16
2759 #undef MINMAX
2760    default: return false;
2761    }
2762 }
2763 
2764 /* when ub > lb:
2765  * v_min_{f,u,i}{16,32}(v_max_{f,u,i}{16,32}(a, lb), ub) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
2766  * v_max_{f,u,i}{16,32}(v_min_{f,u,i}{16,32}(a, ub), lb) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
2767  */
2768 bool
combine_clamp(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode min,aco_opcode max,aco_opcode med)2769 combine_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode min, aco_opcode max,
2770               aco_opcode med)
2771 {
2772    /* TODO: GLSL's clamp(x, minVal, maxVal) and SPIR-V's
2773     * FClamp(x, minVal, maxVal)/NClamp(x, minVal, maxVal) are undefined if
2774     * minVal > maxVal, which means we can always select it to a v_med3_f32 */
2775    aco_opcode other_op;
2776    if (instr->opcode == min)
2777       other_op = max;
2778    else if (instr->opcode == max)
2779       other_op = min;
2780    else
2781       return false;
2782 
2783    for (unsigned swap = 0; swap < 2; swap++) {
2784       Operand operands[3];
2785       bool clamp, precise;
2786       bitarray8 opsel = 0, neg = 0, abs = 0;
2787       uint8_t omod = 0;
2788       if (match_op3_for_vop3(ctx, instr->opcode, other_op, instr.get(), swap, "012", operands, neg,
2789                              abs, opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) {
2790          /* max(min(src, upper), lower) returns upper if src is NaN, but
2791           * med3(src, lower, upper) returns lower.
2792           */
2793          if (precise && instr->opcode != min &&
2794              (min == aco_opcode::v_min_f16 || min == aco_opcode::v_min_f32))
2795             continue;
2796 
2797          int const0_idx = -1, const1_idx = -1;
2798          uint32_t const0 = 0, const1 = 0;
2799          for (int i = 0; i < 3; i++) {
2800             uint32_t val;
2801             bool hi16 = opsel & (1 << i);
2802             if (operands[i].isConstant()) {
2803                val = hi16 ? operands[i].constantValue16(true) : operands[i].constantValue();
2804             } else if (operands[i].isTemp() &&
2805                        ctx.info[operands[i].tempId()].is_constant_or_literal(32)) {
2806                val = ctx.info[operands[i].tempId()].val >> (hi16 ? 16 : 0);
2807             } else {
2808                continue;
2809             }
2810             if (const0_idx >= 0) {
2811                const1_idx = i;
2812                const1 = val;
2813             } else {
2814                const0_idx = i;
2815                const0 = val;
2816             }
2817          }
2818          if (const0_idx < 0 || const1_idx < 0)
2819             continue;
2820 
2821          int lower_idx = const0_idx;
2822          switch (min) {
2823          case aco_opcode::v_min_f32:
2824          case aco_opcode::v_min_f16: {
2825             float const0_f, const1_f;
2826             if (min == aco_opcode::v_min_f32) {
2827                memcpy(&const0_f, &const0, 4);
2828                memcpy(&const1_f, &const1, 4);
2829             } else {
2830                const0_f = _mesa_half_to_float(const0);
2831                const1_f = _mesa_half_to_float(const1);
2832             }
2833             if (abs[const0_idx])
2834                const0_f = fabsf(const0_f);
2835             if (abs[const1_idx])
2836                const1_f = fabsf(const1_f);
2837             if (neg[const0_idx])
2838                const0_f = -const0_f;
2839             if (neg[const1_idx])
2840                const1_f = -const1_f;
2841             lower_idx = const0_f < const1_f ? const0_idx : const1_idx;
2842             break;
2843          }
2844          case aco_opcode::v_min_u32: {
2845             lower_idx = const0 < const1 ? const0_idx : const1_idx;
2846             break;
2847          }
2848          case aco_opcode::v_min_u16:
2849          case aco_opcode::v_min_u16_e64: {
2850             lower_idx = (uint16_t)const0 < (uint16_t)const1 ? const0_idx : const1_idx;
2851             break;
2852          }
2853          case aco_opcode::v_min_i32: {
2854             int32_t const0_i =
2855                const0 & 0x80000000u ? -2147483648 + (int32_t)(const0 & 0x7fffffffu) : const0;
2856             int32_t const1_i =
2857                const1 & 0x80000000u ? -2147483648 + (int32_t)(const1 & 0x7fffffffu) : const1;
2858             lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
2859             break;
2860          }
2861          case aco_opcode::v_min_i16:
2862          case aco_opcode::v_min_i16_e64: {
2863             int16_t const0_i = const0 & 0x8000u ? -32768 + (int16_t)(const0 & 0x7fffu) : const0;
2864             int16_t const1_i = const1 & 0x8000u ? -32768 + (int16_t)(const1 & 0x7fffu) : const1;
2865             lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
2866             break;
2867          }
2868          default: break;
2869          }
2870          int upper_idx = lower_idx == const0_idx ? const1_idx : const0_idx;
2871 
2872          if (instr->opcode == min) {
2873             if (upper_idx != 0 || lower_idx == 0)
2874                return false;
2875          } else {
2876             if (upper_idx == 0 || lower_idx != 0)
2877                return false;
2878          }
2879 
2880          ctx.uses[instr->operands[swap].tempId()]--;
2881          create_vop3_for_op3(ctx, med, instr, operands, neg, abs, opsel, clamp, omod);
2882 
2883          return true;
2884       }
2885    }
2886 
2887    return false;
2888 }
2889 
2890 void
apply_sgprs(opt_ctx & ctx,aco_ptr<Instruction> & instr)2891 apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2892 {
2893    bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64_e64 ||
2894                      instr->opcode == aco_opcode::v_lshlrev_b64 ||
2895                      instr->opcode == aco_opcode::v_lshrrev_b64 ||
2896                      instr->opcode == aco_opcode::v_ashrrev_i64;
2897 
2898    /* find candidates and create the set of sgprs already read */
2899    unsigned sgpr_ids[2] = {0, 0};
2900    uint32_t operand_mask = 0;
2901    bool has_literal = false;
2902    for (unsigned i = 0; i < instr->operands.size(); i++) {
2903       if (instr->operands[i].isLiteral())
2904          has_literal = true;
2905       if (!instr->operands[i].isTemp())
2906          continue;
2907       if (instr->operands[i].getTemp().type() == RegType::sgpr) {
2908          if (instr->operands[i].tempId() != sgpr_ids[0])
2909             sgpr_ids[!!sgpr_ids[0]] = instr->operands[i].tempId();
2910       }
2911       ssa_info& info = ctx.info[instr->operands[i].tempId()];
2912       if (is_copy_label(ctx, instr, info, i) && info.temp.type() == RegType::sgpr)
2913          operand_mask |= 1u << i;
2914       if (info.is_extract() && info.instr->operands[0].getTemp().type() == RegType::sgpr)
2915          operand_mask |= 1u << i;
2916    }
2917    unsigned max_sgprs = 1;
2918    if (ctx.program->gfx_level >= GFX10 && !is_shift64)
2919       max_sgprs = 2;
2920    if (has_literal)
2921       max_sgprs--;
2922 
2923    unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
2924 
2925    /* keep on applying sgprs until there is nothing left to be done */
2926    while (operand_mask) {
2927       uint32_t sgpr_idx = 0;
2928       uint32_t sgpr_info_id = 0;
2929       uint32_t mask = operand_mask;
2930       /* choose a sgpr */
2931       while (mask) {
2932          unsigned i = u_bit_scan(&mask);
2933          uint16_t uses = ctx.uses[instr->operands[i].tempId()];
2934          if (sgpr_info_id == 0 || uses < ctx.uses[sgpr_info_id]) {
2935             sgpr_idx = i;
2936             sgpr_info_id = instr->operands[i].tempId();
2937          }
2938       }
2939       operand_mask &= ~(1u << sgpr_idx);
2940 
2941       ssa_info& info = ctx.info[sgpr_info_id];
2942 
2943       /* Applying two sgprs require making it VOP3, so don't do it unless it's
2944        * definitively beneficial.
2945        * TODO: this is too conservative because later the use count could be reduced to 1 */
2946       if (!info.is_extract() && num_sgprs && ctx.uses[sgpr_info_id] > 1 && !instr->isVOP3() &&
2947           !instr->isSDWA() && instr->format != Format::VOP3P)
2948          break;
2949 
2950       Temp sgpr = info.is_extract() ? info.instr->operands[0].getTemp() : info.temp;
2951       bool new_sgpr = sgpr.id() != sgpr_ids[0] && sgpr.id() != sgpr_ids[1];
2952       if (new_sgpr && num_sgprs >= max_sgprs)
2953          continue;
2954 
2955       if (sgpr_idx == 0)
2956          instr->format = withoutDPP(instr->format);
2957 
2958       if (sgpr_idx == 1 && instr->isDPP())
2959          continue;
2960 
2961       if (sgpr_idx == 0 || instr->isVOP3() || instr->isSDWA() || instr->isVOP3P() ||
2962           info.is_extract()) {
2963          /* can_apply_extract() checks SGPR encoding restrictions */
2964          if (info.is_extract() && can_apply_extract(ctx, instr, sgpr_idx, info))
2965             apply_extract(ctx, instr, sgpr_idx, info);
2966          else if (info.is_extract())
2967             continue;
2968          instr->operands[sgpr_idx] = Operand(sgpr);
2969       } else if (can_swap_operands(instr, &instr->opcode) && !instr->valu().opsel[sgpr_idx]) {
2970          instr->operands[sgpr_idx] = instr->operands[0];
2971          instr->operands[0] = Operand(sgpr);
2972          instr->valu().opsel[0].swap(instr->valu().opsel[sgpr_idx]);
2973          /* swap bits using a 4-entry LUT */
2974          uint32_t swapped = (0x3120 >> (operand_mask & 0x3)) & 0xf;
2975          operand_mask = (operand_mask & ~0x3) | swapped;
2976       } else if (can_use_VOP3(ctx, instr) && !info.is_extract()) {
2977          instr->format = asVOP3(instr->format);
2978          instr->operands[sgpr_idx] = Operand(sgpr);
2979       } else {
2980          continue;
2981       }
2982 
2983       if (new_sgpr)
2984          sgpr_ids[num_sgprs++] = sgpr.id();
2985       ctx.uses[sgpr_info_id]--;
2986       ctx.uses[sgpr.id()]++;
2987 
2988       /* TODO: handle when it's a VGPR */
2989       if ((ctx.info[sgpr.id()].label & (label_extract | label_temp)) &&
2990           ctx.info[sgpr.id()].temp.type() == RegType::sgpr)
2991          operand_mask |= 1u << sgpr_idx;
2992    }
2993 }
2994 
2995 bool
interp_can_become_fma(opt_ctx & ctx,aco_ptr<Instruction> & instr)2996 interp_can_become_fma(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2997 {
2998    if (instr->opcode != aco_opcode::v_interp_p2_f32_inreg)
2999       return false;
3000 
3001    instr->opcode = aco_opcode::v_fma_f32;
3002    instr->format = Format::VOP3;
3003    bool dpp_allowed = can_use_DPP(ctx.program->gfx_level, instr, false);
3004    instr->opcode = aco_opcode::v_interp_p2_f32_inreg;
3005    instr->format = Format::VINTERP_INREG;
3006 
3007    return dpp_allowed;
3008 }
3009 
3010 void
interp_p2_f32_inreg_to_fma_dpp(aco_ptr<Instruction> & instr)3011 interp_p2_f32_inreg_to_fma_dpp(aco_ptr<Instruction>& instr)
3012 {
3013    static_assert(sizeof(DPP16_instruction) == sizeof(VINTERP_inreg_instruction),
3014                  "Invalid instr cast.");
3015    instr->format = asVOP3(Format::DPP16);
3016    instr->opcode = aco_opcode::v_fma_f32;
3017    instr->dpp16().dpp_ctrl = dpp_quad_perm(2, 2, 2, 2);
3018    instr->dpp16().row_mask = 0xf;
3019    instr->dpp16().bank_mask = 0xf;
3020    instr->dpp16().bound_ctrl = 0;
3021    instr->dpp16().fetch_inactive = 1;
3022 }
3023 
3024 /* apply omod / clamp modifiers if the def is used only once and the instruction can have modifiers */
3025 bool
apply_omod_clamp(opt_ctx & ctx,aco_ptr<Instruction> & instr)3026 apply_omod_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3027 {
3028    if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1 ||
3029        !instr_info.can_use_output_modifiers[(int)instr->opcode])
3030       return false;
3031 
3032    bool can_vop3 = can_use_VOP3(ctx, instr);
3033    bool is_mad_mix =
3034       instr->opcode == aco_opcode::v_fma_mix_f32 || instr->opcode == aco_opcode::v_fma_mixlo_f16;
3035    bool needs_vop3 = !instr->isSDWA() && !instr->isVINTERP_INREG() && !is_mad_mix;
3036    if (needs_vop3 && !can_vop3)
3037       return false;
3038 
3039    /* SDWA omod is GFX9+. */
3040    bool can_use_omod = (can_vop3 || ctx.program->gfx_level >= GFX9) && !instr->isVOP3P() &&
3041                        (!instr->isVINTERP_INREG() || interp_can_become_fma(ctx, instr));
3042 
3043    ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3044 
3045    uint64_t omod_labels = label_omod2 | label_omod4 | label_omod5;
3046    if (!def_info.is_clamp() && !(can_use_omod && (def_info.label & omod_labels)))
3047       return false;
3048    /* if the omod/clamp instruction is dead, then the single user of this
3049     * instruction is a different instruction */
3050    if (!ctx.uses[def_info.instr->definitions[0].tempId()])
3051       return false;
3052 
3053    if (def_info.instr->definitions[0].bytes() != instr->definitions[0].bytes())
3054       return false;
3055 
3056    /* MADs/FMAs are created later, so we don't have to update the original add */
3057    assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
3058 
3059    if (!def_info.is_clamp() && (instr->valu().clamp || instr->valu().omod))
3060       return false;
3061 
3062    if (needs_vop3)
3063       instr->format = asVOP3(instr->format);
3064 
3065    if (!def_info.is_clamp() && instr->opcode == aco_opcode::v_interp_p2_f32_inreg)
3066       interp_p2_f32_inreg_to_fma_dpp(instr);
3067 
3068    if (def_info.is_omod2())
3069       instr->valu().omod = 1;
3070    else if (def_info.is_omod4())
3071       instr->valu().omod = 2;
3072    else if (def_info.is_omod5())
3073       instr->valu().omod = 3;
3074    else if (def_info.is_clamp())
3075       instr->valu().clamp = true;
3076 
3077    instr->definitions[0].swapTemp(def_info.instr->definitions[0]);
3078    ctx.info[instr->definitions[0].tempId()].label &= label_clamp | label_insert | label_f2f16;
3079    ctx.uses[def_info.instr->definitions[0].tempId()]--;
3080 
3081    return true;
3082 }
3083 
3084 /* Combine an p_insert (or p_extract, in some cases) instruction with instr.
3085  * p_insert(instr(...)) -> instr_insert().
3086  */
3087 bool
apply_insert(opt_ctx & ctx,aco_ptr<Instruction> & instr)3088 apply_insert(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3089 {
3090    if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1)
3091       return false;
3092 
3093    ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3094    if (!def_info.is_insert())
3095       return false;
3096    /* if the insert instruction is dead, then the single user of this
3097     * instruction is a different instruction */
3098    if (!ctx.uses[def_info.instr->definitions[0].tempId()])
3099       return false;
3100 
3101    /* MADs/FMAs are created later, so we don't have to update the original add */
3102    assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
3103 
3104    SubdwordSel sel = parse_insert(def_info.instr);
3105    assert(sel);
3106 
3107    if (!can_use_SDWA(ctx.program->gfx_level, instr, true))
3108       return false;
3109 
3110    convert_to_SDWA(ctx.program->gfx_level, instr);
3111    if (instr->sdwa().dst_sel.size() != 4)
3112       return false;
3113    instr->sdwa().dst_sel = sel;
3114 
3115    instr->definitions[0].swapTemp(def_info.instr->definitions[0]);
3116    ctx.info[instr->definitions[0].tempId()].label = 0;
3117    ctx.uses[def_info.instr->definitions[0].tempId()]--;
3118 
3119    return true;
3120 }
3121 
3122 /* Remove superfluous extract after ds_read like so:
3123  * p_extract(ds_read_uN(), 0, N, 0) -> ds_read_uN()
3124  */
3125 bool
apply_ds_extract(opt_ctx & ctx,aco_ptr<Instruction> & extract)3126 apply_ds_extract(opt_ctx& ctx, aco_ptr<Instruction>& extract)
3127 {
3128    /* Check if p_extract has a usedef operand and is the only user. */
3129    if (!ctx.info[extract->operands[0].tempId()].is_usedef() ||
3130        ctx.uses[extract->operands[0].tempId()] > 1)
3131       return false;
3132 
3133    /* Check if the usedef is a DS instruction. */
3134    Instruction* ds = ctx.info[extract->operands[0].tempId()].instr;
3135    if (ds->format != Format::DS)
3136       return false;
3137 
3138    unsigned extract_idx = extract->operands[1].constantValue();
3139    unsigned bits_extracted = extract->operands[2].constantValue();
3140    unsigned sign_ext = extract->operands[3].constantValue();
3141    unsigned dst_bitsize = extract->definitions[0].bytes() * 8u;
3142 
3143    /* TODO: These are doable, but probably don't occur too often. */
3144    if (extract_idx || sign_ext || dst_bitsize != 32)
3145       return false;
3146 
3147    unsigned bits_loaded = 0;
3148    if (ds->opcode == aco_opcode::ds_read_u8 || ds->opcode == aco_opcode::ds_read_u8_d16)
3149       bits_loaded = 8;
3150    else if (ds->opcode == aco_opcode::ds_read_u16 || ds->opcode == aco_opcode::ds_read_u16_d16)
3151       bits_loaded = 16;
3152    else
3153       return false;
3154 
3155    /* Shrink the DS load if the extracted bit size is smaller. */
3156    bits_loaded = MIN2(bits_loaded, bits_extracted);
3157 
3158    /* Change the DS opcode so it writes the full register. */
3159    if (bits_loaded == 8)
3160       ds->opcode = aco_opcode::ds_read_u8;
3161    else if (bits_loaded == 16)
3162       ds->opcode = aco_opcode::ds_read_u16;
3163    else
3164       unreachable("Forgot to add DS opcode above.");
3165 
3166    /* The DS now produces the exact same thing as the extract, remove the extract. */
3167    std::swap(ds->definitions[0], extract->definitions[0]);
3168    ctx.uses[extract->definitions[0].tempId()] = 0;
3169    ctx.info[ds->definitions[0].tempId()].label = 0;
3170    return true;
3171 }
3172 
3173 /* v_and(a, v_subbrev_co(0, 0, vcc)) -> v_cndmask(0, a, vcc) */
3174 bool
combine_and_subbrev(opt_ctx & ctx,aco_ptr<Instruction> & instr)3175 combine_and_subbrev(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3176 {
3177    if (instr->usesModifiers())
3178       return false;
3179 
3180    for (unsigned i = 0; i < 2; i++) {
3181       Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
3182       if (op_instr && op_instr->opcode == aco_opcode::v_subbrev_co_u32 &&
3183           op_instr->operands[0].constantEquals(0) && op_instr->operands[1].constantEquals(0) &&
3184           !op_instr->usesModifiers()) {
3185 
3186          aco_ptr<Instruction> new_instr;
3187          if (instr->operands[!i].isTemp() &&
3188              instr->operands[!i].getTemp().type() == RegType::vgpr) {
3189             new_instr.reset(create_instruction(aco_opcode::v_cndmask_b32, Format::VOP2, 3, 1));
3190          } else if (ctx.program->gfx_level >= GFX10 ||
3191                     (instr->operands[!i].isConstant() && !instr->operands[!i].isLiteral())) {
3192             new_instr.reset(
3193                create_instruction(aco_opcode::v_cndmask_b32, asVOP3(Format::VOP2), 3, 1));
3194          } else {
3195             return false;
3196          }
3197 
3198          new_instr->operands[0] = Operand::zero();
3199          new_instr->operands[1] = instr->operands[!i];
3200          new_instr->operands[2] = copy_operand(ctx, op_instr->operands[2]);
3201          new_instr->definitions[0] = instr->definitions[0];
3202          new_instr->pass_flags = instr->pass_flags;
3203          instr = std::move(new_instr);
3204          decrease_uses(ctx, op_instr);
3205          ctx.info[instr->definitions[0].tempId()].label = 0;
3206          return true;
3207       }
3208    }
3209 
3210    return false;
3211 }
3212 
3213 /* v_and(a, not(b)) -> v_bfi_b32(b, 0, a)
3214  * v_or(a, not(b)) -> v_bfi_b32(b, a, -1)
3215  */
3216 bool
combine_v_andor_not(opt_ctx & ctx,aco_ptr<Instruction> & instr)3217 combine_v_andor_not(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3218 {
3219    if (instr->usesModifiers())
3220       return false;
3221 
3222    for (unsigned i = 0; i < 2; i++) {
3223       Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
3224       if (op_instr && !op_instr->usesModifiers() &&
3225           (op_instr->opcode == aco_opcode::v_not_b32 ||
3226            op_instr->opcode == aco_opcode::s_not_b32)) {
3227 
3228          Operand ops[3] = {
3229             op_instr->operands[0],
3230             Operand::zero(),
3231             instr->operands[!i],
3232          };
3233          if (instr->opcode == aco_opcode::v_or_b32) {
3234             ops[1] = instr->operands[!i];
3235             ops[2] = Operand::c32(-1);
3236          }
3237          if (!check_vop3_operands(ctx, 3, ops))
3238             continue;
3239 
3240          Instruction* new_instr = create_instruction(aco_opcode::v_bfi_b32, Format::VOP3, 3, 1);
3241 
3242          if (op_instr->operands[0].isTemp())
3243             ctx.uses[op_instr->operands[0].tempId()]++;
3244          for (unsigned j = 0; j < 3; j++)
3245             new_instr->operands[j] = ops[j];
3246          new_instr->definitions[0] = instr->definitions[0];
3247          new_instr->pass_flags = instr->pass_flags;
3248          instr.reset(new_instr);
3249          decrease_uses(ctx, op_instr);
3250          ctx.info[instr->definitions[0].tempId()].label = 0;
3251          return true;
3252       }
3253    }
3254 
3255    return false;
3256 }
3257 
3258 /* v_add_co(c, s_lshl(a, b)) -> v_mad_u32_u24(a, 1<<b, c)
3259  * v_add_co(c, v_lshlrev(a, b)) -> v_mad_u32_u24(b, 1<<a, c)
3260  * v_sub(c, s_lshl(a, b)) -> v_mad_i32_i24(a, -(1<<b), c)
3261  * v_sub(c, v_lshlrev(a, b)) -> v_mad_i32_i24(b, -(1<<a), c)
3262  */
3263 bool
combine_add_lshl(opt_ctx & ctx,aco_ptr<Instruction> & instr,bool is_sub)3264 combine_add_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr, bool is_sub)
3265 {
3266    if (instr->usesModifiers())
3267       return false;
3268 
3269    /* Substractions: start at operand 1 to avoid mixup such as
3270     * turning v_sub(v_lshlrev(a, b), c) into v_mad_i32_i24(b, -(1<<a), c)
3271     */
3272    unsigned start_op_idx = is_sub ? 1 : 0;
3273 
3274    /* Don't allow 24-bit operands on subtraction because
3275     * v_mad_i32_i24 applies a sign extension.
3276     */
3277    bool allow_24bit = !is_sub;
3278 
3279    for (unsigned i = start_op_idx; i < 2; i++) {
3280       Instruction* op_instr = follow_operand(ctx, instr->operands[i]);
3281       if (!op_instr)
3282          continue;
3283 
3284       if (op_instr->opcode != aco_opcode::s_lshl_b32 &&
3285           op_instr->opcode != aco_opcode::v_lshlrev_b32)
3286          continue;
3287 
3288       int shift_op_idx = op_instr->opcode == aco_opcode::s_lshl_b32 ? 1 : 0;
3289 
3290       if (op_instr->operands[shift_op_idx].isConstant() &&
3291           ((allow_24bit && op_instr->operands[!shift_op_idx].is24bit()) ||
3292            op_instr->operands[!shift_op_idx].is16bit())) {
3293          uint32_t multiplier = 1 << (op_instr->operands[shift_op_idx].constantValue() % 32u);
3294          if (is_sub)
3295             multiplier = -multiplier;
3296          if (is_sub ? (multiplier < 0xff800000) : (multiplier > 0xffffff))
3297             continue;
3298 
3299          Operand ops[3] = {
3300             op_instr->operands[!shift_op_idx],
3301             Operand::c32(multiplier),
3302             instr->operands[!i],
3303          };
3304          if (!check_vop3_operands(ctx, 3, ops))
3305             return false;
3306 
3307          ctx.uses[instr->operands[i].tempId()]--;
3308 
3309          aco_opcode mad_op = is_sub ? aco_opcode::v_mad_i32_i24 : aco_opcode::v_mad_u32_u24;
3310          aco_ptr<Instruction> new_instr{create_instruction(mad_op, Format::VOP3, 3, 1)};
3311          for (unsigned op_idx = 0; op_idx < 3; ++op_idx)
3312             new_instr->operands[op_idx] = ops[op_idx];
3313          new_instr->definitions[0] = instr->definitions[0];
3314          new_instr->pass_flags = instr->pass_flags;
3315          instr = std::move(new_instr);
3316          ctx.info[instr->definitions[0].tempId()].label = 0;
3317          return true;
3318       }
3319    }
3320 
3321    return false;
3322 }
3323 
3324 void
propagate_swizzles(VALU_instruction * instr,bool opsel_lo,bool opsel_hi)3325 propagate_swizzles(VALU_instruction* instr, bool opsel_lo, bool opsel_hi)
3326 {
3327    /* propagate swizzles which apply to a result down to the instruction's operands:
3328     * result = a.xy + b.xx -> result.yx = a.yx + b.xx */
3329    uint8_t tmp_lo = instr->opsel_lo;
3330    uint8_t tmp_hi = instr->opsel_hi;
3331    uint8_t neg_lo = instr->neg_lo;
3332    uint8_t neg_hi = instr->neg_hi;
3333    if (opsel_lo == 1) {
3334       instr->opsel_lo = tmp_hi;
3335       instr->neg_lo = neg_hi;
3336    }
3337    if (opsel_hi == 0) {
3338       instr->opsel_hi = tmp_lo;
3339       instr->neg_hi = neg_lo;
3340    }
3341 }
3342 
3343 void
combine_vop3p(opt_ctx & ctx,aco_ptr<Instruction> & instr)3344 combine_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3345 {
3346    VALU_instruction* vop3p = &instr->valu();
3347 
3348    /* apply clamp */
3349    if (instr->opcode == aco_opcode::v_pk_mul_f16 && instr->operands[1].constantEquals(0x3C00) &&
3350        vop3p->clamp && instr->operands[0].isTemp() && ctx.uses[instr->operands[0].tempId()] == 1 &&
3351        !vop3p->opsel_lo[1] && !vop3p->opsel_hi[1]) {
3352 
3353       ssa_info& info = ctx.info[instr->operands[0].tempId()];
3354       if (info.is_vop3p() && instr_info.can_use_output_modifiers[(int)info.instr->opcode]) {
3355          VALU_instruction* candidate = &ctx.info[instr->operands[0].tempId()].instr->valu();
3356          candidate->clamp = true;
3357          propagate_swizzles(candidate, vop3p->opsel_lo[0], vop3p->opsel_hi[0]);
3358          instr->definitions[0].swapTemp(candidate->definitions[0]);
3359          ctx.info[candidate->definitions[0].tempId()].instr = candidate;
3360          ctx.uses[instr->definitions[0].tempId()]--;
3361          return;
3362       }
3363    }
3364 
3365    /* check for fneg modifiers */
3366    for (unsigned i = 0; i < instr->operands.size(); i++) {
3367       if (!can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, i))
3368          continue;
3369       Operand& op = instr->operands[i];
3370       if (!op.isTemp())
3371          continue;
3372 
3373       ssa_info& info = ctx.info[op.tempId()];
3374       if (info.is_vop3p() && info.instr->opcode == aco_opcode::v_pk_mul_f16 &&
3375           (info.instr->operands[0].constantEquals(0x3C00) ||
3376            info.instr->operands[1].constantEquals(0x3C00))) {
3377 
3378          VALU_instruction* fneg = &info.instr->valu();
3379 
3380          unsigned fneg_src = fneg->operands[0].constantEquals(0x3C00);
3381 
3382          if (fneg->opsel_lo[1 - fneg_src] || fneg->opsel_hi[1 - fneg_src])
3383             continue;
3384 
3385          Operand ops[3];
3386          for (unsigned j = 0; j < instr->operands.size(); j++)
3387             ops[j] = instr->operands[j];
3388          ops[i] = fneg->operands[fneg_src];
3389          if (!check_vop3_operands(ctx, instr->operands.size(), ops))
3390             continue;
3391 
3392          if (fneg->clamp)
3393             continue;
3394          instr->operands[i] = fneg->operands[fneg_src];
3395 
3396          /* opsel_lo/hi is either 0 or 1:
3397           * if 0 - pick selection from fneg->lo
3398           * if 1 - pick selection from fneg->hi
3399           */
3400          bool opsel_lo = vop3p->opsel_lo[i];
3401          bool opsel_hi = vop3p->opsel_hi[i];
3402          bool neg_lo = fneg->neg_lo[0] ^ fneg->neg_lo[1];
3403          bool neg_hi = fneg->neg_hi[0] ^ fneg->neg_hi[1];
3404          vop3p->neg_lo[i] ^= opsel_lo ? neg_hi : neg_lo;
3405          vop3p->neg_hi[i] ^= opsel_hi ? neg_hi : neg_lo;
3406          vop3p->opsel_lo[i] ^= opsel_lo ? !fneg->opsel_hi[fneg_src] : fneg->opsel_lo[fneg_src];
3407          vop3p->opsel_hi[i] ^= opsel_hi ? !fneg->opsel_hi[fneg_src] : fneg->opsel_lo[fneg_src];
3408 
3409          if (--ctx.uses[fneg->definitions[0].tempId()])
3410             ctx.uses[fneg->operands[fneg_src].tempId()]++;
3411       }
3412    }
3413 
3414    if (instr->opcode == aco_opcode::v_pk_add_f16 || instr->opcode == aco_opcode::v_pk_add_u16) {
3415       bool fadd = instr->opcode == aco_opcode::v_pk_add_f16;
3416       if (fadd && instr->definitions[0].isPrecise())
3417          return;
3418       if (!fadd && instr->valu().clamp)
3419          return;
3420 
3421       Instruction* mul_instr = nullptr;
3422       unsigned add_op_idx = 0;
3423       bitarray8 mul_neg_lo = 0, mul_neg_hi = 0, mul_opsel_lo = 0, mul_opsel_hi = 0;
3424       uint32_t uses = UINT32_MAX;
3425 
3426       /* find the 'best' mul instruction to combine with the add */
3427       for (unsigned i = 0; i < 2; i++) {
3428          Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
3429          if (!op_instr)
3430             continue;
3431 
3432          if (ctx.info[instr->operands[i].tempId()].is_vop3p()) {
3433             if (fadd) {
3434                if (op_instr->opcode != aco_opcode::v_pk_mul_f16 ||
3435                    op_instr->definitions[0].isPrecise())
3436                   continue;
3437             } else {
3438                if (op_instr->opcode != aco_opcode::v_pk_mul_lo_u16)
3439                   continue;
3440             }
3441 
3442             Operand op[3] = {op_instr->operands[0], op_instr->operands[1], instr->operands[1 - i]};
3443             if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op))
3444                continue;
3445 
3446             /* no clamp allowed between mul and add */
3447             if (op_instr->valu().clamp)
3448                continue;
3449 
3450             mul_instr = op_instr;
3451             add_op_idx = 1 - i;
3452             uses = ctx.uses[instr->operands[i].tempId()];
3453             mul_neg_lo = mul_instr->valu().neg_lo;
3454             mul_neg_hi = mul_instr->valu().neg_hi;
3455             mul_opsel_lo = mul_instr->valu().opsel_lo;
3456             mul_opsel_hi = mul_instr->valu().opsel_hi;
3457          } else if (instr->operands[i].bytes() == 2) {
3458             if ((fadd && (op_instr->opcode != aco_opcode::v_mul_f16 ||
3459                           op_instr->definitions[0].isPrecise())) ||
3460                 (!fadd && op_instr->opcode != aco_opcode::v_mul_lo_u16 &&
3461                  op_instr->opcode != aco_opcode::v_mul_lo_u16_e64))
3462                continue;
3463 
3464             if (op_instr->valu().clamp || op_instr->valu().omod || op_instr->valu().abs)
3465                continue;
3466 
3467             if (op_instr->isDPP() || (op_instr->isSDWA() && (op_instr->sdwa().sel[0].size() < 2 ||
3468                                                              op_instr->sdwa().sel[1].size() < 2)))
3469                continue;
3470 
3471             Operand op[3] = {op_instr->operands[0], op_instr->operands[1], instr->operands[1 - i]};
3472             if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op))
3473                continue;
3474 
3475             mul_instr = op_instr;
3476             add_op_idx = 1 - i;
3477             uses = ctx.uses[instr->operands[i].tempId()];
3478             mul_neg_lo = mul_instr->valu().neg;
3479             mul_neg_hi = mul_instr->valu().neg;
3480             if (mul_instr->isSDWA()) {
3481                for (unsigned j = 0; j < 2; j++)
3482                   mul_opsel_lo[j] = mul_instr->sdwa().sel[j].offset();
3483             } else {
3484                mul_opsel_lo = mul_instr->valu().opsel;
3485             }
3486             mul_opsel_hi = mul_opsel_lo;
3487          }
3488       }
3489 
3490       if (!mul_instr)
3491          return;
3492 
3493       /* turn mul + packed add into v_pk_fma_f16 */
3494       aco_opcode mad = fadd ? aco_opcode::v_pk_fma_f16 : aco_opcode::v_pk_mad_u16;
3495       aco_ptr<Instruction> fma{create_instruction(mad, Format::VOP3P, 3, 1)};
3496       fma->operands[0] = copy_operand(ctx, mul_instr->operands[0]);
3497       fma->operands[1] = copy_operand(ctx, mul_instr->operands[1]);
3498       fma->operands[2] = instr->operands[add_op_idx];
3499       fma->valu().clamp = vop3p->clamp;
3500       fma->valu().neg_lo = mul_neg_lo;
3501       fma->valu().neg_hi = mul_neg_hi;
3502       fma->valu().opsel_lo = mul_opsel_lo;
3503       fma->valu().opsel_hi = mul_opsel_hi;
3504       propagate_swizzles(&fma->valu(), vop3p->opsel_lo[1 - add_op_idx],
3505                          vop3p->opsel_hi[1 - add_op_idx]);
3506       fma->valu().opsel_lo[2] = vop3p->opsel_lo[add_op_idx];
3507       fma->valu().opsel_hi[2] = vop3p->opsel_hi[add_op_idx];
3508       fma->valu().neg_lo[2] = vop3p->neg_lo[add_op_idx];
3509       fma->valu().neg_hi[2] = vop3p->neg_hi[add_op_idx];
3510       fma->valu().neg_lo[1] = fma->valu().neg_lo[1] ^ vop3p->neg_lo[1 - add_op_idx];
3511       fma->valu().neg_hi[1] = fma->valu().neg_hi[1] ^ vop3p->neg_hi[1 - add_op_idx];
3512       fma->definitions[0] = instr->definitions[0];
3513       fma->pass_flags = instr->pass_flags;
3514       instr = std::move(fma);
3515       ctx.info[instr->definitions[0].tempId()].set_vop3p(instr.get());
3516       decrease_uses(ctx, mul_instr);
3517       return;
3518    }
3519 }
3520 
3521 bool
can_use_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3522 can_use_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3523 {
3524    if (ctx.program->gfx_level < GFX9)
3525       return false;
3526 
3527    /* v_mad_mix* on GFX9 always flushes denormals for 16-bit inputs/outputs */
3528    if (ctx.program->gfx_level == GFX9 && ctx.fp_mode.denorm16_64)
3529       return false;
3530 
3531    if (instr->valu().omod)
3532       return false;
3533 
3534    switch (instr->opcode) {
3535    case aco_opcode::v_add_f32:
3536    case aco_opcode::v_sub_f32:
3537    case aco_opcode::v_subrev_f32:
3538    case aco_opcode::v_mul_f32: return !instr->isSDWA() && !instr->isDPP();
3539    case aco_opcode::v_fma_f32:
3540       return ctx.program->dev.fused_mad_mix || !instr->definitions[0].isPrecise();
3541    case aco_opcode::v_fma_mix_f32:
3542    case aco_opcode::v_fma_mixlo_f16: return true;
3543    default: return false;
3544    }
3545 }
3546 
3547 void
to_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3548 to_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3549 {
3550    ctx.info[instr->definitions[0].tempId()].label &= label_f2f16 | label_clamp | label_mul;
3551 
3552    if (instr->opcode == aco_opcode::v_fma_f32) {
3553       instr->format = (Format)((uint32_t)withoutVOP3(instr->format) | (uint32_t)(Format::VOP3P));
3554       instr->opcode = aco_opcode::v_fma_mix_f32;
3555       return;
3556    }
3557 
3558    bool is_add = instr->opcode != aco_opcode::v_mul_f32;
3559 
3560    aco_ptr<Instruction> vop3p{create_instruction(aco_opcode::v_fma_mix_f32, Format::VOP3P, 3, 1)};
3561 
3562    for (unsigned i = 0; i < instr->operands.size(); i++) {
3563       vop3p->operands[is_add + i] = instr->operands[i];
3564       vop3p->valu().neg_lo[is_add + i] = instr->valu().neg[i];
3565       vop3p->valu().neg_hi[is_add + i] = instr->valu().abs[i];
3566    }
3567    if (instr->opcode == aco_opcode::v_mul_f32) {
3568       vop3p->operands[2] = Operand::zero();
3569       vop3p->valu().neg_lo[2] = true;
3570    } else if (is_add) {
3571       vop3p->operands[0] = Operand::c32(0x3f800000);
3572       if (instr->opcode == aco_opcode::v_sub_f32)
3573          vop3p->valu().neg_lo[2] ^= true;
3574       else if (instr->opcode == aco_opcode::v_subrev_f32)
3575          vop3p->valu().neg_lo[1] ^= true;
3576    }
3577    vop3p->definitions[0] = instr->definitions[0];
3578    vop3p->valu().clamp = instr->valu().clamp;
3579    vop3p->pass_flags = instr->pass_flags;
3580    instr = std::move(vop3p);
3581 
3582    if (ctx.info[instr->definitions[0].tempId()].label & label_mul)
3583       ctx.info[instr->definitions[0].tempId()].instr = instr.get();
3584 }
3585 
3586 bool
combine_output_conversion(opt_ctx & ctx,aco_ptr<Instruction> & instr)3587 combine_output_conversion(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3588 {
3589    ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3590    if (!def_info.is_f2f16())
3591       return false;
3592    Instruction* conv = def_info.instr;
3593 
3594    if (!ctx.uses[conv->definitions[0].tempId()] || ctx.uses[instr->definitions[0].tempId()] != 1)
3595       return false;
3596 
3597    if (conv->usesModifiers())
3598       return false;
3599 
3600    if (interp_can_become_fma(ctx, instr))
3601       interp_p2_f32_inreg_to_fma_dpp(instr);
3602 
3603    if (!can_use_mad_mix(ctx, instr))
3604       return false;
3605 
3606    if (!instr->isVOP3P())
3607       to_mad_mix(ctx, instr);
3608 
3609    instr->opcode = aco_opcode::v_fma_mixlo_f16;
3610    instr->definitions[0].swapTemp(conv->definitions[0]);
3611    if (conv->definitions[0].isPrecise())
3612       instr->definitions[0].setPrecise(true);
3613    ctx.info[instr->definitions[0].tempId()].label &= label_clamp;
3614    ctx.uses[conv->definitions[0].tempId()]--;
3615 
3616    return true;
3617 }
3618 
3619 void
combine_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3620 combine_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3621 {
3622    if (!can_use_mad_mix(ctx, instr))
3623       return;
3624 
3625    for (unsigned i = 0; i < instr->operands.size(); i++) {
3626       if (!instr->operands[i].isTemp())
3627          continue;
3628       Temp tmp = instr->operands[i].getTemp();
3629       if (!ctx.info[tmp.id()].is_f2f32())
3630          continue;
3631 
3632       Instruction* conv = ctx.info[tmp.id()].instr;
3633       if (conv->valu().clamp || conv->valu().omod) {
3634          continue;
3635       } else if (conv->isSDWA() &&
3636                  (conv->sdwa().dst_sel.size() != 4 || conv->sdwa().sel[0].size() != 2)) {
3637          continue;
3638       } else if (conv->isDPP()) {
3639          continue;
3640       }
3641 
3642       if (get_operand_size(instr, i) != 32)
3643          continue;
3644 
3645       /* Conversion to VOP3P will add inline constant operands, but that shouldn't affect
3646        * check_vop3_operands(). */
3647       Operand op[3];
3648       for (unsigned j = 0; j < instr->operands.size(); j++)
3649          op[j] = instr->operands[j];
3650       op[i] = conv->operands[0];
3651       if (!check_vop3_operands(ctx, instr->operands.size(), op))
3652          continue;
3653       if (!conv->operands[0].isOfType(RegType::vgpr) && instr->isDPP())
3654          continue;
3655 
3656       if (!instr->isVOP3P()) {
3657          bool is_add =
3658             instr->opcode != aco_opcode::v_mul_f32 && instr->opcode != aco_opcode::v_fma_f32;
3659          to_mad_mix(ctx, instr);
3660          i += is_add;
3661       }
3662 
3663       if (--ctx.uses[tmp.id()])
3664          ctx.uses[conv->operands[0].tempId()]++;
3665       instr->operands[i].setTemp(conv->operands[0].getTemp());
3666       if (conv->definitions[0].isPrecise())
3667          instr->definitions[0].setPrecise(true);
3668       instr->valu().opsel_hi[i] = true;
3669       if (conv->isSDWA() && conv->sdwa().sel[0].offset() == 2)
3670          instr->valu().opsel_lo[i] = true;
3671       else
3672          instr->valu().opsel_lo[i] = conv->valu().opsel[0];
3673       bool neg = conv->valu().neg[0];
3674       bool abs = conv->valu().abs[0];
3675       if (!instr->valu().abs[i]) {
3676          instr->valu().neg[i] ^= neg;
3677          instr->valu().abs[i] = abs;
3678       }
3679    }
3680 }
3681 
3682 // TODO: we could possibly move the whole label_instruction pass to combine_instruction:
3683 // this would mean that we'd have to fix the instruction uses while value propagation
3684 
3685 /* also returns true for inf */
3686 bool
is_pow_of_two(opt_ctx & ctx,Operand op)3687 is_pow_of_two(opt_ctx& ctx, Operand op)
3688 {
3689    if (op.isTemp() && ctx.info[op.tempId()].is_constant_or_literal(op.bytes() * 8))
3690       return is_pow_of_two(ctx, get_constant_op(ctx, ctx.info[op.tempId()], op.bytes() * 8));
3691    else if (!op.isConstant())
3692       return false;
3693 
3694    uint64_t val = op.constantValue64();
3695 
3696    if (op.bytes() == 4) {
3697       uint32_t exponent = (val & 0x7f800000) >> 23;
3698       uint32_t fraction = val & 0x007fffff;
3699       return (exponent >= 127) && (fraction == 0);
3700    } else if (op.bytes() == 2) {
3701       uint32_t exponent = (val & 0x7c00) >> 10;
3702       uint32_t fraction = val & 0x03ff;
3703       return (exponent >= 15) && (fraction == 0);
3704    } else {
3705       assert(op.bytes() == 8);
3706       uint64_t exponent = (val & UINT64_C(0x7ff0000000000000)) >> 52;
3707       uint64_t fraction = val & UINT64_C(0x000fffffffffffff);
3708       return (exponent >= 1023) && (fraction == 0);
3709    }
3710 }
3711 
3712 void
combine_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)3713 combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3714 {
3715    if (instr->definitions.empty() || is_dead(ctx.uses, instr.get()))
3716       return;
3717 
3718    if (instr->isVALU() || instr->isSALU()) {
3719       /* Apply SDWA. Do this after label_instruction() so it can remove
3720        * label_extract if not all instructions can take SDWA. */
3721       for (unsigned i = 0; i < instr->operands.size(); i++) {
3722          Operand& op = instr->operands[i];
3723          if (!op.isTemp())
3724             continue;
3725          ssa_info& info = ctx.info[op.tempId()];
3726          if (!info.is_extract())
3727             continue;
3728          /* if there are that many uses, there are likely better combinations */
3729          // TODO: delay applying extract to a point where we know better
3730          if (ctx.uses[op.tempId()] > 4) {
3731             info.label &= ~label_extract;
3732             continue;
3733          }
3734          if (info.is_extract() &&
3735              (info.instr->operands[0].getTemp().type() == RegType::vgpr ||
3736               instr->operands[i].getTemp().type() == RegType::sgpr) &&
3737              can_apply_extract(ctx, instr, i, info)) {
3738             /* Increase use count of the extract's operand if the extract still has uses. */
3739             apply_extract(ctx, instr, i, info);
3740             if (--ctx.uses[instr->operands[i].tempId()])
3741                ctx.uses[info.instr->operands[0].tempId()]++;
3742             instr->operands[i].setTemp(info.instr->operands[0].getTemp());
3743          }
3744       }
3745    }
3746 
3747    if (instr->isVALU()) {
3748       if (can_apply_sgprs(ctx, instr))
3749          apply_sgprs(ctx, instr);
3750       combine_mad_mix(ctx, instr);
3751       while (apply_omod_clamp(ctx, instr) || combine_output_conversion(ctx, instr))
3752          ;
3753       apply_insert(ctx, instr);
3754    }
3755 
3756    if (instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_mix_f32 &&
3757        instr->opcode != aco_opcode::v_fma_mixlo_f16)
3758       return combine_vop3p(ctx, instr);
3759 
3760    if (instr->isSDWA() || instr->isDPP())
3761       return;
3762 
3763    if (instr->opcode == aco_opcode::p_extract) {
3764       ssa_info& info = ctx.info[instr->operands[0].tempId()];
3765       if (info.is_extract() && can_apply_extract(ctx, instr, 0, info)) {
3766          apply_extract(ctx, instr, 0, info);
3767          if (--ctx.uses[instr->operands[0].tempId()])
3768             ctx.uses[info.instr->operands[0].tempId()]++;
3769          instr->operands[0].setTemp(info.instr->operands[0].getTemp());
3770       }
3771 
3772       apply_ds_extract(ctx, instr);
3773    }
3774 
3775    /* TODO: There are still some peephole optimizations that could be done:
3776     * - abs(a - b) -> s_absdiff_i32
3777     * - various patterns for s_bitcmp{0,1}_b32 and s_bitset{0,1}_b32
3778     * - patterns for v_alignbit_b32 and v_alignbyte_b32
3779     * These aren't probably too interesting though.
3780     * There are also patterns for v_cmp_class_f{16,32,64}. This is difficult but
3781     * probably more useful than the previously mentioned optimizations.
3782     * The various comparison optimizations also currently only work with 32-bit
3783     * floats. */
3784 
3785    /* neg(mul(a, b)) -> mul(neg(a), b), abs(mul(a, b)) -> mul(abs(a), abs(b)) */
3786    if ((ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs)) &&
3787        ctx.uses[instr->operands[1].tempId()] == 1) {
3788       Temp val = ctx.info[instr->definitions[0].tempId()].temp;
3789 
3790       if (!ctx.info[val.id()].is_mul())
3791          return;
3792 
3793       Instruction* mul_instr = ctx.info[val.id()].instr;
3794 
3795       if (mul_instr->operands[0].isLiteral())
3796          return;
3797       if (mul_instr->valu().clamp)
3798          return;
3799       if (mul_instr->isSDWA() || mul_instr->isDPP())
3800          return;
3801       if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32 &&
3802           mul_instr->definitions[0].isSZPreserve())
3803          return;
3804       if (mul_instr->definitions[0].bytes() != instr->definitions[0].bytes())
3805          return;
3806 
3807       /* convert to mul(neg(a), b), mul(abs(a), abs(b)) or mul(neg(abs(a)), abs(b)) */
3808       ctx.uses[mul_instr->definitions[0].tempId()]--;
3809       Definition def = instr->definitions[0];
3810       bool is_neg = ctx.info[instr->definitions[0].tempId()].is_neg();
3811       bool is_abs = ctx.info[instr->definitions[0].tempId()].is_abs();
3812       uint32_t pass_flags = instr->pass_flags;
3813       Format format = mul_instr->format == Format::VOP2 ? asVOP3(Format::VOP2) : mul_instr->format;
3814       instr.reset(create_instruction(mul_instr->opcode, format, mul_instr->operands.size(), 1));
3815       std::copy(mul_instr->operands.cbegin(), mul_instr->operands.cend(), instr->operands.begin());
3816       instr->pass_flags = pass_flags;
3817       instr->definitions[0] = def;
3818       VALU_instruction& new_mul = instr->valu();
3819       VALU_instruction& mul = mul_instr->valu();
3820       new_mul.neg = mul.neg;
3821       new_mul.abs = mul.abs;
3822       new_mul.omod = mul.omod;
3823       new_mul.opsel = mul.opsel;
3824       new_mul.opsel_lo = mul.opsel_lo;
3825       new_mul.opsel_hi = mul.opsel_hi;
3826       if (is_abs) {
3827          new_mul.neg[0] = new_mul.neg[1] = false;
3828          new_mul.abs[0] = new_mul.abs[1] = true;
3829       }
3830       new_mul.neg[0] ^= is_neg;
3831       new_mul.clamp = false;
3832 
3833       ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
3834       return;
3835    }
3836 
3837    /* combine mul+add -> mad */
3838    bool is_add_mix =
3839       (instr->opcode == aco_opcode::v_fma_mix_f32 ||
3840        instr->opcode == aco_opcode::v_fma_mixlo_f16) &&
3841       !instr->valu().neg_lo[0] &&
3842       ((instr->operands[0].constantEquals(0x3f800000) && !instr->valu().opsel_hi[0]) ||
3843        (instr->operands[0].constantEquals(0x3C00) && instr->valu().opsel_hi[0] &&
3844         !instr->valu().opsel_lo[0]));
3845    bool mad32 = instr->opcode == aco_opcode::v_add_f32 || instr->opcode == aco_opcode::v_sub_f32 ||
3846                 instr->opcode == aco_opcode::v_subrev_f32;
3847    bool mad16 = instr->opcode == aco_opcode::v_add_f16 || instr->opcode == aco_opcode::v_sub_f16 ||
3848                 instr->opcode == aco_opcode::v_subrev_f16;
3849    bool mad64 =
3850       instr->opcode == aco_opcode::v_add_f64_e64 || instr->opcode == aco_opcode::v_add_f64;
3851    if (is_add_mix || mad16 || mad32 || mad64) {
3852       Instruction* mul_instr = nullptr;
3853       unsigned add_op_idx = 0;
3854       uint32_t uses = UINT32_MAX;
3855       bool emit_fma = false;
3856       /* find the 'best' mul instruction to combine with the add */
3857       for (unsigned i = is_add_mix ? 1 : 0; i < instr->operands.size(); i++) {
3858          if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_mul())
3859             continue;
3860          ssa_info& info = ctx.info[instr->operands[i].tempId()];
3861 
3862          /* no clamp/omod allowed between mul and add */
3863          if (info.instr->isVOP3() && (info.instr->valu().clamp || info.instr->valu().omod))
3864             continue;
3865          if (info.instr->isVOP3P() && info.instr->valu().clamp)
3866             continue;
3867          /* v_fma_mix_f32/etc can't do omod */
3868          if (info.instr->isVOP3P() && instr->isVOP3() && instr->valu().omod)
3869             continue;
3870          /* don't promote fp16 to fp32 or remove fp32->fp16->fp32 conversions */
3871          if (is_add_mix && info.instr->definitions[0].bytes() == 2)
3872             continue;
3873 
3874          if (get_operand_size(instr, i) != info.instr->definitions[0].bytes() * 8)
3875             continue;
3876 
3877          bool legacy = info.instr->opcode == aco_opcode::v_mul_legacy_f32;
3878          bool mad_mix = is_add_mix || info.instr->isVOP3P();
3879 
3880          /* Multiplication by power-of-two should never need rounding. 1/power-of-two also works,
3881           * but using fma removes denormal flushing (0xfffffe * 0.5 + 0x810001a2).
3882           */
3883          bool is_fma_precise = is_pow_of_two(ctx, info.instr->operands[0]) ||
3884                                is_pow_of_two(ctx, info.instr->operands[1]);
3885 
3886          bool has_fma = mad16 || mad64 || (legacy && ctx.program->gfx_level >= GFX10_3) ||
3887                         (mad32 && !legacy && !mad_mix && ctx.program->dev.has_fast_fma32) ||
3888                         (mad_mix && ctx.program->dev.fused_mad_mix);
3889          bool has_mad = mad_mix ? !ctx.program->dev.fused_mad_mix
3890                                 : ((mad32 && ctx.program->gfx_level < GFX10_3) ||
3891                                    (mad16 && ctx.program->gfx_level <= GFX9));
3892          bool can_use_fma =
3893             has_fma &&
3894             (!(info.instr->definitions[0].isPrecise() || instr->definitions[0].isPrecise()) ||
3895              is_fma_precise);
3896          bool can_use_mad =
3897             has_mad && (mad_mix || mad32 ? ctx.fp_mode.denorm32 : ctx.fp_mode.denorm16_64) == 0;
3898          if (mad_mix && legacy)
3899             continue;
3900          if (!can_use_fma && !can_use_mad)
3901             continue;
3902 
3903          unsigned candidate_add_op_idx = is_add_mix ? (3 - i) : (1 - i);
3904          Operand op[3] = {info.instr->operands[0], info.instr->operands[1],
3905                           instr->operands[candidate_add_op_idx]};
3906          if (info.instr->isSDWA() || info.instr->isDPP() || !check_vop3_operands(ctx, 3, op) ||
3907              ctx.uses[instr->operands[i].tempId()] > uses)
3908             continue;
3909 
3910          if (ctx.uses[instr->operands[i].tempId()] == uses) {
3911             unsigned cur_idx = mul_instr->definitions[0].tempId();
3912             unsigned new_idx = info.instr->definitions[0].tempId();
3913             if (cur_idx > new_idx)
3914                continue;
3915          }
3916 
3917          mul_instr = info.instr;
3918          add_op_idx = candidate_add_op_idx;
3919          uses = ctx.uses[instr->operands[i].tempId()];
3920          emit_fma = !can_use_mad;
3921       }
3922 
3923       if (mul_instr) {
3924          /* turn mul+add into v_mad/v_fma */
3925          Operand op[3] = {mul_instr->operands[0], mul_instr->operands[1],
3926                           instr->operands[add_op_idx]};
3927          ctx.uses[mul_instr->definitions[0].tempId()]--;
3928          if (ctx.uses[mul_instr->definitions[0].tempId()]) {
3929             if (op[0].isTemp())
3930                ctx.uses[op[0].tempId()]++;
3931             if (op[1].isTemp())
3932                ctx.uses[op[1].tempId()]++;
3933          }
3934 
3935          bool neg[3] = {false, false, false};
3936          bool abs[3] = {false, false, false};
3937          unsigned omod = 0;
3938          bool clamp = false;
3939          bitarray8 opsel_lo = 0;
3940          bitarray8 opsel_hi = 0;
3941          bitarray8 opsel = 0;
3942          unsigned mul_op_idx = (instr->isVOP3P() ? 3 : 1) - add_op_idx;
3943 
3944          VALU_instruction& valu_mul = mul_instr->valu();
3945          neg[0] = valu_mul.neg[0];
3946          neg[1] = valu_mul.neg[1];
3947          abs[0] = valu_mul.abs[0];
3948          abs[1] = valu_mul.abs[1];
3949          opsel_lo = valu_mul.opsel_lo & 0x3;
3950          opsel_hi = valu_mul.opsel_hi & 0x3;
3951          opsel = valu_mul.opsel & 0x3;
3952 
3953          VALU_instruction& valu = instr->valu();
3954          neg[2] = valu.neg[add_op_idx];
3955          abs[2] = valu.abs[add_op_idx];
3956          opsel_lo[2] = valu.opsel_lo[add_op_idx];
3957          opsel_hi[2] = valu.opsel_hi[add_op_idx];
3958          opsel[2] = valu.opsel[add_op_idx];
3959          opsel[3] = valu.opsel[3];
3960          omod = valu.omod;
3961          clamp = valu.clamp;
3962          /* abs of the multiplication result */
3963          if (valu.abs[mul_op_idx]) {
3964             neg[0] = false;
3965             neg[1] = false;
3966             abs[0] = true;
3967             abs[1] = true;
3968          }
3969          /* neg of the multiplication result */
3970          neg[1] ^= valu.neg[mul_op_idx];
3971 
3972          if (instr->opcode == aco_opcode::v_sub_f32 || instr->opcode == aco_opcode::v_sub_f16)
3973             neg[1 + add_op_idx] = neg[1 + add_op_idx] ^ true;
3974          else if (instr->opcode == aco_opcode::v_subrev_f32 ||
3975                   instr->opcode == aco_opcode::v_subrev_f16)
3976             neg[2 - add_op_idx] = neg[2 - add_op_idx] ^ true;
3977 
3978          aco_ptr<Instruction> add_instr = std::move(instr);
3979          aco_ptr<Instruction> mad;
3980          if (add_instr->isVOP3P() || mul_instr->isVOP3P()) {
3981             assert(!omod);
3982             assert(!opsel);
3983 
3984             aco_opcode mad_op = add_instr->definitions[0].bytes() == 2 ? aco_opcode::v_fma_mixlo_f16
3985                                                                        : aco_opcode::v_fma_mix_f32;
3986             mad.reset(create_instruction(mad_op, Format::VOP3P, 3, 1));
3987          } else {
3988             assert(!opsel_lo);
3989             assert(!opsel_hi);
3990 
3991             aco_opcode mad_op = emit_fma ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32;
3992             if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32) {
3993                assert(emit_fma == (ctx.program->gfx_level >= GFX10_3));
3994                mad_op = emit_fma ? aco_opcode::v_fma_legacy_f32 : aco_opcode::v_mad_legacy_f32;
3995             } else if (mad16) {
3996                mad_op = emit_fma ? (ctx.program->gfx_level == GFX8 ? aco_opcode::v_fma_legacy_f16
3997                                                                    : aco_opcode::v_fma_f16)
3998                                  : (ctx.program->gfx_level == GFX8 ? aco_opcode::v_mad_legacy_f16
3999                                                                    : aco_opcode::v_mad_f16);
4000             } else if (mad64) {
4001                mad_op = aco_opcode::v_fma_f64;
4002             }
4003 
4004             mad.reset(create_instruction(mad_op, Format::VOP3, 3, 1));
4005          }
4006 
4007          for (unsigned i = 0; i < 3; i++) {
4008             mad->operands[i] = op[i];
4009             mad->valu().neg[i] = neg[i];
4010             mad->valu().abs[i] = abs[i];
4011          }
4012          mad->valu().omod = omod;
4013          mad->valu().clamp = clamp;
4014          mad->valu().opsel_lo = opsel_lo;
4015          mad->valu().opsel_hi = opsel_hi;
4016          mad->valu().opsel = opsel;
4017          mad->definitions[0] = add_instr->definitions[0];
4018          mad->definitions[0].setPrecise(add_instr->definitions[0].isPrecise() ||
4019                                         mul_instr->definitions[0].isPrecise());
4020          mad->pass_flags = add_instr->pass_flags;
4021 
4022          instr = std::move(mad);
4023 
4024          /* mark this ssa_def to be re-checked for profitability and literals */
4025          ctx.mad_infos.emplace_back(std::move(add_instr), mul_instr->definitions[0].tempId());
4026          ctx.info[instr->definitions[0].tempId()].set_mad(ctx.mad_infos.size() - 1);
4027          return;
4028       }
4029    }
4030    /* v_mul_f32(v_cndmask_b32(0, 1.0, cond), a) -> v_cndmask_b32(0, a, cond) */
4031    else if (((instr->opcode == aco_opcode::v_mul_f32 && !instr->definitions[0].isNaNPreserve() &&
4032               !instr->definitions[0].isInfPreserve()) ||
4033              (instr->opcode == aco_opcode::v_mul_legacy_f32 &&
4034               !instr->definitions[0].isSZPreserve())) &&
4035             !instr->usesModifiers() && !ctx.fp_mode.must_flush_denorms32) {
4036       for (unsigned i = 0; i < 2; i++) {
4037          if (instr->operands[i].isTemp() && ctx.info[instr->operands[i].tempId()].is_b2f() &&
4038              ctx.uses[instr->operands[i].tempId()] == 1 && instr->operands[!i].isTemp() &&
4039              instr->operands[!i].getTemp().type() == RegType::vgpr) {
4040             ctx.uses[instr->operands[i].tempId()]--;
4041             ctx.uses[ctx.info[instr->operands[i].tempId()].temp.id()]++;
4042 
4043             aco_ptr<Instruction> new_instr{
4044                create_instruction(aco_opcode::v_cndmask_b32, Format::VOP2, 3, 1)};
4045             new_instr->operands[0] = Operand::zero();
4046             new_instr->operands[1] = instr->operands[!i];
4047             new_instr->operands[2] = Operand(ctx.info[instr->operands[i].tempId()].temp);
4048             new_instr->definitions[0] = instr->definitions[0];
4049             new_instr->pass_flags = instr->pass_flags;
4050             instr = std::move(new_instr);
4051             ctx.info[instr->definitions[0].tempId()].label = 0;
4052             return;
4053          }
4054       }
4055    } else if (instr->opcode == aco_opcode::v_or_b32 && ctx.program->gfx_level >= GFX9) {
4056       if (combine_three_valu_op(ctx, instr, aco_opcode::s_or_b32, aco_opcode::v_or3_b32, "012",
4057                                 1 | 2)) {
4058       } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_or_b32, aco_opcode::v_or3_b32,
4059                                        "012", 1 | 2)) {
4060       } else if (combine_add_or_then_and_lshl(ctx, instr)) {
4061       } else if (combine_v_andor_not(ctx, instr)) {
4062       }
4063    } else if (instr->opcode == aco_opcode::v_xor_b32 && ctx.program->gfx_level >= GFX10) {
4064       if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xor3_b32, "012",
4065                                 1 | 2)) {
4066       } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xor3_b32,
4067                                        "012", 1 | 2)) {
4068       } else if (combine_xor_not(ctx, instr)) {
4069       }
4070    } else if (instr->opcode == aco_opcode::v_not_b32 && ctx.program->gfx_level >= GFX10) {
4071       combine_not_xor(ctx, instr);
4072    } else if (instr->opcode == aco_opcode::v_add_u16 && !instr->valu().clamp) {
4073       combine_three_valu_op(
4074          ctx, instr, aco_opcode::v_mul_lo_u16,
4075          ctx.program->gfx_level == GFX8 ? aco_opcode::v_mad_legacy_u16 : aco_opcode::v_mad_u16,
4076          "120", 1 | 2);
4077    } else if (instr->opcode == aco_opcode::v_add_u16_e64 && !instr->valu().clamp) {
4078       combine_three_valu_op(ctx, instr, aco_opcode::v_mul_lo_u16_e64, aco_opcode::v_mad_u16, "120",
4079                             1 | 2);
4080    } else if (instr->opcode == aco_opcode::v_add_u32 && !instr->usesModifiers()) {
4081       if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_addc_co_u32, 1 | 2)) {
4082       } else if (combine_add_bcnt(ctx, instr)) {
4083       } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_mul_u32_u24,
4084                                        aco_opcode::v_mad_u32_u24, "120", 1 | 2)) {
4085       } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_mul_i32_i24,
4086                                        aco_opcode::v_mad_i32_i24, "120", 1 | 2)) {
4087       } else if (ctx.program->gfx_level >= GFX9) {
4088          if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xad_u32, "120",
4089                                    1 | 2)) {
4090          } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xad_u32,
4091                                           "120", 1 | 2)) {
4092          } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_i32, aco_opcode::v_add3_u32,
4093                                           "012", 1 | 2)) {
4094          } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_u32, aco_opcode::v_add3_u32,
4095                                           "012", 1 | 2)) {
4096          } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add3_u32,
4097                                           "012", 1 | 2)) {
4098          } else if (combine_add_or_then_and_lshl(ctx, instr)) {
4099          }
4100       }
4101    } else if ((instr->opcode == aco_opcode::v_add_co_u32 ||
4102                instr->opcode == aco_opcode::v_add_co_u32_e64) &&
4103               !instr->usesModifiers()) {
4104       bool carry_out = ctx.uses[instr->definitions[1].tempId()] > 0;
4105       if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_addc_co_u32, 1 | 2)) {
4106       } else if (!carry_out && combine_add_bcnt(ctx, instr)) {
4107       } else if (!carry_out && combine_three_valu_op(ctx, instr, aco_opcode::v_mul_u32_u24,
4108                                                      aco_opcode::v_mad_u32_u24, "120", 1 | 2)) {
4109       } else if (!carry_out && combine_three_valu_op(ctx, instr, aco_opcode::v_mul_i32_i24,
4110                                                      aco_opcode::v_mad_i32_i24, "120", 1 | 2)) {
4111       } else if (!carry_out && combine_add_lshl(ctx, instr, false)) {
4112       }
4113    } else if (instr->opcode == aco_opcode::v_sub_u32 || instr->opcode == aco_opcode::v_sub_co_u32 ||
4114               instr->opcode == aco_opcode::v_sub_co_u32_e64) {
4115       bool carry_out =
4116          instr->opcode != aco_opcode::v_sub_u32 && ctx.uses[instr->definitions[1].tempId()] > 0;
4117       if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 2)) {
4118       } else if (!carry_out && combine_add_lshl(ctx, instr, true)) {
4119       }
4120    } else if (instr->opcode == aco_opcode::v_subrev_u32 ||
4121               instr->opcode == aco_opcode::v_subrev_co_u32 ||
4122               instr->opcode == aco_opcode::v_subrev_co_u32_e64) {
4123       combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 1);
4124    } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && ctx.program->gfx_level >= GFX9) {
4125       combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add_lshl_u32, "120",
4126                             2);
4127    } else if ((instr->opcode == aco_opcode::s_add_u32 || instr->opcode == aco_opcode::s_add_i32) &&
4128               ctx.program->gfx_level >= GFX9) {
4129       combine_salu_lshl_add(ctx, instr);
4130    } else if (instr->opcode == aco_opcode::s_not_b32 || instr->opcode == aco_opcode::s_not_b64) {
4131       if (!combine_salu_not_bitwise(ctx, instr))
4132          combine_inverse_comparison(ctx, instr);
4133    } else if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_or_b32 ||
4134               instr->opcode == aco_opcode::s_and_b64 || instr->opcode == aco_opcode::s_or_b64) {
4135       combine_salu_n2(ctx, instr);
4136    } else if (instr->opcode == aco_opcode::s_abs_i32) {
4137       combine_sabsdiff(ctx, instr);
4138    } else if (instr->opcode == aco_opcode::v_and_b32) {
4139       if (combine_and_subbrev(ctx, instr)) {
4140       } else if (combine_v_andor_not(ctx, instr)) {
4141       }
4142    } else if (instr->opcode == aco_opcode::v_fma_f32 || instr->opcode == aco_opcode::v_fma_f16) {
4143       /* set existing v_fma_f32 with label_mad so we can create v_fmamk_f32/v_fmaak_f32.
4144        * since ctx.uses[mad_info::mul_temp_id] is always 0, we don't have to worry about
4145        * select_instruction() using mad_info::add_instr.
4146        */
4147       ctx.mad_infos.emplace_back(nullptr, 0);
4148       ctx.info[instr->definitions[0].tempId()].set_mad(ctx.mad_infos.size() - 1);
4149    } else if (instr->opcode == aco_opcode::v_med3_f32 || instr->opcode == aco_opcode::v_med3_f16) {
4150       /* Optimize v_med3 to v_add so that it can be dual issued on GFX11. We start with v_med3 in
4151        * case omod can be applied.
4152        */
4153       unsigned idx;
4154       if (detect_clamp(instr.get(), &idx)) {
4155          instr->format = asVOP3(Format::VOP2);
4156          instr->operands[0] = instr->operands[idx];
4157          instr->operands[1] = Operand::zero();
4158          instr->opcode =
4159             instr->opcode == aco_opcode::v_med3_f32 ? aco_opcode::v_add_f32 : aco_opcode::v_add_f16;
4160          instr->valu().clamp = true;
4161          instr->valu().abs = (uint8_t)instr->valu().abs[idx];
4162          instr->valu().neg = (uint8_t)instr->valu().neg[idx];
4163          instr->operands.pop_back();
4164       }
4165    } else {
4166       aco_opcode min, max, min3, max3, med3, minmax;
4167       bool some_gfx9_only;
4168       if (get_minmax_info(instr->opcode, &min, &max, &min3, &max3, &med3, &minmax,
4169                           &some_gfx9_only) &&
4170           (!some_gfx9_only || ctx.program->gfx_level >= GFX9)) {
4171          if (combine_minmax(ctx, instr, instr->opcode == min ? max : min,
4172                             instr->opcode == min ? min3 : max3, minmax)) {
4173          } else {
4174             combine_clamp(ctx, instr, min, max, med3);
4175          }
4176       }
4177    }
4178 }
4179 
4180 struct remat_entry {
4181    Instruction* instr;
4182    uint32_t block;
4183 };
4184 
4185 inline bool
is_constant(Instruction * instr)4186 is_constant(Instruction* instr)
4187 {
4188    if (instr->opcode != aco_opcode::p_parallelcopy || instr->operands.size() != 1)
4189       return false;
4190 
4191    return instr->operands[0].isConstant() && instr->definitions[0].isTemp();
4192 }
4193 
4194 void
remat_constants_instr(opt_ctx & ctx,aco::map<Temp,remat_entry> & constants,Instruction * instr,uint32_t block_idx)4195 remat_constants_instr(opt_ctx& ctx, aco::map<Temp, remat_entry>& constants, Instruction* instr,
4196                       uint32_t block_idx)
4197 {
4198    for (Operand& op : instr->operands) {
4199       if (!op.isTemp())
4200          continue;
4201 
4202       auto it = constants.find(op.getTemp());
4203       if (it == constants.end())
4204          continue;
4205 
4206       /* Check if we already emitted the same constant in this block. */
4207       if (it->second.block != block_idx) {
4208          /* Rematerialize the constant. */
4209          Builder bld(ctx.program, &ctx.instructions);
4210          Operand const_op = it->second.instr->operands[0];
4211          it->second.instr = bld.copy(bld.def(op.regClass()), const_op);
4212          it->second.block = block_idx;
4213          ctx.uses.push_back(0);
4214          ctx.info.push_back(ctx.info[op.tempId()]);
4215       }
4216 
4217       /* Use the rematerialized constant and update information about latest use. */
4218       if (op.getTemp() != it->second.instr->definitions[0].getTemp()) {
4219          ctx.uses[op.tempId()]--;
4220          op.setTemp(it->second.instr->definitions[0].getTemp());
4221          ctx.uses[op.tempId()]++;
4222       }
4223    }
4224 }
4225 
4226 /**
4227  * This pass implements a simple constant rematerialization.
4228  * As common subexpression elimination (CSE) might increase the live-ranges
4229  * of loaded constants over large distances, this pass splits the live-ranges
4230  * again by re-emitting constants in every basic block.
4231  */
4232 void
rematerialize_constants(opt_ctx & ctx)4233 rematerialize_constants(opt_ctx& ctx)
4234 {
4235    aco::monotonic_buffer_resource memory(1024);
4236    aco::map<Temp, remat_entry> constants(memory);
4237 
4238    for (Block& block : ctx.program->blocks) {
4239       if (block.logical_idom == -1)
4240          continue;
4241 
4242       if (block.logical_idom == (int)block.index)
4243          constants.clear();
4244 
4245       ctx.instructions.reserve(block.instructions.size());
4246 
4247       for (aco_ptr<Instruction>& instr : block.instructions) {
4248          if (is_dead(ctx.uses, instr.get()))
4249             continue;
4250 
4251          if (is_constant(instr.get())) {
4252             Temp tmp = instr->definitions[0].getTemp();
4253             constants[tmp] = {instr.get(), block.index};
4254          } else if (!is_phi(instr)) {
4255             remat_constants_instr(ctx, constants, instr.get(), block.index);
4256          }
4257 
4258          ctx.instructions.emplace_back(instr.release());
4259       }
4260 
4261       block.instructions = std::move(ctx.instructions);
4262    }
4263 }
4264 
4265 bool
to_uniform_bool_instr(opt_ctx & ctx,aco_ptr<Instruction> & instr)4266 to_uniform_bool_instr(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4267 {
4268    /* Check every operand to make sure they are suitable. */
4269    for (Operand& op : instr->operands) {
4270       if (!op.isTemp())
4271          return false;
4272       if (!ctx.info[op.tempId()].is_uniform_bool() && !ctx.info[op.tempId()].is_uniform_bitwise())
4273          return false;
4274    }
4275 
4276    switch (instr->opcode) {
4277    case aco_opcode::s_and_b32:
4278    case aco_opcode::s_and_b64: instr->opcode = aco_opcode::s_and_b32; break;
4279    case aco_opcode::s_or_b32:
4280    case aco_opcode::s_or_b64: instr->opcode = aco_opcode::s_or_b32; break;
4281    case aco_opcode::s_xor_b32:
4282    case aco_opcode::s_xor_b64: instr->opcode = aco_opcode::s_absdiff_i32; break;
4283    default:
4284       /* Don't transform other instructions. They are very unlikely to appear here. */
4285       return false;
4286    }
4287 
4288    for (Operand& op : instr->operands) {
4289       ctx.uses[op.tempId()]--;
4290 
4291       if (ctx.info[op.tempId()].is_uniform_bool()) {
4292          /* Just use the uniform boolean temp. */
4293          op.setTemp(ctx.info[op.tempId()].temp);
4294       } else if (ctx.info[op.tempId()].is_uniform_bitwise()) {
4295          /* Use the SCC definition of the predecessor instruction.
4296           * This allows the predecessor to get picked up by the same optimization (if it has no
4297           * divergent users), and it also makes sure that the current instruction will keep working
4298           * even if the predecessor won't be transformed.
4299           */
4300          Instruction* pred_instr = ctx.info[op.tempId()].instr;
4301          assert(pred_instr->definitions.size() >= 2);
4302          assert(pred_instr->definitions[1].isFixed() &&
4303                 pred_instr->definitions[1].physReg() == scc);
4304          op.setTemp(pred_instr->definitions[1].getTemp());
4305       } else {
4306          unreachable("Invalid operand on uniform bitwise instruction.");
4307       }
4308 
4309       ctx.uses[op.tempId()]++;
4310    }
4311 
4312    instr->definitions[0].setTemp(Temp(instr->definitions[0].tempId(), s1));
4313    ctx.program->temp_rc[instr->definitions[0].tempId()] = s1;
4314    assert(instr->operands[0].regClass() == s1);
4315    assert(instr->operands[1].regClass() == s1);
4316    return true;
4317 }
4318 
4319 void
select_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)4320 select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4321 {
4322    const uint32_t threshold = 4;
4323 
4324    if (is_dead(ctx.uses, instr.get())) {
4325       instr.reset();
4326       return;
4327    }
4328 
4329    /* convert split_vector into a copy or extract_vector if only one definition is ever used */
4330    if (instr->opcode == aco_opcode::p_split_vector) {
4331       unsigned num_used = 0;
4332       unsigned idx = 0;
4333       unsigned split_offset = 0;
4334       for (unsigned i = 0, offset = 0; i < instr->definitions.size();
4335            offset += instr->definitions[i++].bytes()) {
4336          if (ctx.uses[instr->definitions[i].tempId()]) {
4337             num_used++;
4338             idx = i;
4339             split_offset = offset;
4340          }
4341       }
4342       bool done = false;
4343       if (num_used == 1 && ctx.info[instr->operands[0].tempId()].is_vec() &&
4344           ctx.uses[instr->operands[0].tempId()] == 1) {
4345          Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
4346 
4347          unsigned off = 0;
4348          Operand op;
4349          for (Operand& vec_op : vec->operands) {
4350             if (off == split_offset) {
4351                op = vec_op;
4352                break;
4353             }
4354             off += vec_op.bytes();
4355          }
4356          if (off != instr->operands[0].bytes() && op.bytes() == instr->definitions[idx].bytes()) {
4357             ctx.uses[instr->operands[0].tempId()]--;
4358             for (Operand& vec_op : vec->operands) {
4359                if (vec_op.isTemp())
4360                   ctx.uses[vec_op.tempId()]--;
4361             }
4362             if (op.isTemp())
4363                ctx.uses[op.tempId()]++;
4364 
4365             aco_ptr<Instruction> copy{
4366                create_instruction(aco_opcode::p_parallelcopy, Format::PSEUDO, 1, 1)};
4367             copy->operands[0] = op;
4368             copy->definitions[0] = instr->definitions[idx];
4369             instr = std::move(copy);
4370 
4371             done = true;
4372          }
4373       }
4374 
4375       if (!done && num_used == 1 &&
4376           instr->operands[0].bytes() % instr->definitions[idx].bytes() == 0 &&
4377           split_offset % instr->definitions[idx].bytes() == 0) {
4378          aco_ptr<Instruction> extract{
4379             create_instruction(aco_opcode::p_extract_vector, Format::PSEUDO, 2, 1)};
4380          extract->operands[0] = instr->operands[0];
4381          extract->operands[1] =
4382             Operand::c32((uint32_t)split_offset / instr->definitions[idx].bytes());
4383          extract->definitions[0] = instr->definitions[idx];
4384          instr = std::move(extract);
4385       }
4386    }
4387 
4388    mad_info* mad_info = NULL;
4389    if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
4390       mad_info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
4391       /* re-check mad instructions */
4392       if (ctx.uses[mad_info->mul_temp_id] && mad_info->add_instr) {
4393          ctx.uses[mad_info->mul_temp_id]++;
4394          if (instr->operands[0].isTemp())
4395             ctx.uses[instr->operands[0].tempId()]--;
4396          if (instr->operands[1].isTemp())
4397             ctx.uses[instr->operands[1].tempId()]--;
4398          instr.swap(mad_info->add_instr);
4399          mad_info = NULL;
4400       }
4401       /* check literals */
4402       else if (!instr->isDPP() && !instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_f64 &&
4403                instr->opcode != aco_opcode::v_mad_legacy_f32 &&
4404                instr->opcode != aco_opcode::v_fma_legacy_f32) {
4405          /* FMA can only take literals on GFX10+ */
4406          if ((instr->opcode == aco_opcode::v_fma_f32 || instr->opcode == aco_opcode::v_fma_f16) &&
4407              ctx.program->gfx_level < GFX10)
4408             return;
4409          /* There are no v_fmaak_legacy_f16/v_fmamk_legacy_f16 and on chips where VOP3 can take
4410           * literals (GFX10+), these instructions don't exist.
4411           */
4412          if (instr->opcode == aco_opcode::v_fma_legacy_f16)
4413             return;
4414 
4415          uint32_t literal_mask = 0;
4416          uint32_t fp16_mask = 0;
4417          uint32_t sgpr_mask = 0;
4418          uint32_t vgpr_mask = 0;
4419          uint32_t literal_uses = UINT32_MAX;
4420          uint32_t literal_value = 0;
4421 
4422          /* Iterate in reverse to prefer v_madak/v_fmaak. */
4423          for (int i = 2; i >= 0; i--) {
4424             Operand& op = instr->operands[i];
4425             if (!op.isTemp())
4426                continue;
4427             if (ctx.info[op.tempId()].is_literal(get_operand_size(instr, i))) {
4428                uint32_t new_literal = ctx.info[op.tempId()].val;
4429                float value = uif(new_literal);
4430                uint16_t fp16_val = _mesa_float_to_half(value);
4431                bool is_denorm = (fp16_val & 0x7fff) != 0 && (fp16_val & 0x7fff) <= 0x3ff;
4432                if (_mesa_half_to_float(fp16_val) == value &&
4433                    (!is_denorm || (ctx.fp_mode.denorm16_64 & fp_denorm_keep_in)))
4434                   fp16_mask |= 1 << i;
4435 
4436                if (!literal_mask || literal_value == new_literal) {
4437                   literal_value = new_literal;
4438                   literal_uses = MIN2(literal_uses, ctx.uses[op.tempId()]);
4439                   literal_mask |= 1 << i;
4440                   continue;
4441                }
4442             }
4443             sgpr_mask |= op.isOfType(RegType::sgpr) << i;
4444             vgpr_mask |= op.isOfType(RegType::vgpr) << i;
4445          }
4446 
4447          /* The constant bus limitations before GFX10 disallows SGPRs. */
4448          if (sgpr_mask && ctx.program->gfx_level < GFX10)
4449             literal_mask = 0;
4450 
4451          /* Encoding needs a vgpr. */
4452          if (!vgpr_mask)
4453             literal_mask = 0;
4454 
4455          /* v_madmk/v_fmamk needs a vgpr in the third source. */
4456          if (!(literal_mask & 0b100) && !(vgpr_mask & 0b100))
4457             literal_mask = 0;
4458 
4459          /* opsel with GFX11+ is the only modifier supported by fmamk/fmaak*/
4460          if (instr->valu().abs || instr->valu().neg || instr->valu().omod || instr->valu().clamp ||
4461              (instr->valu().opsel && ctx.program->gfx_level < GFX11))
4462             literal_mask = 0;
4463 
4464          if (instr->valu().opsel & ~vgpr_mask)
4465             literal_mask = 0;
4466 
4467          /* We can't use three unique fp16 literals */
4468          if (fp16_mask == 0b111)
4469             fp16_mask = 0b11;
4470 
4471          if ((instr->opcode == aco_opcode::v_fma_f32 ||
4472               (instr->opcode == aco_opcode::v_mad_f32 && !instr->definitions[0].isPrecise())) &&
4473              !instr->valu().omod && ctx.program->gfx_level >= GFX10 &&
4474              util_bitcount(fp16_mask) > std::max<uint32_t>(util_bitcount(literal_mask), 1)) {
4475             assert(ctx.program->dev.fused_mad_mix);
4476             u_foreach_bit (i, fp16_mask)
4477                ctx.uses[instr->operands[i].tempId()]--;
4478             mad_info->fp16_mask = fp16_mask;
4479             return;
4480          }
4481 
4482          /* Limit the number of literals to apply to not increase the code
4483           * size too much, but always apply literals for v_mad->v_madak
4484           * because both instructions are 64-bit and this doesn't increase
4485           * code size.
4486           * TODO: try to apply the literals earlier to lower the number of
4487           * uses below threshold
4488           */
4489          if (literal_mask && (literal_uses < threshold || (literal_mask & 0b100))) {
4490             u_foreach_bit (i, literal_mask)
4491                ctx.uses[instr->operands[i].tempId()]--;
4492             mad_info->literal_mask = literal_mask;
4493             return;
4494          }
4495       }
4496    }
4497 
4498    /* Mark SCC needed, so the uniform boolean transformation won't swap the definitions
4499     * when it isn't beneficial */
4500    if (instr->isBranch() && instr->operands.size() && instr->operands[0].isTemp() &&
4501        instr->operands[0].isFixed() && instr->operands[0].physReg() == scc) {
4502       ctx.info[instr->operands[0].tempId()].set_scc_needed();
4503       return;
4504    } else if ((instr->opcode == aco_opcode::s_cselect_b64 ||
4505                instr->opcode == aco_opcode::s_cselect_b32) &&
4506               instr->operands[2].isTemp()) {
4507       ctx.info[instr->operands[2].tempId()].set_scc_needed();
4508    }
4509 
4510    /* check for literals */
4511    if (!instr->isSALU() && !instr->isVALU())
4512       return;
4513 
4514    /* Transform uniform bitwise boolean operations to 32-bit when there are no divergent uses. */
4515    if (instr->definitions.size() && ctx.uses[instr->definitions[0].tempId()] == 0 &&
4516        ctx.info[instr->definitions[0].tempId()].is_uniform_bitwise()) {
4517       bool transform_done = to_uniform_bool_instr(ctx, instr);
4518 
4519       if (transform_done && !ctx.info[instr->definitions[1].tempId()].is_scc_needed()) {
4520          /* Swap the two definition IDs in order to avoid overusing the SCC.
4521           * This reduces extra moves generated by RA. */
4522          uint32_t def0_id = instr->definitions[0].getTemp().id();
4523          uint32_t def1_id = instr->definitions[1].getTemp().id();
4524          instr->definitions[0].setTemp(Temp(def1_id, s1));
4525          instr->definitions[1].setTemp(Temp(def0_id, s1));
4526       }
4527 
4528       return;
4529    }
4530 
4531    /* This optimization is done late in order to be able to apply otherwise
4532     * unsafe optimizations such as the inverse comparison optimization.
4533     */
4534    if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_and_b64) {
4535       if (instr->operands[0].isTemp() && fixed_to_exec(instr->operands[1]) &&
4536           ctx.uses[instr->operands[0].tempId()] == 1 &&
4537           ctx.uses[instr->definitions[1].tempId()] == 0 &&
4538           can_eliminate_and_exec(ctx, instr->operands[0].getTemp(), instr->pass_flags)) {
4539          ctx.uses[instr->operands[0].tempId()]--;
4540          ctx.info[instr->operands[0].tempId()].instr->definitions[0].setTemp(
4541             instr->definitions[0].getTemp());
4542          instr.reset();
4543          return;
4544       }
4545    }
4546 
4547    /* Combine DPP copies into VALU. This should be done after creating MAD/FMA. */
4548    if (instr->isVALU() && !instr->isDPP()) {
4549       for (unsigned i = 0; i < instr->operands.size(); i++) {
4550          if (!instr->operands[i].isTemp())
4551             continue;
4552          ssa_info info = ctx.info[instr->operands[i].tempId()];
4553 
4554          if (!info.is_dpp() || info.instr->pass_flags != instr->pass_flags)
4555             continue;
4556 
4557          /* We won't eliminate the DPP mov if the operand is used twice */
4558          bool op_used_twice = false;
4559          for (unsigned j = 0; j < instr->operands.size(); j++)
4560             op_used_twice |= i != j && instr->operands[i] == instr->operands[j];
4561          if (op_used_twice)
4562             continue;
4563 
4564          if (i != 0) {
4565             if (!can_swap_operands(instr, &instr->opcode, 0, i))
4566                continue;
4567             instr->valu().swapOperands(0, i);
4568          }
4569 
4570          if (!can_use_DPP(ctx.program->gfx_level, instr, info.is_dpp8()))
4571             continue;
4572 
4573          bool dpp8 = info.is_dpp8();
4574          bool input_mods = can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, 0) &&
4575                            get_operand_size(instr, 0) == 32;
4576          bool mov_uses_mods = info.instr->valu().neg[0] || info.instr->valu().abs[0];
4577          if (((dpp8 && ctx.program->gfx_level < GFX11) || !input_mods) && mov_uses_mods)
4578             continue;
4579 
4580          convert_to_DPP(ctx.program->gfx_level, instr, dpp8);
4581 
4582          if (dpp8) {
4583             DPP8_instruction* dpp = &instr->dpp8();
4584             dpp->lane_sel = info.instr->dpp8().lane_sel;
4585             dpp->fetch_inactive = info.instr->dpp8().fetch_inactive;
4586             if (mov_uses_mods)
4587                instr->format = asVOP3(instr->format);
4588          } else {
4589             DPP16_instruction* dpp = &instr->dpp16();
4590             dpp->dpp_ctrl = info.instr->dpp16().dpp_ctrl;
4591             dpp->bound_ctrl = info.instr->dpp16().bound_ctrl;
4592             dpp->fetch_inactive = info.instr->dpp16().fetch_inactive;
4593          }
4594 
4595          instr->valu().neg[0] ^= info.instr->valu().neg[0] && !instr->valu().abs[0];
4596          instr->valu().abs[0] |= info.instr->valu().abs[0];
4597 
4598          if (--ctx.uses[info.instr->definitions[0].tempId()])
4599             ctx.uses[info.instr->operands[0].tempId()]++;
4600          instr->operands[0].setTemp(info.instr->operands[0].getTemp());
4601          break;
4602       }
4603    }
4604 
4605    /* Use v_fma_mix for f2f32/f2f16 if it has higher throughput.
4606     * Do this late to not disturb other optimizations.
4607     */
4608    if ((instr->opcode == aco_opcode::v_cvt_f32_f16 || instr->opcode == aco_opcode::v_cvt_f16_f32) &&
4609        ctx.program->gfx_level >= GFX11 && ctx.program->wave_size == 64 && !instr->valu().omod &&
4610        !instr->isDPP()) {
4611       bool is_f2f16 = instr->opcode == aco_opcode::v_cvt_f16_f32;
4612       Instruction* fma = create_instruction(
4613          is_f2f16 ? aco_opcode::v_fma_mixlo_f16 : aco_opcode::v_fma_mix_f32, Format::VOP3P, 3, 1);
4614       fma->definitions[0] = instr->definitions[0];
4615       fma->operands[0] = instr->operands[0];
4616       fma->valu().opsel_hi[0] = !is_f2f16;
4617       fma->valu().opsel_lo[0] = instr->valu().opsel[0];
4618       fma->valu().clamp = instr->valu().clamp;
4619       fma->valu().abs[0] = instr->valu().abs[0];
4620       fma->valu().neg[0] = instr->valu().neg[0];
4621       fma->operands[1] = Operand::c32(fui(1.0f));
4622       fma->operands[2] = Operand::zero();
4623       fma->valu().neg[2] = true;
4624       instr.reset(fma);
4625       ctx.info[instr->definitions[0].tempId()].label = 0;
4626    }
4627 
4628    if (instr->isSDWA() || (instr->isVOP3() && ctx.program->gfx_level < GFX10) ||
4629        (instr->isVOP3P() && ctx.program->gfx_level < GFX10))
4630       return; /* some encodings can't ever take literals */
4631 
4632    /* we do not apply the literals yet as we don't know if it is profitable */
4633    Operand current_literal(s1);
4634 
4635    unsigned literal_id = 0;
4636    unsigned literal_uses = UINT32_MAX;
4637    Operand literal(s1);
4638    unsigned num_operands = 1;
4639    if (instr->isSALU() || (ctx.program->gfx_level >= GFX10 &&
4640                            (can_use_VOP3(ctx, instr) || instr->isVOP3P()) && !instr->isDPP()))
4641       num_operands = instr->operands.size();
4642    /* catch VOP2 with a 3rd SGPR operand (e.g. v_cndmask_b32, v_addc_co_u32) */
4643    else if (instr->isVALU() && instr->operands.size() >= 3)
4644       return;
4645 
4646    unsigned sgpr_ids[2] = {0, 0};
4647    bool is_literal_sgpr = false;
4648    uint32_t mask = 0;
4649 
4650    /* choose a literal to apply */
4651    for (unsigned i = 0; i < num_operands; i++) {
4652       Operand op = instr->operands[i];
4653       unsigned bits = get_operand_size(instr, i);
4654 
4655       if (instr->isVALU() && op.isTemp() && op.getTemp().type() == RegType::sgpr &&
4656           op.tempId() != sgpr_ids[0])
4657          sgpr_ids[!!sgpr_ids[0]] = op.tempId();
4658 
4659       if (op.isLiteral()) {
4660          current_literal = op;
4661          continue;
4662       } else if (!op.isTemp() || !ctx.info[op.tempId()].is_literal(bits)) {
4663          continue;
4664       }
4665 
4666       if (!alu_can_accept_constant(instr, i))
4667          continue;
4668 
4669       if (ctx.uses[op.tempId()] < literal_uses) {
4670          is_literal_sgpr = op.getTemp().type() == RegType::sgpr;
4671          mask = 0;
4672          literal = Operand::c32(ctx.info[op.tempId()].val);
4673          literal_uses = ctx.uses[op.tempId()];
4674          literal_id = op.tempId();
4675       }
4676 
4677       mask |= (op.tempId() == literal_id) << i;
4678    }
4679 
4680    /* don't go over the constant bus limit */
4681    bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64_e64 ||
4682                      instr->opcode == aco_opcode::v_lshlrev_b64 ||
4683                      instr->opcode == aco_opcode::v_lshrrev_b64 ||
4684                      instr->opcode == aco_opcode::v_ashrrev_i64;
4685    unsigned const_bus_limit = instr->isVALU() ? 1 : UINT32_MAX;
4686    if (ctx.program->gfx_level >= GFX10 && !is_shift64)
4687       const_bus_limit = 2;
4688 
4689    unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
4690    if (num_sgprs == const_bus_limit && !is_literal_sgpr)
4691       return;
4692 
4693    if (literal_id && literal_uses < threshold &&
4694        (current_literal.isUndefined() ||
4695         (current_literal.size() == literal.size() &&
4696          current_literal.constantValue() == literal.constantValue()))) {
4697       /* mark the literal to be applied */
4698       while (mask) {
4699          unsigned i = u_bit_scan(&mask);
4700          if (instr->operands[i].isTemp() && instr->operands[i].tempId() == literal_id)
4701             ctx.uses[instr->operands[i].tempId()]--;
4702       }
4703    }
4704 }
4705 
4706 static aco_opcode
sopk_opcode_for_sopc(aco_opcode opcode)4707 sopk_opcode_for_sopc(aco_opcode opcode)
4708 {
4709 #define CTOK(op)                                                                                   \
4710    case aco_opcode::s_cmp_##op##_i32: return aco_opcode::s_cmpk_##op##_i32;                        \
4711    case aco_opcode::s_cmp_##op##_u32: return aco_opcode::s_cmpk_##op##_u32;
4712    switch (opcode) {
4713       CTOK(eq)
4714       CTOK(lg)
4715       CTOK(gt)
4716       CTOK(ge)
4717       CTOK(lt)
4718       CTOK(le)
4719    default: return aco_opcode::num_opcodes;
4720    }
4721 #undef CTOK
4722 }
4723 
4724 static bool
sopc_is_signed(aco_opcode opcode)4725 sopc_is_signed(aco_opcode opcode)
4726 {
4727 #define SOPC(op)                                                                                   \
4728    case aco_opcode::s_cmp_##op##_i32: return true;                                                 \
4729    case aco_opcode::s_cmp_##op##_u32: return false;
4730    switch (opcode) {
4731       SOPC(eq)
4732       SOPC(lg)
4733       SOPC(gt)
4734       SOPC(ge)
4735       SOPC(lt)
4736       SOPC(le)
4737    default: unreachable("Not a valid SOPC instruction.");
4738    }
4739 #undef SOPC
4740 }
4741 
4742 static aco_opcode
sopc_32_swapped(aco_opcode opcode)4743 sopc_32_swapped(aco_opcode opcode)
4744 {
4745 #define SOPC(op1, op2)                                                                             \
4746    case aco_opcode::s_cmp_##op1##_i32: return aco_opcode::s_cmp_##op2##_i32;                       \
4747    case aco_opcode::s_cmp_##op1##_u32: return aco_opcode::s_cmp_##op2##_u32;
4748    switch (opcode) {
4749       SOPC(eq, eq)
4750       SOPC(lg, lg)
4751       SOPC(gt, lt)
4752       SOPC(ge, le)
4753       SOPC(lt, gt)
4754       SOPC(le, ge)
4755    default: return aco_opcode::num_opcodes;
4756    }
4757 #undef SOPC
4758 }
4759 
4760 static void
try_convert_sopc_to_sopk(aco_ptr<Instruction> & instr)4761 try_convert_sopc_to_sopk(aco_ptr<Instruction>& instr)
4762 {
4763    if (sopk_opcode_for_sopc(instr->opcode) == aco_opcode::num_opcodes)
4764       return;
4765 
4766    if (instr->operands[0].isLiteral()) {
4767       std::swap(instr->operands[0], instr->operands[1]);
4768       instr->opcode = sopc_32_swapped(instr->opcode);
4769    }
4770 
4771    if (!instr->operands[1].isLiteral())
4772       return;
4773 
4774    if (instr->operands[0].isFixed() && instr->operands[0].physReg() >= 128)
4775       return;
4776 
4777    uint32_t value = instr->operands[1].constantValue();
4778 
4779    const uint32_t i16_mask = 0xffff8000u;
4780 
4781    bool value_is_i16 = (value & i16_mask) == 0 || (value & i16_mask) == i16_mask;
4782    bool value_is_u16 = !(value & 0xffff0000u);
4783 
4784    if (!value_is_i16 && !value_is_u16)
4785       return;
4786 
4787    if (!value_is_i16 && sopc_is_signed(instr->opcode)) {
4788       if (instr->opcode == aco_opcode::s_cmp_lg_i32)
4789          instr->opcode = aco_opcode::s_cmp_lg_u32;
4790       else if (instr->opcode == aco_opcode::s_cmp_eq_i32)
4791          instr->opcode = aco_opcode::s_cmp_eq_u32;
4792       else
4793          return;
4794    } else if (!value_is_u16 && !sopc_is_signed(instr->opcode)) {
4795       if (instr->opcode == aco_opcode::s_cmp_lg_u32)
4796          instr->opcode = aco_opcode::s_cmp_lg_i32;
4797       else if (instr->opcode == aco_opcode::s_cmp_eq_u32)
4798          instr->opcode = aco_opcode::s_cmp_eq_i32;
4799       else
4800          return;
4801    }
4802 
4803    instr->format = Format::SOPK;
4804    SALU_instruction* instr_sopk = &instr->salu();
4805 
4806    instr_sopk->imm = instr_sopk->operands[1].constantValue() & 0xffff;
4807    instr_sopk->opcode = sopk_opcode_for_sopc(instr_sopk->opcode);
4808    instr_sopk->operands.pop_back();
4809 }
4810 
4811 static void
opt_fma_mix_acc(opt_ctx & ctx,aco_ptr<Instruction> & instr)4812 opt_fma_mix_acc(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4813 {
4814    /* fma_mix is only dual issued on gfx11 if dst and acc type match */
4815    bool f2f16 = instr->opcode == aco_opcode::v_fma_mixlo_f16;
4816 
4817    if (instr->valu().opsel_hi[2] == f2f16 || instr->isDPP())
4818       return;
4819 
4820    bool is_add = false;
4821    for (unsigned i = 0; i < 2; i++) {
4822       uint32_t one = instr->valu().opsel_hi[i] ? 0x3800 : 0x3f800000;
4823       is_add = instr->operands[i].constantEquals(one) && !instr->valu().neg[i] &&
4824                !instr->valu().opsel_lo[i];
4825       if (is_add) {
4826          instr->valu().swapOperands(0, i);
4827          break;
4828       }
4829    }
4830 
4831    if (is_add && instr->valu().opsel_hi[1] == f2f16) {
4832       instr->valu().swapOperands(1, 2);
4833       return;
4834    }
4835 
4836    unsigned literal_count = instr->operands[0].isLiteral() + instr->operands[1].isLiteral() +
4837                             instr->operands[2].isLiteral();
4838 
4839    if (!f2f16 || literal_count > 1)
4840       return;
4841 
4842    /* try to convert constant operand to fp16 */
4843    for (unsigned i = 2 - is_add; i < 3; i++) {
4844       if (!instr->operands[i].isConstant())
4845          continue;
4846 
4847       float value = uif(instr->operands[i].constantValue());
4848       uint16_t fp16_val = _mesa_float_to_half(value);
4849       bool is_denorm = (fp16_val & 0x7fff) != 0 && (fp16_val & 0x7fff) <= 0x3ff;
4850 
4851       if (_mesa_half_to_float(fp16_val) != value ||
4852           (is_denorm && !(ctx.fp_mode.denorm16_64 & fp_denorm_keep_in)))
4853          continue;
4854 
4855       instr->valu().swapOperands(i, 2);
4856 
4857       Operand op16 = Operand::c16(fp16_val);
4858       assert(!op16.isLiteral() || instr->operands[2].isLiteral());
4859 
4860       instr->operands[2] = op16;
4861       instr->valu().opsel_lo[2] = false;
4862       instr->valu().opsel_hi[2] = true;
4863       return;
4864    }
4865 }
4866 
4867 void
apply_literals(opt_ctx & ctx,aco_ptr<Instruction> & instr)4868 apply_literals(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4869 {
4870    /* Cleanup Dead Instructions */
4871    if (!instr)
4872       return;
4873 
4874    /* apply literals on MAD */
4875    if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
4876       mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
4877       const bool madak = (info->literal_mask & 0b100);
4878       bool has_dead_literal = false;
4879       u_foreach_bit (i, info->literal_mask | info->fp16_mask)
4880          has_dead_literal |= ctx.uses[instr->operands[i].tempId()] == 0;
4881 
4882       if (has_dead_literal && info->fp16_mask) {
4883          instr->format = Format::VOP3P;
4884          instr->opcode = aco_opcode::v_fma_mix_f32;
4885 
4886          uint32_t literal = 0;
4887          bool second = false;
4888          u_foreach_bit (i, info->fp16_mask) {
4889             float value = uif(ctx.info[instr->operands[i].tempId()].val);
4890             literal |= _mesa_float_to_half(value) << (second * 16);
4891             instr->valu().opsel_lo[i] = second;
4892             instr->valu().opsel_hi[i] = true;
4893             second = true;
4894          }
4895 
4896          for (unsigned i = 0; i < 3; i++) {
4897             if (info->fp16_mask & (1 << i))
4898                instr->operands[i] = Operand::literal32(literal);
4899          }
4900 
4901          ctx.instructions.emplace_back(std::move(instr));
4902          return;
4903       }
4904 
4905       if (has_dead_literal || madak) {
4906          aco_opcode new_op = madak ? aco_opcode::v_madak_f32 : aco_opcode::v_madmk_f32;
4907          if (instr->opcode == aco_opcode::v_fma_f32)
4908             new_op = madak ? aco_opcode::v_fmaak_f32 : aco_opcode::v_fmamk_f32;
4909          else if (instr->opcode == aco_opcode::v_mad_f16 ||
4910                   instr->opcode == aco_opcode::v_mad_legacy_f16)
4911             new_op = madak ? aco_opcode::v_madak_f16 : aco_opcode::v_madmk_f16;
4912          else if (instr->opcode == aco_opcode::v_fma_f16)
4913             new_op = madak ? aco_opcode::v_fmaak_f16 : aco_opcode::v_fmamk_f16;
4914 
4915          uint32_t literal = ctx.info[instr->operands[ffs(info->literal_mask) - 1].tempId()].val;
4916          instr->format = Format::VOP2;
4917          instr->opcode = new_op;
4918          for (unsigned i = 0; i < 3; i++) {
4919             if (info->literal_mask & (1 << i))
4920                instr->operands[i] = Operand::literal32(literal);
4921          }
4922          if (madak) { /* add literal -> madak */
4923             if (!instr->operands[1].isOfType(RegType::vgpr))
4924                instr->valu().swapOperands(0, 1);
4925          } else { /* mul literal -> madmk */
4926             if (!(info->literal_mask & 0b10))
4927                instr->valu().swapOperands(0, 1);
4928             instr->valu().swapOperands(1, 2);
4929          }
4930          ctx.instructions.emplace_back(std::move(instr));
4931          return;
4932       }
4933    }
4934 
4935    /* apply literals on other SALU/VALU */
4936    if (instr->isSALU() || instr->isVALU()) {
4937       for (unsigned i = 0; i < instr->operands.size(); i++) {
4938          Operand op = instr->operands[i];
4939          unsigned bits = get_operand_size(instr, i);
4940          if (op.isTemp() && ctx.info[op.tempId()].is_literal(bits) && ctx.uses[op.tempId()] == 0) {
4941             Operand literal = Operand::literal32(ctx.info[op.tempId()].val);
4942             instr->format = withoutDPP(instr->format);
4943             if (instr->isVALU() && i > 0 && instr->format != Format::VOP3P)
4944                instr->format = asVOP3(instr->format);
4945             instr->operands[i] = literal;
4946          }
4947       }
4948    }
4949 
4950    if (instr->isSOPC() && ctx.program->gfx_level < GFX12)
4951       try_convert_sopc_to_sopk(instr);
4952 
4953    if (instr->opcode == aco_opcode::v_fma_mixlo_f16 || instr->opcode == aco_opcode::v_fma_mix_f32)
4954       opt_fma_mix_acc(ctx, instr);
4955 
4956    ctx.instructions.emplace_back(std::move(instr));
4957 }
4958 
4959 } /* end namespace */
4960 
4961 void
optimize(Program * program)4962 optimize(Program* program)
4963 {
4964    opt_ctx ctx;
4965    ctx.program = program;
4966    ctx.info = std::vector<ssa_info>(program->peekAllocationId());
4967 
4968    /* 1. Bottom-Up DAG pass (forward) to label all ssa-defs */
4969    for (Block& block : program->blocks) {
4970       ctx.fp_mode = block.fp_mode;
4971       for (aco_ptr<Instruction>& instr : block.instructions)
4972          label_instruction(ctx, instr);
4973    }
4974 
4975    ctx.uses = dead_code_analysis(program);
4976 
4977    /* 2. Rematerialize constants in every block. */
4978    rematerialize_constants(ctx);
4979 
4980    /* 3. Combine v_mad, omod, clamp and propagate sgpr on VALU instructions */
4981    for (Block& block : program->blocks) {
4982       ctx.fp_mode = block.fp_mode;
4983       for (aco_ptr<Instruction>& instr : block.instructions)
4984          combine_instruction(ctx, instr);
4985    }
4986 
4987    /* 4. Top-Down DAG pass (backward) to select instructions (includes DCE) */
4988    for (auto block_rit = program->blocks.rbegin(); block_rit != program->blocks.rend();
4989         ++block_rit) {
4990       Block* block = &(*block_rit);
4991       ctx.fp_mode = block->fp_mode;
4992       for (auto instr_rit = block->instructions.rbegin(); instr_rit != block->instructions.rend();
4993            ++instr_rit)
4994          select_instruction(ctx, *instr_rit);
4995    }
4996 
4997    /* 5. Add literals to instructions */
4998    for (Block& block : program->blocks) {
4999       ctx.instructions.reserve(block.instructions.size());
5000       ctx.fp_mode = block.fp_mode;
5001       for (aco_ptr<Instruction>& instr : block.instructions)
5002          apply_literals(ctx, instr);
5003       block.instructions = std::move(ctx.instructions);
5004    }
5005 }
5006 
5007 } // namespace aco
5008