1 /*
2 * Copyright © 2018 Valve Corporation
3 *
4 * SPDX-License-Identifier: MIT
5 */
6
7 #include "aco_builder.h"
8 #include "aco_ir.h"
9
10 #include "util/half_float.h"
11 #include "util/memstream.h"
12
13 #include <algorithm>
14 #include <array>
15 #include <vector>
16
17 namespace aco {
18
19 namespace {
20 /**
21 * The optimizer works in 4 phases:
22 * (1) The first pass collects information for each ssa-def,
23 * propagates reg->reg operands of the same type, inline constants
24 * and neg/abs input modifiers.
25 * (2) The second pass combines instructions like mad, omod, clamp and
26 * propagates sgpr's on VALU instructions.
27 * This pass depends on information collected in the first pass.
28 * (3) The third pass goes backwards, and selects instructions,
29 * i.e. decides if a mad instruction is profitable and eliminates dead code.
30 * (4) The fourth pass cleans up the sequence: literals get applied and dead
31 * instructions are removed from the sequence.
32 */
33
34 struct mad_info {
35 aco_ptr<Instruction> add_instr;
36 uint32_t mul_temp_id;
37 uint16_t literal_mask;
38 uint16_t fp16_mask;
39
mad_infoaco::__anon9e387afb0111::mad_info40 mad_info(aco_ptr<Instruction> instr, uint32_t id)
41 : add_instr(std::move(instr)), mul_temp_id(id), literal_mask(0), fp16_mask(0)
42 {}
43 };
44
45 enum Label {
46 label_vec = 1 << 0,
47 label_constant_32bit = 1 << 1,
48 /* label_{abs,neg,mul,omod2,omod4,omod5,clamp} are used for both 16 and
49 * 32-bit operations but this shouldn't cause any issues because we don't
50 * look through any conversions */
51 label_abs = 1 << 2,
52 label_neg = 1 << 3,
53 label_mul = 1 << 4,
54 label_temp = 1 << 5,
55 label_literal = 1 << 6,
56 label_mad = 1 << 7,
57 label_omod2 = 1 << 8,
58 label_omod4 = 1 << 9,
59 label_omod5 = 1 << 10,
60 label_clamp = 1 << 12,
61 label_b2f = 1 << 16,
62 label_add_sub = 1 << 17,
63 label_bitwise = 1 << 18,
64 label_minmax = 1 << 19,
65 label_vopc = 1 << 20,
66 label_uniform_bool = 1 << 21,
67 label_constant_64bit = 1 << 22,
68 label_uniform_bitwise = 1 << 23,
69 label_scc_invert = 1 << 24,
70 label_scc_needed = 1 << 26,
71 label_b2i = 1 << 27,
72 label_fcanonicalize = 1 << 28,
73 label_constant_16bit = 1 << 29,
74 label_usedef = 1 << 30, /* generic label */
75 label_vop3p = 1ull << 31, /* 1ull to prevent sign extension */
76 label_canonicalized = 1ull << 32,
77 label_extract = 1ull << 33,
78 label_insert = 1ull << 34,
79 label_dpp16 = 1ull << 35,
80 label_dpp8 = 1ull << 36,
81 label_f2f32 = 1ull << 37,
82 label_f2f16 = 1ull << 38,
83 label_split = 1ull << 39,
84 };
85
86 static constexpr uint64_t instr_usedef_labels =
87 label_vec | label_mul | label_add_sub | label_vop3p | label_bitwise | label_uniform_bitwise |
88 label_minmax | label_vopc | label_usedef | label_extract | label_dpp16 | label_dpp8 |
89 label_f2f32;
90 static constexpr uint64_t instr_mod_labels =
91 label_omod2 | label_omod4 | label_omod5 | label_clamp | label_insert | label_f2f16;
92
93 static constexpr uint64_t instr_labels = instr_usedef_labels | instr_mod_labels | label_split;
94 static constexpr uint64_t temp_labels = label_abs | label_neg | label_temp | label_b2f |
95 label_uniform_bool | label_scc_invert | label_b2i |
96 label_fcanonicalize;
97 static constexpr uint32_t val_labels =
98 label_constant_32bit | label_constant_64bit | label_constant_16bit | label_literal | label_mad;
99
100 static_assert((instr_labels & temp_labels) == 0, "labels cannot intersect");
101 static_assert((instr_labels & val_labels) == 0, "labels cannot intersect");
102 static_assert((temp_labels & val_labels) == 0, "labels cannot intersect");
103
104 struct ssa_info {
105 uint64_t label;
106 union {
107 uint32_t val;
108 Temp temp;
109 Instruction* instr;
110 };
111
ssa_infoaco::__anon9e387afb0111::ssa_info112 ssa_info() : label(0) {}
113
add_labelaco::__anon9e387afb0111::ssa_info114 void add_label(Label new_label)
115 {
116 /* Since all the instr_usedef_labels use instr for the same thing
117 * (indicating the defining instruction), there is usually no need to
118 * clear any other instr labels. */
119 if (new_label & instr_usedef_labels)
120 label &= ~(instr_mod_labels | temp_labels | val_labels); /* instr, temp and val alias */
121
122 if (new_label & instr_mod_labels) {
123 label &= ~instr_labels;
124 label &= ~(temp_labels | val_labels); /* instr, temp and val alias */
125 }
126
127 if (new_label & temp_labels) {
128 label &= ~temp_labels;
129 label &= ~(instr_labels | val_labels); /* instr, temp and val alias */
130 }
131
132 uint32_t const_labels =
133 label_literal | label_constant_32bit | label_constant_64bit | label_constant_16bit;
134 if (new_label & const_labels) {
135 label &= ~val_labels | const_labels;
136 label &= ~(instr_labels | temp_labels); /* instr, temp and val alias */
137 } else if (new_label & val_labels) {
138 label &= ~val_labels;
139 label &= ~(instr_labels | temp_labels); /* instr, temp and val alias */
140 }
141
142 label |= new_label;
143 }
144
set_vecaco::__anon9e387afb0111::ssa_info145 void set_vec(Instruction* vec)
146 {
147 add_label(label_vec);
148 instr = vec;
149 }
150
is_vecaco::__anon9e387afb0111::ssa_info151 bool is_vec() { return label & label_vec; }
152
set_constantaco::__anon9e387afb0111::ssa_info153 void set_constant(amd_gfx_level gfx_level, uint64_t constant)
154 {
155 Operand op16 = Operand::c16(constant);
156 Operand op32 = Operand::get_const(gfx_level, constant, 4);
157 add_label(label_literal);
158 val = constant;
159
160 /* check that no upper bits are lost in case of packed 16bit constants */
161 if (gfx_level >= GFX8 && !op16.isLiteral() &&
162 op16.constantValue16(true) == ((constant >> 16) & 0xffff))
163 add_label(label_constant_16bit);
164
165 if (!op32.isLiteral())
166 add_label(label_constant_32bit);
167
168 if (Operand::is_constant_representable(constant, 8))
169 add_label(label_constant_64bit);
170
171 if (label & label_constant_64bit) {
172 val = Operand::c64(constant).constantValue();
173 if (val != constant)
174 label &= ~(label_literal | label_constant_16bit | label_constant_32bit);
175 }
176 }
177
is_constantaco::__anon9e387afb0111::ssa_info178 bool is_constant(unsigned bits)
179 {
180 switch (bits) {
181 case 8: return label & label_literal;
182 case 16: return label & label_constant_16bit;
183 case 32: return label & label_constant_32bit;
184 case 64: return label & label_constant_64bit;
185 }
186 return false;
187 }
188
is_literalaco::__anon9e387afb0111::ssa_info189 bool is_literal(unsigned bits)
190 {
191 bool is_lit = label & label_literal;
192 switch (bits) {
193 case 8: return false;
194 case 16: return is_lit && ~(label & label_constant_16bit);
195 case 32: return is_lit && ~(label & label_constant_32bit);
196 case 64: return false;
197 }
198 return false;
199 }
200
is_constant_or_literalaco::__anon9e387afb0111::ssa_info201 bool is_constant_or_literal(unsigned bits)
202 {
203 if (bits == 64)
204 return label & label_constant_64bit;
205 else
206 return label & label_literal;
207 }
208
set_absaco::__anon9e387afb0111::ssa_info209 void set_abs(Temp abs_temp)
210 {
211 add_label(label_abs);
212 temp = abs_temp;
213 }
214
is_absaco::__anon9e387afb0111::ssa_info215 bool is_abs() { return label & label_abs; }
216
set_negaco::__anon9e387afb0111::ssa_info217 void set_neg(Temp neg_temp)
218 {
219 add_label(label_neg);
220 temp = neg_temp;
221 }
222
is_negaco::__anon9e387afb0111::ssa_info223 bool is_neg() { return label & label_neg; }
224
set_neg_absaco::__anon9e387afb0111::ssa_info225 void set_neg_abs(Temp neg_abs_temp)
226 {
227 add_label((Label)((uint32_t)label_abs | (uint32_t)label_neg));
228 temp = neg_abs_temp;
229 }
230
set_mulaco::__anon9e387afb0111::ssa_info231 void set_mul(Instruction* mul)
232 {
233 add_label(label_mul);
234 instr = mul;
235 }
236
is_mulaco::__anon9e387afb0111::ssa_info237 bool is_mul() { return label & label_mul; }
238
set_tempaco::__anon9e387afb0111::ssa_info239 void set_temp(Temp tmp)
240 {
241 add_label(label_temp);
242 temp = tmp;
243 }
244
is_tempaco::__anon9e387afb0111::ssa_info245 bool is_temp() { return label & label_temp; }
246
set_madaco::__anon9e387afb0111::ssa_info247 void set_mad(uint32_t mad_info_idx)
248 {
249 add_label(label_mad);
250 val = mad_info_idx;
251 }
252
is_madaco::__anon9e387afb0111::ssa_info253 bool is_mad() { return label & label_mad; }
254
set_omod2aco::__anon9e387afb0111::ssa_info255 void set_omod2(Instruction* mul)
256 {
257 if (label & temp_labels)
258 return;
259 add_label(label_omod2);
260 instr = mul;
261 }
262
is_omod2aco::__anon9e387afb0111::ssa_info263 bool is_omod2() { return label & label_omod2; }
264
set_omod4aco::__anon9e387afb0111::ssa_info265 void set_omod4(Instruction* mul)
266 {
267 if (label & temp_labels)
268 return;
269 add_label(label_omod4);
270 instr = mul;
271 }
272
is_omod4aco::__anon9e387afb0111::ssa_info273 bool is_omod4() { return label & label_omod4; }
274
set_omod5aco::__anon9e387afb0111::ssa_info275 void set_omod5(Instruction* mul)
276 {
277 if (label & temp_labels)
278 return;
279 add_label(label_omod5);
280 instr = mul;
281 }
282
is_omod5aco::__anon9e387afb0111::ssa_info283 bool is_omod5() { return label & label_omod5; }
284
set_clampaco::__anon9e387afb0111::ssa_info285 void set_clamp(Instruction* med3)
286 {
287 if (label & temp_labels)
288 return;
289 add_label(label_clamp);
290 instr = med3;
291 }
292
is_clampaco::__anon9e387afb0111::ssa_info293 bool is_clamp() { return label & label_clamp; }
294
set_f2f16aco::__anon9e387afb0111::ssa_info295 void set_f2f16(Instruction* conv)
296 {
297 if (label & temp_labels)
298 return;
299 add_label(label_f2f16);
300 instr = conv;
301 }
302
is_f2f16aco::__anon9e387afb0111::ssa_info303 bool is_f2f16() { return label & label_f2f16; }
304
set_b2faco::__anon9e387afb0111::ssa_info305 void set_b2f(Temp b2f_val)
306 {
307 add_label(label_b2f);
308 temp = b2f_val;
309 }
310
is_b2faco::__anon9e387afb0111::ssa_info311 bool is_b2f() { return label & label_b2f; }
312
set_add_subaco::__anon9e387afb0111::ssa_info313 void set_add_sub(Instruction* add_sub_instr)
314 {
315 add_label(label_add_sub);
316 instr = add_sub_instr;
317 }
318
is_add_subaco::__anon9e387afb0111::ssa_info319 bool is_add_sub() { return label & label_add_sub; }
320
set_bitwiseaco::__anon9e387afb0111::ssa_info321 void set_bitwise(Instruction* bitwise_instr)
322 {
323 add_label(label_bitwise);
324 instr = bitwise_instr;
325 }
326
is_bitwiseaco::__anon9e387afb0111::ssa_info327 bool is_bitwise() { return label & label_bitwise; }
328
set_uniform_bitwiseaco::__anon9e387afb0111::ssa_info329 void set_uniform_bitwise() { add_label(label_uniform_bitwise); }
330
is_uniform_bitwiseaco::__anon9e387afb0111::ssa_info331 bool is_uniform_bitwise() { return label & label_uniform_bitwise; }
332
set_minmaxaco::__anon9e387afb0111::ssa_info333 void set_minmax(Instruction* minmax_instr)
334 {
335 add_label(label_minmax);
336 instr = minmax_instr;
337 }
338
is_minmaxaco::__anon9e387afb0111::ssa_info339 bool is_minmax() { return label & label_minmax; }
340
set_vopcaco::__anon9e387afb0111::ssa_info341 void set_vopc(Instruction* vopc_instr)
342 {
343 add_label(label_vopc);
344 instr = vopc_instr;
345 }
346
is_vopcaco::__anon9e387afb0111::ssa_info347 bool is_vopc() { return label & label_vopc; }
348
set_scc_neededaco::__anon9e387afb0111::ssa_info349 void set_scc_needed() { add_label(label_scc_needed); }
350
is_scc_neededaco::__anon9e387afb0111::ssa_info351 bool is_scc_needed() { return label & label_scc_needed; }
352
set_scc_invertaco::__anon9e387afb0111::ssa_info353 void set_scc_invert(Temp scc_inv)
354 {
355 add_label(label_scc_invert);
356 temp = scc_inv;
357 }
358
is_scc_invertaco::__anon9e387afb0111::ssa_info359 bool is_scc_invert() { return label & label_scc_invert; }
360
set_uniform_boolaco::__anon9e387afb0111::ssa_info361 void set_uniform_bool(Temp uniform_bool)
362 {
363 add_label(label_uniform_bool);
364 temp = uniform_bool;
365 }
366
is_uniform_boolaco::__anon9e387afb0111::ssa_info367 bool is_uniform_bool() { return label & label_uniform_bool; }
368
set_b2iaco::__anon9e387afb0111::ssa_info369 void set_b2i(Temp b2i_val)
370 {
371 add_label(label_b2i);
372 temp = b2i_val;
373 }
374
is_b2iaco::__anon9e387afb0111::ssa_info375 bool is_b2i() { return label & label_b2i; }
376
set_usedefaco::__anon9e387afb0111::ssa_info377 void set_usedef(Instruction* label_instr)
378 {
379 add_label(label_usedef);
380 instr = label_instr;
381 }
382
is_usedefaco::__anon9e387afb0111::ssa_info383 bool is_usedef() { return label & label_usedef; }
384
set_vop3paco::__anon9e387afb0111::ssa_info385 void set_vop3p(Instruction* vop3p_instr)
386 {
387 add_label(label_vop3p);
388 instr = vop3p_instr;
389 }
390
is_vop3paco::__anon9e387afb0111::ssa_info391 bool is_vop3p() { return label & label_vop3p; }
392
set_fcanonicalizeaco::__anon9e387afb0111::ssa_info393 void set_fcanonicalize(Temp tmp)
394 {
395 add_label(label_fcanonicalize);
396 temp = tmp;
397 }
398
is_fcanonicalizeaco::__anon9e387afb0111::ssa_info399 bool is_fcanonicalize() { return label & label_fcanonicalize; }
400
set_canonicalizedaco::__anon9e387afb0111::ssa_info401 void set_canonicalized() { add_label(label_canonicalized); }
402
is_canonicalizedaco::__anon9e387afb0111::ssa_info403 bool is_canonicalized() { return label & label_canonicalized; }
404
set_f2f32aco::__anon9e387afb0111::ssa_info405 void set_f2f32(Instruction* cvt)
406 {
407 add_label(label_f2f32);
408 instr = cvt;
409 }
410
is_f2f32aco::__anon9e387afb0111::ssa_info411 bool is_f2f32() { return label & label_f2f32; }
412
set_extractaco::__anon9e387afb0111::ssa_info413 void set_extract(Instruction* extract)
414 {
415 add_label(label_extract);
416 instr = extract;
417 }
418
is_extractaco::__anon9e387afb0111::ssa_info419 bool is_extract() { return label & label_extract; }
420
set_insertaco::__anon9e387afb0111::ssa_info421 void set_insert(Instruction* insert)
422 {
423 if (label & temp_labels)
424 return;
425 add_label(label_insert);
426 instr = insert;
427 }
428
is_insertaco::__anon9e387afb0111::ssa_info429 bool is_insert() { return label & label_insert; }
430
set_dpp16aco::__anon9e387afb0111::ssa_info431 void set_dpp16(Instruction* mov)
432 {
433 add_label(label_dpp16);
434 instr = mov;
435 }
436
set_dpp8aco::__anon9e387afb0111::ssa_info437 void set_dpp8(Instruction* mov)
438 {
439 add_label(label_dpp8);
440 instr = mov;
441 }
442
is_dppaco::__anon9e387afb0111::ssa_info443 bool is_dpp() { return label & (label_dpp16 | label_dpp8); }
is_dpp16aco::__anon9e387afb0111::ssa_info444 bool is_dpp16() { return label & label_dpp16; }
is_dpp8aco::__anon9e387afb0111::ssa_info445 bool is_dpp8() { return label & label_dpp8; }
446
set_splitaco::__anon9e387afb0111::ssa_info447 void set_split(Instruction* split)
448 {
449 add_label(label_split);
450 instr = split;
451 }
452
is_splitaco::__anon9e387afb0111::ssa_info453 bool is_split() { return label & label_split; }
454 };
455
456 struct opt_ctx {
457 Program* program;
458 float_mode fp_mode;
459 std::vector<aco_ptr<Instruction>> instructions;
460 std::vector<ssa_info> info;
461 std::pair<uint32_t, Temp> last_literal;
462 std::vector<mad_info> mad_infos;
463 std::vector<uint16_t> uses;
464 };
465
466 bool
can_use_VOP3(opt_ctx & ctx,const aco_ptr<Instruction> & instr)467 can_use_VOP3(opt_ctx& ctx, const aco_ptr<Instruction>& instr)
468 {
469 if (instr->isVOP3())
470 return true;
471
472 if (instr->isVOP3P() || instr->isVINTERP_INREG())
473 return false;
474
475 if (instr->operands.size() && instr->operands[0].isLiteral() && ctx.program->gfx_level < GFX10)
476 return false;
477
478 if (instr->isSDWA())
479 return false;
480
481 if (instr->isDPP() && ctx.program->gfx_level < GFX11)
482 return false;
483
484 return instr->opcode != aco_opcode::v_madmk_f32 && instr->opcode != aco_opcode::v_madak_f32 &&
485 instr->opcode != aco_opcode::v_madmk_f16 && instr->opcode != aco_opcode::v_madak_f16 &&
486 instr->opcode != aco_opcode::v_fmamk_f32 && instr->opcode != aco_opcode::v_fmaak_f32 &&
487 instr->opcode != aco_opcode::v_fmamk_f16 && instr->opcode != aco_opcode::v_fmaak_f16 &&
488 instr->opcode != aco_opcode::v_permlane64_b32 &&
489 instr->opcode != aco_opcode::v_readlane_b32 &&
490 instr->opcode != aco_opcode::v_writelane_b32 &&
491 instr->opcode != aco_opcode::v_readfirstlane_b32;
492 }
493
494 bool
pseudo_propagate_temp(opt_ctx & ctx,aco_ptr<Instruction> & instr,Temp temp,unsigned index)495 pseudo_propagate_temp(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp temp, unsigned index)
496 {
497 if (instr->definitions.empty())
498 return false;
499
500 const bool vgpr =
501 instr->opcode == aco_opcode::p_as_uniform ||
502 std::all_of(instr->definitions.begin(), instr->definitions.end(),
503 [](const Definition& def) { return def.regClass().type() == RegType::vgpr; });
504
505 /* don't propagate VGPRs into SGPR instructions */
506 if (temp.type() == RegType::vgpr && !vgpr)
507 return false;
508
509 bool can_accept_sgpr =
510 ctx.program->gfx_level >= GFX9 ||
511 std::none_of(instr->definitions.begin(), instr->definitions.end(),
512 [](const Definition& def) { return def.regClass().is_subdword(); });
513
514 switch (instr->opcode) {
515 case aco_opcode::p_phi:
516 case aco_opcode::p_linear_phi:
517 case aco_opcode::p_parallelcopy:
518 case aco_opcode::p_create_vector:
519 case aco_opcode::p_start_linear_vgpr:
520 if (temp.bytes() != instr->operands[index].bytes())
521 return false;
522 break;
523 case aco_opcode::p_extract_vector:
524 case aco_opcode::p_extract:
525 if (temp.type() == RegType::sgpr && !can_accept_sgpr)
526 return false;
527 break;
528 case aco_opcode::p_split_vector: {
529 if (temp.type() == RegType::sgpr && !can_accept_sgpr)
530 return false;
531 /* don't increase the vector size */
532 if (temp.bytes() > instr->operands[index].bytes())
533 return false;
534 /* We can decrease the vector size as smaller temporaries are only
535 * propagated by p_as_uniform instructions.
536 * If this propagation leads to invalid IR or hits the assertion below,
537 * it means that some undefined bytes within a dword are begin accessed
538 * and a bug in instruction_selection is likely. */
539 int decrease = instr->operands[index].bytes() - temp.bytes();
540 while (decrease > 0) {
541 decrease -= instr->definitions.back().bytes();
542 instr->definitions.pop_back();
543 }
544 assert(decrease == 0);
545 break;
546 }
547 case aco_opcode::p_as_uniform:
548 if (temp.regClass() == instr->definitions[0].regClass())
549 instr->opcode = aco_opcode::p_parallelcopy;
550 break;
551 default: return false;
552 }
553
554 instr->operands[index].setTemp(temp);
555 return true;
556 }
557
558 /* This expects the DPP modifier to be removed. */
559 bool
can_apply_sgprs(opt_ctx & ctx,aco_ptr<Instruction> & instr)560 can_apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
561 {
562 assert(instr->isVALU());
563 if (instr->isSDWA() && ctx.program->gfx_level < GFX9)
564 return false;
565 return instr->opcode != aco_opcode::v_readfirstlane_b32 &&
566 instr->opcode != aco_opcode::v_readlane_b32 &&
567 instr->opcode != aco_opcode::v_readlane_b32_e64 &&
568 instr->opcode != aco_opcode::v_writelane_b32 &&
569 instr->opcode != aco_opcode::v_writelane_b32_e64 &&
570 instr->opcode != aco_opcode::v_permlane16_b32 &&
571 instr->opcode != aco_opcode::v_permlanex16_b32 &&
572 instr->opcode != aco_opcode::v_permlane64_b32 &&
573 instr->opcode != aco_opcode::v_interp_p1_f32 &&
574 instr->opcode != aco_opcode::v_interp_p2_f32 &&
575 instr->opcode != aco_opcode::v_interp_mov_f32 &&
576 instr->opcode != aco_opcode::v_interp_p1ll_f16 &&
577 instr->opcode != aco_opcode::v_interp_p1lv_f16 &&
578 instr->opcode != aco_opcode::v_interp_p2_legacy_f16 &&
579 instr->opcode != aco_opcode::v_interp_p2_f16 &&
580 instr->opcode != aco_opcode::v_interp_p2_hi_f16 &&
581 instr->opcode != aco_opcode::v_interp_p10_f32_inreg &&
582 instr->opcode != aco_opcode::v_interp_p2_f32_inreg &&
583 instr->opcode != aco_opcode::v_interp_p10_f16_f32_inreg &&
584 instr->opcode != aco_opcode::v_interp_p2_f16_f32_inreg &&
585 instr->opcode != aco_opcode::v_interp_p10_rtz_f16_f32_inreg &&
586 instr->opcode != aco_opcode::v_interp_p2_rtz_f16_f32_inreg &&
587 instr->opcode != aco_opcode::v_wmma_f32_16x16x16_f16 &&
588 instr->opcode != aco_opcode::v_wmma_f32_16x16x16_bf16 &&
589 instr->opcode != aco_opcode::v_wmma_f16_16x16x16_f16 &&
590 instr->opcode != aco_opcode::v_wmma_bf16_16x16x16_bf16 &&
591 instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu8 &&
592 instr->opcode != aco_opcode::v_wmma_i32_16x16x16_iu4;
593 }
594
595 /* only covers special cases */
596 bool
alu_can_accept_constant(const aco_ptr<Instruction> & instr,unsigned operand)597 alu_can_accept_constant(const aco_ptr<Instruction>& instr, unsigned operand)
598 {
599 /* Fixed operands can't accept constants because we need them
600 * to be in their fixed register.
601 */
602 assert(instr->operands.size() > operand);
603 if (instr->operands[operand].isFixed())
604 return false;
605
606 /* SOPP instructions can't use constants. */
607 if (instr->isSOPP())
608 return false;
609
610 switch (instr->opcode) {
611 case aco_opcode::s_fmac_f16:
612 case aco_opcode::s_fmac_f32:
613 case aco_opcode::v_mac_f32:
614 case aco_opcode::v_writelane_b32:
615 case aco_opcode::v_writelane_b32_e64:
616 case aco_opcode::v_cndmask_b32: return operand != 2;
617 case aco_opcode::s_addk_i32:
618 case aco_opcode::s_mulk_i32:
619 case aco_opcode::p_extract_vector:
620 case aco_opcode::p_split_vector:
621 case aco_opcode::v_readlane_b32:
622 case aco_opcode::v_readlane_b32_e64:
623 case aco_opcode::v_readfirstlane_b32:
624 case aco_opcode::p_extract:
625 case aco_opcode::p_insert: return operand != 0;
626 case aco_opcode::p_bpermute_readlane:
627 case aco_opcode::p_bpermute_shared_vgpr:
628 case aco_opcode::p_bpermute_permlane:
629 case aco_opcode::p_interp_gfx11:
630 case aco_opcode::p_dual_src_export_gfx11:
631 case aco_opcode::v_interp_p1_f32:
632 case aco_opcode::v_interp_p2_f32:
633 case aco_opcode::v_interp_mov_f32:
634 case aco_opcode::v_interp_p1ll_f16:
635 case aco_opcode::v_interp_p1lv_f16:
636 case aco_opcode::v_interp_p2_legacy_f16:
637 case aco_opcode::v_interp_p10_f32_inreg:
638 case aco_opcode::v_interp_p2_f32_inreg:
639 case aco_opcode::v_interp_p10_f16_f32_inreg:
640 case aco_opcode::v_interp_p2_f16_f32_inreg:
641 case aco_opcode::v_interp_p10_rtz_f16_f32_inreg:
642 case aco_opcode::v_interp_p2_rtz_f16_f32_inreg:
643 case aco_opcode::v_wmma_f32_16x16x16_f16:
644 case aco_opcode::v_wmma_f32_16x16x16_bf16:
645 case aco_opcode::v_wmma_f16_16x16x16_f16:
646 case aco_opcode::v_wmma_bf16_16x16x16_bf16:
647 case aco_opcode::v_wmma_i32_16x16x16_iu8:
648 case aco_opcode::v_wmma_i32_16x16x16_iu4: return false;
649 default: return true;
650 }
651 }
652
653 bool
valu_can_accept_vgpr(aco_ptr<Instruction> & instr,unsigned operand)654 valu_can_accept_vgpr(aco_ptr<Instruction>& instr, unsigned operand)
655 {
656 if (instr->opcode == aco_opcode::v_writelane_b32 ||
657 instr->opcode == aco_opcode::v_writelane_b32_e64)
658 return operand == 2;
659 if (instr->opcode == aco_opcode::v_permlane16_b32 ||
660 instr->opcode == aco_opcode::v_permlanex16_b32 ||
661 instr->opcode == aco_opcode::v_readlane_b32 ||
662 instr->opcode == aco_opcode::v_readlane_b32_e64)
663 return operand == 0;
664 return instr_info.classes[(int)instr->opcode] != instr_class::valu_pseudo_scalar_trans;
665 }
666
667 /* check constant bus and literal limitations */
668 bool
check_vop3_operands(opt_ctx & ctx,unsigned num_operands,Operand * operands)669 check_vop3_operands(opt_ctx& ctx, unsigned num_operands, Operand* operands)
670 {
671 int limit = ctx.program->gfx_level >= GFX10 ? 2 : 1;
672 Operand literal32(s1);
673 Operand literal64(s2);
674 unsigned num_sgprs = 0;
675 unsigned sgpr[] = {0, 0};
676
677 for (unsigned i = 0; i < num_operands; i++) {
678 Operand op = operands[i];
679
680 if (op.hasRegClass() && op.regClass().type() == RegType::sgpr) {
681 /* two reads of the same SGPR count as 1 to the limit */
682 if (op.tempId() != sgpr[0] && op.tempId() != sgpr[1]) {
683 if (num_sgprs < 2)
684 sgpr[num_sgprs++] = op.tempId();
685 limit--;
686 if (limit < 0)
687 return false;
688 }
689 } else if (op.isLiteral()) {
690 if (ctx.program->gfx_level < GFX10)
691 return false;
692
693 if (!literal32.isUndefined() && literal32.constantValue() != op.constantValue())
694 return false;
695 if (!literal64.isUndefined() && literal64.constantValue() != op.constantValue())
696 return false;
697
698 /* Any number of 32-bit literals counts as only 1 to the limit. Same
699 * (but separately) for 64-bit literals. */
700 if (op.size() == 1 && literal32.isUndefined()) {
701 limit--;
702 literal32 = op;
703 } else if (op.size() == 2 && literal64.isUndefined()) {
704 limit--;
705 literal64 = op;
706 }
707
708 if (limit < 0)
709 return false;
710 }
711 }
712
713 return true;
714 }
715
716 bool
parse_base_offset(opt_ctx & ctx,Instruction * instr,unsigned op_index,Temp * base,uint32_t * offset,bool prevent_overflow)717 parse_base_offset(opt_ctx& ctx, Instruction* instr, unsigned op_index, Temp* base, uint32_t* offset,
718 bool prevent_overflow)
719 {
720 Operand op = instr->operands[op_index];
721
722 if (!op.isTemp())
723 return false;
724 Temp tmp = op.getTemp();
725 if (!ctx.info[tmp.id()].is_add_sub())
726 return false;
727
728 Instruction* add_instr = ctx.info[tmp.id()].instr;
729
730 unsigned mask = 0x3;
731 bool is_sub = false;
732 switch (add_instr->opcode) {
733 case aco_opcode::v_add_u32:
734 case aco_opcode::v_add_co_u32:
735 case aco_opcode::v_add_co_u32_e64:
736 case aco_opcode::s_add_i32:
737 case aco_opcode::s_add_u32: break;
738 case aco_opcode::v_sub_u32:
739 case aco_opcode::v_sub_i32:
740 case aco_opcode::v_sub_co_u32:
741 case aco_opcode::v_sub_co_u32_e64:
742 case aco_opcode::s_sub_u32:
743 case aco_opcode::s_sub_i32:
744 mask = 0x2;
745 is_sub = true;
746 break;
747 case aco_opcode::v_subrev_u32:
748 case aco_opcode::v_subrev_co_u32:
749 case aco_opcode::v_subrev_co_u32_e64:
750 mask = 0x1;
751 is_sub = true;
752 break;
753 default: return false;
754 }
755 if (prevent_overflow && !add_instr->definitions[0].isNUW())
756 return false;
757
758 if (add_instr->usesModifiers())
759 return false;
760
761 u_foreach_bit (i, mask) {
762 if (add_instr->operands[i].isConstant()) {
763 *offset = add_instr->operands[i].constantValue() * (uint32_t)(is_sub ? -1 : 1);
764 } else if (add_instr->operands[i].isTemp() &&
765 ctx.info[add_instr->operands[i].tempId()].is_constant_or_literal(32)) {
766 *offset = ctx.info[add_instr->operands[i].tempId()].val * (uint32_t)(is_sub ? -1 : 1);
767 } else {
768 continue;
769 }
770 if (!add_instr->operands[!i].isTemp())
771 continue;
772
773 uint32_t offset2 = 0;
774 if (parse_base_offset(ctx, add_instr, !i, base, &offset2, prevent_overflow)) {
775 *offset += offset2;
776 } else {
777 *base = add_instr->operands[!i].getTemp();
778 }
779 return true;
780 }
781
782 return false;
783 }
784
785 void
skip_smem_offset_align(opt_ctx & ctx,SMEM_instruction * smem)786 skip_smem_offset_align(opt_ctx& ctx, SMEM_instruction* smem)
787 {
788 bool soe = smem->operands.size() >= (!smem->definitions.empty() ? 3 : 4);
789 if (soe && !smem->operands[1].isConstant())
790 return;
791 /* We don't need to check the constant offset because the address seems to be calculated with
792 * (offset&-4 + const_offset&-4), not (offset+const_offset)&-4.
793 */
794
795 Operand& op = smem->operands[soe ? smem->operands.size() - 1 : 1];
796 if (!op.isTemp() || !ctx.info[op.tempId()].is_bitwise())
797 return;
798
799 Instruction* bitwise_instr = ctx.info[op.tempId()].instr;
800 if (bitwise_instr->opcode != aco_opcode::s_and_b32)
801 return;
802
803 if (bitwise_instr->operands[0].constantEquals(-4) &&
804 bitwise_instr->operands[1].isOfType(op.regClass().type()))
805 op.setTemp(bitwise_instr->operands[1].getTemp());
806 else if (bitwise_instr->operands[1].constantEquals(-4) &&
807 bitwise_instr->operands[0].isOfType(op.regClass().type()))
808 op.setTemp(bitwise_instr->operands[0].getTemp());
809 }
810
811 void
smem_combine(opt_ctx & ctx,aco_ptr<Instruction> & instr)812 smem_combine(opt_ctx& ctx, aco_ptr<Instruction>& instr)
813 {
814 /* skip &-4 before offset additions: load((a + 16) & -4, 0) */
815 if (!instr->operands.empty())
816 skip_smem_offset_align(ctx, &instr->smem());
817
818 /* propagate constants and combine additions */
819 if (!instr->operands.empty() && instr->operands[1].isTemp()) {
820 SMEM_instruction& smem = instr->smem();
821 ssa_info info = ctx.info[instr->operands[1].tempId()];
822
823 Temp base;
824 uint32_t offset;
825 if (info.is_constant_or_literal(32) &&
826 ((ctx.program->gfx_level == GFX6 && info.val <= 0x3FF) ||
827 (ctx.program->gfx_level == GFX7 && info.val <= 0xFFFFFFFF) ||
828 (ctx.program->gfx_level >= GFX8 && info.val <= 0xFFFFF))) {
829 instr->operands[1] = Operand::c32(info.val);
830 } else if (parse_base_offset(ctx, instr.get(), 1, &base, &offset, true) &&
831 base.regClass() == s1 && offset <= 0xFFFFF && ctx.program->gfx_level >= GFX9 &&
832 offset % 4u == 0) {
833 bool soe = smem.operands.size() >= (!smem.definitions.empty() ? 3 : 4);
834 if (soe) {
835 if (ctx.info[smem.operands.back().tempId()].is_constant_or_literal(32) &&
836 ctx.info[smem.operands.back().tempId()].val == 0) {
837 smem.operands[1] = Operand::c32(offset);
838 smem.operands.back() = Operand(base);
839 }
840 } else {
841 Instruction* new_instr = create_instruction(
842 smem.opcode, Format::SMEM, smem.operands.size() + 1, smem.definitions.size());
843 new_instr->operands[0] = smem.operands[0];
844 new_instr->operands[1] = Operand::c32(offset);
845 if (smem.definitions.empty())
846 new_instr->operands[2] = smem.operands[2];
847 new_instr->operands.back() = Operand(base);
848 if (!smem.definitions.empty())
849 new_instr->definitions[0] = smem.definitions[0];
850 new_instr->smem().sync = smem.sync;
851 new_instr->smem().cache = smem.cache;
852 instr.reset(new_instr);
853 }
854 }
855 }
856
857 /* skip &-4 after offset additions: load(a & -4, 16) */
858 if (!instr->operands.empty())
859 skip_smem_offset_align(ctx, &instr->smem());
860 }
861
862 Operand
get_constant_op(opt_ctx & ctx,ssa_info info,uint32_t bits)863 get_constant_op(opt_ctx& ctx, ssa_info info, uint32_t bits)
864 {
865 if (bits == 64)
866 return Operand::c32_or_c64(info.val, true);
867 return Operand::get_const(ctx.program->gfx_level, info.val, bits / 8u);
868 }
869
870 void
propagate_constants_vop3p(opt_ctx & ctx,aco_ptr<Instruction> & instr,ssa_info & info,unsigned i)871 propagate_constants_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr, ssa_info& info, unsigned i)
872 {
873 if (!info.is_constant_or_literal(32))
874 return;
875
876 assert(instr->operands[i].isTemp());
877 unsigned bits = get_operand_size(instr, i);
878 if (info.is_constant(bits)) {
879 instr->operands[i] = get_constant_op(ctx, info, bits);
880 return;
881 }
882
883 /* The accumulation operand of dot product instructions ignores opsel. */
884 bool cannot_use_opsel =
885 (instr->opcode == aco_opcode::v_dot4_i32_i8 || instr->opcode == aco_opcode::v_dot2_i32_i16 ||
886 instr->opcode == aco_opcode::v_dot4_i32_iu8 || instr->opcode == aco_opcode::v_dot4_u32_u8 ||
887 instr->opcode == aco_opcode::v_dot2_u32_u16) &&
888 i == 2;
889 if (cannot_use_opsel)
890 return;
891
892 /* try to fold inline constants */
893 VALU_instruction* vop3p = &instr->valu();
894 bool opsel_lo = vop3p->opsel_lo[i];
895 bool opsel_hi = vop3p->opsel_hi[i];
896
897 Operand const_op[2];
898 bool const_opsel[2] = {false, false};
899 for (unsigned j = 0; j < 2; j++) {
900 if ((unsigned)opsel_lo != j && (unsigned)opsel_hi != j)
901 continue; /* this half is unused */
902
903 uint16_t val = info.val >> (j ? 16 : 0);
904 Operand op = Operand::get_const(ctx.program->gfx_level, val, bits / 8u);
905 if (bits == 32 && op.isLiteral()) /* try sign extension */
906 op = Operand::get_const(ctx.program->gfx_level, val | 0xffff0000, 4);
907 if (bits == 32 && op.isLiteral()) { /* try shifting left */
908 op = Operand::get_const(ctx.program->gfx_level, val << 16, 4);
909 const_opsel[j] = true;
910 }
911 if (op.isLiteral())
912 return;
913 const_op[j] = op;
914 }
915
916 Operand const_lo = const_op[0];
917 Operand const_hi = const_op[1];
918 bool const_lo_opsel = const_opsel[0];
919 bool const_hi_opsel = const_opsel[1];
920
921 if (opsel_lo == opsel_hi) {
922 /* use the single 16bit value */
923 instr->operands[i] = opsel_lo ? const_hi : const_lo;
924
925 /* opsel must point the same for both halves */
926 opsel_lo = opsel_lo ? const_hi_opsel : const_lo_opsel;
927 opsel_hi = opsel_lo;
928 } else if (const_lo == const_hi) {
929 /* both constants are the same */
930 instr->operands[i] = const_lo;
931
932 /* opsel must point the same for both halves */
933 opsel_lo = const_lo_opsel;
934 opsel_hi = const_lo_opsel;
935 } else if (const_lo.constantValue16(const_lo_opsel) ==
936 const_hi.constantValue16(!const_hi_opsel)) {
937 instr->operands[i] = const_hi;
938
939 /* redirect opsel selection */
940 opsel_lo = opsel_lo ? const_hi_opsel : !const_hi_opsel;
941 opsel_hi = opsel_hi ? const_hi_opsel : !const_hi_opsel;
942 } else if (const_hi.constantValue16(const_hi_opsel) ==
943 const_lo.constantValue16(!const_lo_opsel)) {
944 instr->operands[i] = const_lo;
945
946 /* redirect opsel selection */
947 opsel_lo = opsel_lo ? !const_lo_opsel : const_lo_opsel;
948 opsel_hi = opsel_hi ? !const_lo_opsel : const_lo_opsel;
949 } else if (bits == 16 && const_lo.constantValue() == (const_hi.constantValue() ^ (1 << 15))) {
950 assert(const_lo_opsel == false && const_hi_opsel == false);
951
952 /* const_lo == -const_hi */
953 if (!can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, i))
954 return;
955
956 instr->operands[i] = Operand::c16(const_lo.constantValue() & 0x7FFF);
957 bool neg_lo = const_lo.constantValue() & (1 << 15);
958 vop3p->neg_lo[i] ^= opsel_lo ^ neg_lo;
959 vop3p->neg_hi[i] ^= opsel_hi ^ neg_lo;
960
961 /* opsel must point to lo for both operands */
962 opsel_lo = false;
963 opsel_hi = false;
964 }
965
966 vop3p->opsel_lo[i] = opsel_lo;
967 vop3p->opsel_hi[i] = opsel_hi;
968 }
969
970 bool
fixed_to_exec(Operand op)971 fixed_to_exec(Operand op)
972 {
973 return op.isFixed() && op.physReg() == exec;
974 }
975
976 SubdwordSel
parse_extract(Instruction * instr)977 parse_extract(Instruction* instr)
978 {
979 if (instr->opcode == aco_opcode::p_extract) {
980 unsigned size = instr->operands[2].constantValue() / 8;
981 unsigned offset = instr->operands[1].constantValue() * size;
982 bool sext = instr->operands[3].constantEquals(1);
983 return SubdwordSel(size, offset, sext);
984 } else if (instr->opcode == aco_opcode::p_insert && instr->operands[1].constantEquals(0)) {
985 return instr->operands[2].constantEquals(8) ? SubdwordSel::ubyte : SubdwordSel::uword;
986 } else if (instr->opcode == aco_opcode::p_extract_vector) {
987 unsigned size = instr->definitions[0].bytes();
988 unsigned offset = instr->operands[1].constantValue() * size;
989 if (size <= 2)
990 return SubdwordSel(size, offset, false);
991 } else if (instr->opcode == aco_opcode::p_split_vector) {
992 assert(instr->operands[0].bytes() == 4 && instr->definitions[1].bytes() == 2);
993 return SubdwordSel(2, 2, false);
994 }
995
996 return SubdwordSel();
997 }
998
999 SubdwordSel
parse_insert(Instruction * instr)1000 parse_insert(Instruction* instr)
1001 {
1002 if (instr->opcode == aco_opcode::p_extract && instr->operands[3].constantEquals(0) &&
1003 instr->operands[1].constantEquals(0)) {
1004 return instr->operands[2].constantEquals(8) ? SubdwordSel::ubyte : SubdwordSel::uword;
1005 } else if (instr->opcode == aco_opcode::p_insert) {
1006 unsigned size = instr->operands[2].constantValue() / 8;
1007 unsigned offset = instr->operands[1].constantValue() * size;
1008 return SubdwordSel(size, offset, false);
1009 } else {
1010 return SubdwordSel();
1011 }
1012 }
1013
1014 bool
can_apply_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr,unsigned idx,ssa_info & info)1015 can_apply_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, ssa_info& info)
1016 {
1017 Temp tmp = info.instr->operands[0].getTemp();
1018 SubdwordSel sel = parse_extract(info.instr);
1019
1020 if (!sel) {
1021 return false;
1022 } else if (sel.size() == 4) {
1023 return true;
1024 } else if ((instr->opcode == aco_opcode::v_cvt_f32_u32 ||
1025 instr->opcode == aco_opcode::v_cvt_f32_i32) &&
1026 sel.size() == 1 && !sel.sign_extend()) {
1027 return true;
1028 } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && instr->operands[0].isConstant() &&
1029 sel.offset() == 0 &&
1030 ((sel.size() == 2 && instr->operands[0].constantValue() >= 16u) ||
1031 (sel.size() == 1 && instr->operands[0].constantValue() >= 24u))) {
1032 return true;
1033 } else if (instr->opcode == aco_opcode::v_mul_u32_u24 && ctx.program->gfx_level >= GFX10 &&
1034 !instr->usesModifiers() && sel.size() == 2 && !sel.sign_extend() &&
1035 (instr->operands[!idx].is16bit() ||
1036 (instr->operands[!idx].isConstant() &&
1037 instr->operands[!idx].constantValue() <= UINT16_MAX))) {
1038 return true;
1039 } else if (idx < 2 && can_use_SDWA(ctx.program->gfx_level, instr, true) &&
1040 (tmp.type() == RegType::vgpr || ctx.program->gfx_level >= GFX9)) {
1041 if (instr->isSDWA() && instr->sdwa().sel[idx] != SubdwordSel::dword)
1042 return false;
1043 return true;
1044 } else if (instr->isVALU() && sel.size() == 2 && !instr->valu().opsel[idx] &&
1045 can_use_opsel(ctx.program->gfx_level, instr->opcode, idx)) {
1046 return true;
1047 } else if (instr->opcode == aco_opcode::s_pack_ll_b32_b16 && sel.size() == 2 &&
1048 (idx == 1 || ctx.program->gfx_level >= GFX11 || !sel.offset())) {
1049 return true;
1050 } else if (sel.size() == 2 &&
1051 ((instr->opcode == aco_opcode::s_pack_lh_b32_b16 && idx == 0) ||
1052 (instr->opcode == aco_opcode::s_pack_hl_b32_b16 && idx == 1))) {
1053 return true;
1054 } else if (instr->opcode == aco_opcode::p_extract) {
1055 SubdwordSel instrSel = parse_extract(instr.get());
1056
1057 /* the outer offset must be within extracted range */
1058 if (instrSel.offset() >= sel.size())
1059 return false;
1060
1061 /* don't remove the sign-extension when increasing the size further */
1062 if (instrSel.size() > sel.size() && !instrSel.sign_extend() && sel.sign_extend())
1063 return false;
1064
1065 return true;
1066 }
1067
1068 return false;
1069 }
1070
1071 /* Combine an p_extract (or p_insert, in some cases) instruction with instr.
1072 * instr(p_extract(...)) -> instr()
1073 */
1074 void
apply_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr,unsigned idx,ssa_info & info)1075 apply_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned idx, ssa_info& info)
1076 {
1077 Temp tmp = info.instr->operands[0].getTemp();
1078 SubdwordSel sel = parse_extract(info.instr);
1079 assert(sel);
1080
1081 instr->operands[idx].set16bit(false);
1082 instr->operands[idx].set24bit(false);
1083
1084 ctx.info[tmp.id()].label &= ~label_insert;
1085
1086 if (sel.size() == 4) {
1087 /* full dword selection */
1088 } else if ((instr->opcode == aco_opcode::v_cvt_f32_u32 ||
1089 instr->opcode == aco_opcode::v_cvt_f32_i32) &&
1090 sel.size() == 1 && !sel.sign_extend()) {
1091 switch (sel.offset()) {
1092 case 0: instr->opcode = aco_opcode::v_cvt_f32_ubyte0; break;
1093 case 1: instr->opcode = aco_opcode::v_cvt_f32_ubyte1; break;
1094 case 2: instr->opcode = aco_opcode::v_cvt_f32_ubyte2; break;
1095 case 3: instr->opcode = aco_opcode::v_cvt_f32_ubyte3; break;
1096 }
1097 } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && instr->operands[0].isConstant() &&
1098 sel.offset() == 0 &&
1099 ((sel.size() == 2 && instr->operands[0].constantValue() >= 16u) ||
1100 (sel.size() == 1 && instr->operands[0].constantValue() >= 24u))) {
1101 /* The undesirable upper bits are already shifted out. */
1102 return;
1103 } else if (instr->opcode == aco_opcode::v_mul_u32_u24 && ctx.program->gfx_level >= GFX10 &&
1104 !instr->usesModifiers() && sel.size() == 2 && !sel.sign_extend() &&
1105 (instr->operands[!idx].is16bit() ||
1106 instr->operands[!idx].constantValue() <= UINT16_MAX)) {
1107 Instruction* mad = create_instruction(aco_opcode::v_mad_u32_u16, Format::VOP3, 3, 1);
1108 mad->definitions[0] = instr->definitions[0];
1109 mad->operands[0] = instr->operands[0];
1110 mad->operands[1] = instr->operands[1];
1111 mad->operands[2] = Operand::zero();
1112 mad->valu().opsel[idx] = sel.offset();
1113 mad->pass_flags = instr->pass_flags;
1114 instr.reset(mad);
1115 } else if (can_use_SDWA(ctx.program->gfx_level, instr, true) &&
1116 (tmp.type() == RegType::vgpr || ctx.program->gfx_level >= GFX9)) {
1117 convert_to_SDWA(ctx.program->gfx_level, instr);
1118 instr->sdwa().sel[idx] = sel;
1119 } else if (instr->isVALU()) {
1120 if (sel.offset()) {
1121 instr->valu().opsel[idx] = true;
1122
1123 /* VOP12C cannot use opsel with SGPRs. */
1124 if (!instr->isVOP3() && !instr->isVINTERP_INREG() &&
1125 !info.instr->operands[0].isOfType(RegType::vgpr))
1126 instr->format = asVOP3(instr->format);
1127 }
1128 } else if (instr->opcode == aco_opcode::s_pack_ll_b32_b16) {
1129 if (sel.offset())
1130 instr->opcode = idx ? aco_opcode::s_pack_lh_b32_b16 : aco_opcode::s_pack_hl_b32_b16;
1131 } else if (instr->opcode == aco_opcode::s_pack_lh_b32_b16 ||
1132 instr->opcode == aco_opcode::s_pack_hl_b32_b16) {
1133 if (sel.offset())
1134 instr->opcode = aco_opcode::s_pack_hh_b32_b16;
1135 } else if (instr->opcode == aco_opcode::p_extract) {
1136 SubdwordSel instrSel = parse_extract(instr.get());
1137
1138 unsigned size = std::min(sel.size(), instrSel.size());
1139 unsigned offset = sel.offset() + instrSel.offset();
1140 unsigned sign_extend =
1141 instrSel.sign_extend() && (sel.sign_extend() || instrSel.size() <= sel.size());
1142
1143 instr->operands[1] = Operand::c32(offset / size);
1144 instr->operands[2] = Operand::c32(size * 8u);
1145 instr->operands[3] = Operand::c32(sign_extend);
1146 return;
1147 }
1148
1149 /* These are the only labels worth keeping at the moment. */
1150 for (Definition& def : instr->definitions) {
1151 ctx.info[def.tempId()].label &=
1152 (label_mul | label_minmax | label_usedef | label_vopc | label_f2f32 | instr_mod_labels);
1153 if (ctx.info[def.tempId()].label & instr_usedef_labels)
1154 ctx.info[def.tempId()].instr = instr.get();
1155 }
1156 }
1157
1158 void
check_sdwa_extract(opt_ctx & ctx,aco_ptr<Instruction> & instr)1159 check_sdwa_extract(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1160 {
1161 for (unsigned i = 0; i < instr->operands.size(); i++) {
1162 Operand op = instr->operands[i];
1163 if (!op.isTemp())
1164 continue;
1165 ssa_info& info = ctx.info[op.tempId()];
1166 if (info.is_extract() && (info.instr->operands[0].getTemp().type() == RegType::vgpr ||
1167 op.getTemp().type() == RegType::sgpr)) {
1168 if (!can_apply_extract(ctx, instr, i, info))
1169 info.label &= ~label_extract;
1170 }
1171 }
1172 }
1173
1174 bool
does_fp_op_flush_denorms(opt_ctx & ctx,aco_opcode op)1175 does_fp_op_flush_denorms(opt_ctx& ctx, aco_opcode op)
1176 {
1177 switch (op) {
1178 case aco_opcode::v_min_f32:
1179 case aco_opcode::v_max_f32:
1180 case aco_opcode::v_med3_f32:
1181 case aco_opcode::v_min3_f32:
1182 case aco_opcode::v_max3_f32:
1183 case aco_opcode::v_min_f16:
1184 case aco_opcode::v_max_f16: return ctx.program->gfx_level > GFX8;
1185 case aco_opcode::v_cndmask_b32:
1186 case aco_opcode::v_cndmask_b16:
1187 case aco_opcode::v_mov_b32:
1188 case aco_opcode::v_mov_b16: return false;
1189 default: return true;
1190 }
1191 }
1192
1193 bool
can_eliminate_fcanonicalize(opt_ctx & ctx,aco_ptr<Instruction> & instr,Temp tmp,unsigned idx)1194 can_eliminate_fcanonicalize(opt_ctx& ctx, aco_ptr<Instruction>& instr, Temp tmp, unsigned idx)
1195 {
1196 float_mode* fp = &ctx.fp_mode;
1197 if (ctx.info[tmp.id()].is_canonicalized() ||
1198 (tmp.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep)
1199 return true;
1200
1201 aco_opcode op = instr->opcode;
1202 return can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, idx) &&
1203 does_fp_op_flush_denorms(ctx, op);
1204 }
1205
1206 bool
can_eliminate_and_exec(opt_ctx & ctx,Temp tmp,unsigned pass_flags)1207 can_eliminate_and_exec(opt_ctx& ctx, Temp tmp, unsigned pass_flags)
1208 {
1209 if (ctx.info[tmp.id()].is_vopc()) {
1210 Instruction* vopc_instr = ctx.info[tmp.id()].instr;
1211 /* Remove superfluous s_and when the VOPC instruction uses the same exec and thus
1212 * already produces the same result */
1213 return vopc_instr->pass_flags == pass_flags;
1214 }
1215 if (ctx.info[tmp.id()].is_bitwise()) {
1216 Instruction* instr = ctx.info[tmp.id()].instr;
1217 if (instr->operands.size() != 2 || instr->pass_flags != pass_flags)
1218 return false;
1219 if (!(instr->operands[0].isTemp() && instr->operands[1].isTemp()))
1220 return false;
1221 if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_and_b64) {
1222 return can_eliminate_and_exec(ctx, instr->operands[0].getTemp(), pass_flags) ||
1223 can_eliminate_and_exec(ctx, instr->operands[1].getTemp(), pass_flags);
1224 } else {
1225 return can_eliminate_and_exec(ctx, instr->operands[0].getTemp(), pass_flags) &&
1226 can_eliminate_and_exec(ctx, instr->operands[1].getTemp(), pass_flags);
1227 }
1228 }
1229 return false;
1230 }
1231
1232 bool
is_copy_label(opt_ctx & ctx,aco_ptr<Instruction> & instr,ssa_info & info,unsigned idx)1233 is_copy_label(opt_ctx& ctx, aco_ptr<Instruction>& instr, ssa_info& info, unsigned idx)
1234 {
1235 return info.is_temp() ||
1236 (info.is_fcanonicalize() && can_eliminate_fcanonicalize(ctx, instr, info.temp, idx));
1237 }
1238
1239 bool
is_op_canonicalized(opt_ctx & ctx,Operand op)1240 is_op_canonicalized(opt_ctx& ctx, Operand op)
1241 {
1242 float_mode* fp = &ctx.fp_mode;
1243 if ((op.isTemp() && ctx.info[op.tempId()].is_canonicalized()) ||
1244 (op.bytes() == 4 ? fp->denorm32 : fp->denorm16_64) == fp_denorm_keep)
1245 return true;
1246
1247 if (op.isConstant() || (op.isTemp() && ctx.info[op.tempId()].is_constant_or_literal(32))) {
1248 uint32_t val = op.isTemp() ? ctx.info[op.tempId()].val : op.constantValue();
1249 if (op.bytes() == 2)
1250 return (val & 0x7fff) == 0 || (val & 0x7fff) > 0x3ff;
1251 else if (op.bytes() == 4)
1252 return (val & 0x7fffffff) == 0 || (val & 0x7fffffff) > 0x7fffff;
1253 }
1254 return false;
1255 }
1256
1257 bool
is_scratch_offset_valid(opt_ctx & ctx,Instruction * instr,int64_t offset0,int64_t offset1)1258 is_scratch_offset_valid(opt_ctx& ctx, Instruction* instr, int64_t offset0, int64_t offset1)
1259 {
1260 bool negative_unaligned_scratch_offset_bug = ctx.program->gfx_level == GFX10;
1261 int32_t min = ctx.program->dev.scratch_global_offset_min;
1262 int32_t max = ctx.program->dev.scratch_global_offset_max;
1263
1264 int64_t offset = offset0 + offset1;
1265
1266 bool has_vgpr_offset = instr && !instr->operands[0].isUndefined();
1267 if (negative_unaligned_scratch_offset_bug && has_vgpr_offset && offset < 0 && offset % 4)
1268 return false;
1269
1270 return offset >= min && offset <= max;
1271 }
1272
1273 bool
detect_clamp(Instruction * instr,unsigned * clamped_idx)1274 detect_clamp(Instruction* instr, unsigned* clamped_idx)
1275 {
1276 VALU_instruction& valu = instr->valu();
1277 if (valu.omod != 0 || valu.opsel != 0)
1278 return false;
1279
1280 unsigned idx = 0;
1281 bool found_zero = false, found_one = false;
1282 bool is_fp16 = instr->opcode == aco_opcode::v_med3_f16;
1283 for (unsigned i = 0; i < 3; i++) {
1284 if (!valu.neg[i] && instr->operands[i].constantEquals(0))
1285 found_zero = true;
1286 else if (!valu.neg[i] &&
1287 instr->operands[i].constantEquals(is_fp16 ? 0x3c00 : 0x3f800000)) /* 1.0 */
1288 found_one = true;
1289 else
1290 idx = i;
1291 }
1292 if (found_zero && found_one && instr->operands[idx].isTemp()) {
1293 *clamped_idx = idx;
1294 return true;
1295 } else {
1296 return false;
1297 }
1298 }
1299
1300 void
label_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)1301 label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
1302 {
1303 if (instr->isSMEM())
1304 smem_combine(ctx, instr);
1305
1306 for (unsigned i = 0; i < instr->operands.size(); i++) {
1307 if (!instr->operands[i].isTemp())
1308 continue;
1309
1310 ssa_info info = ctx.info[instr->operands[i].tempId()];
1311 /* propagate reg->reg of same type */
1312 while (info.is_temp() && info.temp.regClass() == instr->operands[i].getTemp().regClass()) {
1313 instr->operands[i].setTemp(ctx.info[instr->operands[i].tempId()].temp);
1314 info = ctx.info[info.temp.id()];
1315 }
1316
1317 /* PSEUDO: propagate temporaries */
1318 if (instr->isPseudo()) {
1319 while (info.is_temp()) {
1320 pseudo_propagate_temp(ctx, instr, info.temp, i);
1321 info = ctx.info[info.temp.id()];
1322 }
1323 }
1324
1325 /* SALU / PSEUDO: propagate inline constants */
1326 if (instr->isSALU() || instr->isPseudo()) {
1327 unsigned bits = get_operand_size(instr, i);
1328 if ((info.is_constant(bits) || (info.is_literal(bits) && instr->isPseudo())) &&
1329 alu_can_accept_constant(instr, i)) {
1330 instr->operands[i] = get_constant_op(ctx, info, bits);
1331 continue;
1332 }
1333 }
1334
1335 /* VALU: propagate neg, abs & inline constants */
1336 else if (instr->isVALU()) {
1337 if (is_copy_label(ctx, instr, info, i) && info.temp.type() == RegType::vgpr &&
1338 valu_can_accept_vgpr(instr, i)) {
1339 instr->operands[i].setTemp(info.temp);
1340 info = ctx.info[info.temp.id()];
1341 }
1342 /* applying SGPRs to VOP1 doesn't increase code size and DCE is helped by doing it earlier */
1343 if (info.is_temp() && info.temp.type() == RegType::sgpr && can_apply_sgprs(ctx, instr) &&
1344 instr->operands.size() == 1) {
1345 instr->format = withoutDPP(instr->format);
1346 instr->operands[i].setTemp(info.temp);
1347 info = ctx.info[info.temp.id()];
1348 }
1349
1350 /* for instructions other than v_cndmask_b32, the size of the instruction should match the
1351 * operand size */
1352 bool can_use_mod =
1353 instr->opcode != aco_opcode::v_cndmask_b32 || instr->operands[i].getTemp().bytes() == 4;
1354 can_use_mod &= can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, i);
1355
1356 bool packed_math = instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_mix_f32 &&
1357 instr->opcode != aco_opcode::v_fma_mixlo_f16 &&
1358 instr->opcode != aco_opcode::v_fma_mixhi_f16;
1359
1360 if (instr->isSDWA())
1361 can_use_mod &= instr->sdwa().sel[i].size() == 4;
1362 else if (instr->isVOP3P())
1363 can_use_mod &= !packed_math || !info.is_abs();
1364 else if (instr->isVINTERP_INREG())
1365 can_use_mod &= !info.is_abs();
1366 else
1367 can_use_mod &= instr->isDPP16() || can_use_VOP3(ctx, instr);
1368
1369 unsigned bits = get_operand_size(instr, i);
1370 can_use_mod &= instr->operands[i].bytes() * 8 == bits;
1371
1372 if (info.is_neg() && can_use_mod &&
1373 can_eliminate_fcanonicalize(ctx, instr, info.temp, i)) {
1374 instr->operands[i].setTemp(info.temp);
1375 if (!packed_math && instr->valu().abs[i]) {
1376 /* fabs(fneg(a)) -> fabs(a) */
1377 } else if (instr->opcode == aco_opcode::v_add_f32) {
1378 instr->opcode = i ? aco_opcode::v_sub_f32 : aco_opcode::v_subrev_f32;
1379 } else if (instr->opcode == aco_opcode::v_add_f16) {
1380 instr->opcode = i ? aco_opcode::v_sub_f16 : aco_opcode::v_subrev_f16;
1381 } else if (packed_math) {
1382 /* Bit size compat should ensure this. */
1383 assert(!instr->valu().opsel_lo[i] && !instr->valu().opsel_hi[i]);
1384 instr->valu().neg_lo[i] ^= true;
1385 instr->valu().neg_hi[i] ^= true;
1386 } else {
1387 if (!instr->isDPP16() && can_use_VOP3(ctx, instr))
1388 instr->format = asVOP3(instr->format);
1389 instr->valu().neg[i] ^= true;
1390 }
1391 }
1392 if (info.is_abs() && can_use_mod &&
1393 can_eliminate_fcanonicalize(ctx, instr, info.temp, i)) {
1394 if (!instr->isDPP16() && can_use_VOP3(ctx, instr))
1395 instr->format = asVOP3(instr->format);
1396 instr->operands[i] = Operand(info.temp);
1397 instr->valu().abs[i] = true;
1398 continue;
1399 }
1400
1401 if (instr->isVOP3P()) {
1402 propagate_constants_vop3p(ctx, instr, info, i);
1403 continue;
1404 }
1405
1406 if (info.is_constant(bits) && alu_can_accept_constant(instr, i) &&
1407 (!instr->isSDWA() || ctx.program->gfx_level >= GFX9) && (!instr->isDPP() || i != 1)) {
1408 Operand op = get_constant_op(ctx, info, bits);
1409 if (i == 0 || instr->isSDWA() || instr->opcode == aco_opcode::v_readlane_b32 ||
1410 instr->opcode == aco_opcode::v_writelane_b32) {
1411 instr->format = withoutDPP(instr->format);
1412 instr->operands[i] = op;
1413 continue;
1414 } else if (!instr->isVOP3() && can_swap_operands(instr, &instr->opcode)) {
1415 instr->operands[i] = op;
1416 instr->valu().swapOperands(0, i);
1417 continue;
1418 } else if (can_use_VOP3(ctx, instr)) {
1419 instr->format = asVOP3(instr->format);
1420 instr->operands[i] = op;
1421 continue;
1422 }
1423 }
1424 }
1425
1426 /* MUBUF: propagate constants and combine additions */
1427 else if (instr->isMUBUF()) {
1428 MUBUF_instruction& mubuf = instr->mubuf();
1429 Temp base;
1430 uint32_t offset;
1431 while (info.is_temp())
1432 info = ctx.info[info.temp.id()];
1433
1434 bool swizzled = ctx.program->gfx_level >= GFX12 ? mubuf.cache.gfx12.swizzled
1435 : (mubuf.cache.value & ac_swizzled);
1436 /* According to AMDGPUDAGToDAGISel::SelectMUBUFScratchOffen(), vaddr
1437 * overflow for scratch accesses works only on GFX9+ and saddr overflow
1438 * never works. Since swizzling is the only thing that separates
1439 * scratch accesses and other accesses and swizzling changing how
1440 * addressing works significantly, this probably applies to swizzled
1441 * MUBUF accesses. */
1442 bool vaddr_prevent_overflow = swizzled && ctx.program->gfx_level < GFX9;
1443
1444 if (mubuf.offen && mubuf.idxen && i == 1 && info.is_vec() &&
1445 info.instr->operands.size() == 2 && info.instr->operands[0].isTemp() &&
1446 info.instr->operands[0].regClass() == v1 && info.instr->operands[1].isConstant() &&
1447 mubuf.offset + info.instr->operands[1].constantValue() < 4096) {
1448 instr->operands[1] = info.instr->operands[0];
1449 mubuf.offset += info.instr->operands[1].constantValue();
1450 mubuf.offen = false;
1451 continue;
1452 } else if (mubuf.offen && i == 1 && info.is_constant_or_literal(32) &&
1453 mubuf.offset + info.val < 4096) {
1454 assert(!mubuf.idxen);
1455 instr->operands[1] = Operand(v1);
1456 mubuf.offset += info.val;
1457 mubuf.offen = false;
1458 continue;
1459 } else if (i == 2 && info.is_constant_or_literal(32) && mubuf.offset + info.val < 4096) {
1460 instr->operands[2] = Operand::c32(0);
1461 mubuf.offset += info.val;
1462 continue;
1463 } else if (mubuf.offen && i == 1 &&
1464 parse_base_offset(ctx, instr.get(), i, &base, &offset,
1465 vaddr_prevent_overflow) &&
1466 base.regClass() == v1 && mubuf.offset + offset < 4096) {
1467 assert(!mubuf.idxen);
1468 instr->operands[1].setTemp(base);
1469 mubuf.offset += offset;
1470 continue;
1471 } else if (i == 2 && parse_base_offset(ctx, instr.get(), i, &base, &offset, true) &&
1472 base.regClass() == s1 && mubuf.offset + offset < 4096 && !swizzled) {
1473 instr->operands[i].setTemp(base);
1474 mubuf.offset += offset;
1475 continue;
1476 }
1477 }
1478
1479 else if (instr->isMTBUF()) {
1480 MTBUF_instruction& mtbuf = instr->mtbuf();
1481 while (info.is_temp())
1482 info = ctx.info[info.temp.id()];
1483
1484 if (mtbuf.offen && mtbuf.idxen && i == 1 && info.is_vec() &&
1485 info.instr->operands.size() == 2 && info.instr->operands[0].isTemp() &&
1486 info.instr->operands[0].regClass() == v1 && info.instr->operands[1].isConstant() &&
1487 mtbuf.offset + info.instr->operands[1].constantValue() < 4096) {
1488 instr->operands[1] = info.instr->operands[0];
1489 mtbuf.offset += info.instr->operands[1].constantValue();
1490 mtbuf.offen = false;
1491 continue;
1492 }
1493 }
1494
1495 /* SCRATCH: propagate constants and combine additions */
1496 else if (instr->isScratch()) {
1497 FLAT_instruction& scratch = instr->scratch();
1498 Temp base;
1499 uint32_t offset;
1500 while (info.is_temp())
1501 info = ctx.info[info.temp.id()];
1502
1503 /* The hardware probably does: 'scratch_base + u2u64(saddr) + i2i64(offset)'. This means
1504 * we can't combine the addition if the unsigned addition overflows and offset is
1505 * positive. In theory, there is also issues if
1506 * 'ilt(offset, 0) && ige(saddr, 0) && ilt(saddr + offset, 0)', but that just
1507 * replaces an already out-of-bounds access with a larger one since 'saddr + offset'
1508 * would be larger than INT32_MAX.
1509 */
1510 if (i <= 1 && parse_base_offset(ctx, instr.get(), i, &base, &offset, true) &&
1511 base.regClass() == instr->operands[i].regClass() &&
1512 is_scratch_offset_valid(ctx, instr.get(), scratch.offset, (int32_t)offset)) {
1513 instr->operands[i].setTemp(base);
1514 scratch.offset += (int32_t)offset;
1515 continue;
1516 } else if (i <= 1 && parse_base_offset(ctx, instr.get(), i, &base, &offset, false) &&
1517 base.regClass() == instr->operands[i].regClass() && (int32_t)offset < 0 &&
1518 is_scratch_offset_valid(ctx, instr.get(), scratch.offset, (int32_t)offset)) {
1519 instr->operands[i].setTemp(base);
1520 scratch.offset += (int32_t)offset;
1521 continue;
1522 } else if (i <= 1 && info.is_constant_or_literal(32) &&
1523 ctx.program->gfx_level >= GFX10_3 &&
1524 is_scratch_offset_valid(ctx, NULL, scratch.offset, (int32_t)info.val)) {
1525 /* GFX10.3+ can disable both SADDR and ADDR. */
1526 instr->operands[i] = Operand(instr->operands[i].regClass());
1527 scratch.offset += (int32_t)info.val;
1528 continue;
1529 }
1530 }
1531
1532 /* DS: combine additions */
1533 else if (instr->isDS()) {
1534
1535 DS_instruction& ds = instr->ds();
1536 Temp base;
1537 uint32_t offset;
1538 bool has_usable_ds_offset = ctx.program->gfx_level >= GFX7;
1539 if (has_usable_ds_offset && i == 0 &&
1540 parse_base_offset(ctx, instr.get(), i, &base, &offset, false) &&
1541 base.regClass() == instr->operands[i].regClass() &&
1542 instr->opcode != aco_opcode::ds_swizzle_b32) {
1543 if (instr->opcode == aco_opcode::ds_write2_b32 ||
1544 instr->opcode == aco_opcode::ds_read2_b32 ||
1545 instr->opcode == aco_opcode::ds_write2_b64 ||
1546 instr->opcode == aco_opcode::ds_read2_b64 ||
1547 instr->opcode == aco_opcode::ds_write2st64_b32 ||
1548 instr->opcode == aco_opcode::ds_read2st64_b32 ||
1549 instr->opcode == aco_opcode::ds_write2st64_b64 ||
1550 instr->opcode == aco_opcode::ds_read2st64_b64) {
1551 bool is64bit = instr->opcode == aco_opcode::ds_write2_b64 ||
1552 instr->opcode == aco_opcode::ds_read2_b64 ||
1553 instr->opcode == aco_opcode::ds_write2st64_b64 ||
1554 instr->opcode == aco_opcode::ds_read2st64_b64;
1555 bool st64 = instr->opcode == aco_opcode::ds_write2st64_b32 ||
1556 instr->opcode == aco_opcode::ds_read2st64_b32 ||
1557 instr->opcode == aco_opcode::ds_write2st64_b64 ||
1558 instr->opcode == aco_opcode::ds_read2st64_b64;
1559 unsigned shifts = (is64bit ? 3 : 2) + (st64 ? 6 : 0);
1560 unsigned mask = BITFIELD_MASK(shifts);
1561
1562 if ((offset & mask) == 0 && ds.offset0 + (offset >> shifts) <= 255 &&
1563 ds.offset1 + (offset >> shifts) <= 255) {
1564 instr->operands[i].setTemp(base);
1565 ds.offset0 += offset >> shifts;
1566 ds.offset1 += offset >> shifts;
1567 }
1568 } else {
1569 if (ds.offset0 + offset <= 65535) {
1570 instr->operands[i].setTemp(base);
1571 ds.offset0 += offset;
1572 }
1573 }
1574 }
1575 }
1576
1577 else if (instr->isBranch()) {
1578 if (ctx.info[instr->operands[0].tempId()].is_scc_invert()) {
1579 /* Flip the branch instruction to get rid of the scc_invert instruction */
1580 instr->opcode = instr->opcode == aco_opcode::p_cbranch_z ? aco_opcode::p_cbranch_nz
1581 : aco_opcode::p_cbranch_z;
1582 instr->operands[0].setTemp(ctx.info[instr->operands[0].tempId()].temp);
1583 }
1584 }
1585 }
1586
1587 /* if this instruction doesn't define anything, return */
1588 if (instr->definitions.empty()) {
1589 check_sdwa_extract(ctx, instr);
1590 return;
1591 }
1592
1593 if (instr->isVALU() || instr->isVINTRP()) {
1594 if (instr_info.can_use_output_modifiers[(int)instr->opcode] || instr->isVINTRP() ||
1595 instr->opcode == aco_opcode::v_cndmask_b32) {
1596 bool canonicalized = true;
1597 if (!does_fp_op_flush_denorms(ctx, instr->opcode)) {
1598 unsigned ops = instr->opcode == aco_opcode::v_cndmask_b32 ? 2 : instr->operands.size();
1599 for (unsigned i = 0; canonicalized && (i < ops); i++)
1600 canonicalized = is_op_canonicalized(ctx, instr->operands[i]);
1601 }
1602 if (canonicalized)
1603 ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1604 }
1605
1606 if (instr->isVOPC()) {
1607 ctx.info[instr->definitions[0].tempId()].set_vopc(instr.get());
1608 check_sdwa_extract(ctx, instr);
1609 return;
1610 }
1611 if (instr->isVOP3P()) {
1612 ctx.info[instr->definitions[0].tempId()].set_vop3p(instr.get());
1613 return;
1614 }
1615 }
1616
1617 switch (instr->opcode) {
1618 case aco_opcode::p_create_vector: {
1619 bool copy_prop = instr->operands.size() == 1 && instr->operands[0].isTemp() &&
1620 instr->operands[0].regClass() == instr->definitions[0].regClass();
1621 if (copy_prop) {
1622 ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1623 break;
1624 }
1625
1626 /* expand vector operands */
1627 std::vector<Operand> ops;
1628 unsigned offset = 0;
1629 for (const Operand& op : instr->operands) {
1630 /* ensure that any expanded operands are properly aligned */
1631 bool aligned = offset % 4 == 0 || op.bytes() < 4;
1632 offset += op.bytes();
1633 if (aligned && op.isTemp() && ctx.info[op.tempId()].is_vec()) {
1634 Instruction* vec = ctx.info[op.tempId()].instr;
1635 for (const Operand& vec_op : vec->operands)
1636 ops.emplace_back(vec_op);
1637 } else {
1638 ops.emplace_back(op);
1639 }
1640 }
1641
1642 /* combine expanded operands to new vector */
1643 if (ops.size() != instr->operands.size()) {
1644 assert(ops.size() > instr->operands.size());
1645 Definition def = instr->definitions[0];
1646 instr.reset(
1647 create_instruction(aco_opcode::p_create_vector, Format::PSEUDO, ops.size(), 1));
1648 for (unsigned i = 0; i < ops.size(); i++) {
1649 if (ops[i].isTemp() && ctx.info[ops[i].tempId()].is_temp() &&
1650 ops[i].regClass() == ctx.info[ops[i].tempId()].temp.regClass())
1651 ops[i].setTemp(ctx.info[ops[i].tempId()].temp);
1652 instr->operands[i] = ops[i];
1653 }
1654 instr->definitions[0] = def;
1655 } else {
1656 for (unsigned i = 0; i < ops.size(); i++) {
1657 assert(instr->operands[i] == ops[i]);
1658 }
1659 }
1660 ctx.info[instr->definitions[0].tempId()].set_vec(instr.get());
1661
1662 if (instr->operands.size() == 2) {
1663 /* check if this is created from split_vector */
1664 if (instr->operands[1].isTemp() && ctx.info[instr->operands[1].tempId()].is_split()) {
1665 Instruction* split = ctx.info[instr->operands[1].tempId()].instr;
1666 if (instr->operands[0].isTemp() &&
1667 instr->operands[0].getTemp() == split->definitions[0].getTemp())
1668 ctx.info[instr->definitions[0].tempId()].set_temp(split->operands[0].getTemp());
1669 }
1670 }
1671 break;
1672 }
1673 case aco_opcode::p_split_vector: {
1674 ssa_info& info = ctx.info[instr->operands[0].tempId()];
1675
1676 if (info.is_constant_or_literal(32)) {
1677 uint64_t val = info.val;
1678 for (Definition def : instr->definitions) {
1679 uint32_t mask = u_bit_consecutive(0, def.bytes() * 8u);
1680 ctx.info[def.tempId()].set_constant(ctx.program->gfx_level, val & mask);
1681 val >>= def.bytes() * 8u;
1682 }
1683 break;
1684 } else if (!info.is_vec()) {
1685 if (instr->definitions.size() == 2 && instr->operands[0].isTemp() &&
1686 instr->definitions[0].bytes() == instr->definitions[1].bytes()) {
1687 ctx.info[instr->definitions[1].tempId()].set_split(instr.get());
1688 if (instr->operands[0].bytes() == 4) {
1689 /* D16 subdword split */
1690 ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1691 ctx.info[instr->definitions[1].tempId()].set_extract(instr.get());
1692 }
1693 }
1694 break;
1695 }
1696
1697 Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
1698 unsigned split_offset = 0;
1699 unsigned vec_offset = 0;
1700 unsigned vec_index = 0;
1701 for (unsigned i = 0; i < instr->definitions.size();
1702 split_offset += instr->definitions[i++].bytes()) {
1703 while (vec_offset < split_offset && vec_index < vec->operands.size())
1704 vec_offset += vec->operands[vec_index++].bytes();
1705
1706 if (vec_offset != split_offset ||
1707 vec->operands[vec_index].bytes() != instr->definitions[i].bytes())
1708 continue;
1709
1710 Operand vec_op = vec->operands[vec_index];
1711 if (vec_op.isConstant()) {
1712 ctx.info[instr->definitions[i].tempId()].set_constant(ctx.program->gfx_level,
1713 vec_op.constantValue64());
1714 } else if (vec_op.isTemp()) {
1715 ctx.info[instr->definitions[i].tempId()].set_temp(vec_op.getTemp());
1716 }
1717 }
1718 break;
1719 }
1720 case aco_opcode::p_extract_vector: { /* mov */
1721 const unsigned index = instr->operands[1].constantValue();
1722
1723 if (instr->operands[0].isTemp()) {
1724 ssa_info& info = ctx.info[instr->operands[0].tempId()];
1725 const unsigned dst_offset = index * instr->definitions[0].bytes();
1726
1727 if (info.is_vec()) {
1728 /* check if we index directly into a vector element */
1729 Instruction* vec = info.instr;
1730 unsigned offset = 0;
1731
1732 for (const Operand& op : vec->operands) {
1733 if (offset < dst_offset) {
1734 offset += op.bytes();
1735 continue;
1736 } else if (offset != dst_offset || op.bytes() != instr->definitions[0].bytes()) {
1737 break;
1738 }
1739 instr->operands[0] = op;
1740 break;
1741 }
1742 } else if (info.is_constant_or_literal(32)) {
1743 /* propagate constants */
1744 uint32_t mask = u_bit_consecutive(0, instr->definitions[0].bytes() * 8u);
1745 uint32_t val = (info.val >> (dst_offset * 8u)) & mask;
1746 instr->operands[0] =
1747 Operand::get_const(ctx.program->gfx_level, val, instr->definitions[0].bytes());
1748 ;
1749 }
1750 }
1751
1752 if (instr->operands[0].bytes() != instr->definitions[0].bytes()) {
1753 if (instr->operands[0].size() != 1 || !instr->operands[0].isTemp())
1754 break;
1755
1756 if (index == 0)
1757 ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1758 else
1759 ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
1760 break;
1761 }
1762
1763 /* convert this extract into a copy instruction */
1764 instr->opcode = aco_opcode::p_parallelcopy;
1765 instr->operands.pop_back();
1766 FALLTHROUGH;
1767 }
1768 case aco_opcode::p_parallelcopy: /* propagate */
1769 if (instr->operands[0].isTemp() && ctx.info[instr->operands[0].tempId()].is_vec() &&
1770 instr->operands[0].regClass() != instr->definitions[0].regClass()) {
1771 /* We might not be able to copy-propagate if it's a SGPR->VGPR copy, so
1772 * duplicate the vector instead.
1773 */
1774 Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
1775 aco_ptr<Instruction> old_copy = std::move(instr);
1776
1777 instr.reset(create_instruction(aco_opcode::p_create_vector, Format::PSEUDO,
1778 vec->operands.size(), 1));
1779 instr->definitions[0] = old_copy->definitions[0];
1780 std::copy(vec->operands.begin(), vec->operands.end(), instr->operands.begin());
1781 for (unsigned i = 0; i < vec->operands.size(); i++) {
1782 Operand& op = instr->operands[i];
1783 if (op.isTemp() && ctx.info[op.tempId()].is_temp() &&
1784 ctx.info[op.tempId()].temp.type() == instr->definitions[0].regClass().type())
1785 op.setTemp(ctx.info[op.tempId()].temp);
1786 }
1787 ctx.info[instr->definitions[0].tempId()].set_vec(instr.get());
1788 break;
1789 }
1790 FALLTHROUGH;
1791 case aco_opcode::p_as_uniform:
1792 if (instr->definitions[0].isFixed()) {
1793 /* don't copy-propagate copies into fixed registers */
1794 } else if (instr->operands[0].isConstant()) {
1795 ctx.info[instr->definitions[0].tempId()].set_constant(
1796 ctx.program->gfx_level, instr->operands[0].constantValue64());
1797 } else if (instr->operands[0].isTemp()) {
1798 ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1799 if (ctx.info[instr->operands[0].tempId()].is_canonicalized())
1800 ctx.info[instr->definitions[0].tempId()].set_canonicalized();
1801 } else {
1802 assert(instr->operands[0].isFixed());
1803 }
1804 break;
1805 case aco_opcode::v_mov_b32:
1806 if (instr->isDPP16()) {
1807 /* anything else doesn't make sense in SSA */
1808 assert(instr->dpp16().row_mask == 0xf && instr->dpp16().bank_mask == 0xf);
1809 ctx.info[instr->definitions[0].tempId()].set_dpp16(instr.get());
1810 } else if (instr->isDPP8()) {
1811 ctx.info[instr->definitions[0].tempId()].set_dpp8(instr.get());
1812 }
1813 break;
1814 case aco_opcode::p_is_helper:
1815 if (!ctx.program->needs_wqm)
1816 ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->gfx_level, 0u);
1817 break;
1818 case aco_opcode::v_mul_f64_e64:
1819 case aco_opcode::v_mul_f64: ctx.info[instr->definitions[0].tempId()].set_mul(instr.get()); break;
1820 case aco_opcode::v_mul_f16:
1821 case aco_opcode::v_mul_f32:
1822 case aco_opcode::v_mul_legacy_f32: { /* omod */
1823 ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
1824
1825 /* TODO: try to move the negate/abs modifier to the consumer instead */
1826 bool uses_mods = instr->usesModifiers();
1827 bool fp16 = instr->opcode == aco_opcode::v_mul_f16;
1828
1829 for (unsigned i = 0; i < 2; i++) {
1830 if (instr->operands[!i].isConstant() && instr->operands[i].isTemp()) {
1831 if (!instr->isDPP() && !instr->isSDWA() && !instr->valu().opsel &&
1832 (instr->operands[!i].constantEquals(fp16 ? 0x3c00 : 0x3f800000) || /* 1.0 */
1833 instr->operands[!i].constantEquals(fp16 ? 0xbc00 : 0xbf800000u))) { /* -1.0 */
1834 bool neg1 = instr->operands[!i].constantEquals(fp16 ? 0xbc00 : 0xbf800000u);
1835
1836 VALU_instruction* valu = &instr->valu();
1837 if (valu->abs[!i] || valu->neg[!i] || valu->omod)
1838 continue;
1839
1840 bool abs = valu->abs[i];
1841 bool neg = neg1 ^ valu->neg[i];
1842 Temp other = instr->operands[i].getTemp();
1843
1844 if (valu->clamp) {
1845 if (!abs && !neg && other.type() == RegType::vgpr)
1846 ctx.info[other.id()].set_clamp(instr.get());
1847 continue;
1848 }
1849
1850 if (abs && neg && other.type() == RegType::vgpr)
1851 ctx.info[instr->definitions[0].tempId()].set_neg_abs(other);
1852 else if (abs && !neg && other.type() == RegType::vgpr)
1853 ctx.info[instr->definitions[0].tempId()].set_abs(other);
1854 else if (!abs && neg && other.type() == RegType::vgpr)
1855 ctx.info[instr->definitions[0].tempId()].set_neg(other);
1856 else if (!abs && !neg)
1857 ctx.info[instr->definitions[0].tempId()].set_fcanonicalize(other);
1858 } else if (uses_mods || (instr->definitions[0].isSZPreserve() &&
1859 instr->opcode != aco_opcode::v_mul_legacy_f32)) {
1860 continue; /* omod uses a legacy multiplication. */
1861 } else if (instr->operands[!i].constantValue() == 0u &&
1862 ((!instr->definitions[0].isNaNPreserve() &&
1863 !instr->definitions[0].isInfPreserve()) ||
1864 instr->opcode == aco_opcode::v_mul_legacy_f32)) { /* 0.0 */
1865 ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->gfx_level, 0u);
1866 } else if ((fp16 ? ctx.fp_mode.denorm16_64 : ctx.fp_mode.denorm32) != fp_denorm_flush) {
1867 /* omod has no effect if denormals are enabled. */
1868 continue;
1869 } else if (instr->operands[!i].constantValue() ==
1870 (fp16 ? 0x4000 : 0x40000000)) { /* 2.0 */
1871 ctx.info[instr->operands[i].tempId()].set_omod2(instr.get());
1872 } else if (instr->operands[!i].constantValue() ==
1873 (fp16 ? 0x4400 : 0x40800000)) { /* 4.0 */
1874 ctx.info[instr->operands[i].tempId()].set_omod4(instr.get());
1875 } else if (instr->operands[!i].constantValue() ==
1876 (fp16 ? 0x3800 : 0x3f000000)) { /* 0.5 */
1877 ctx.info[instr->operands[i].tempId()].set_omod5(instr.get());
1878 } else {
1879 continue;
1880 }
1881 break;
1882 }
1883 }
1884 break;
1885 }
1886 case aco_opcode::v_mul_lo_u16:
1887 case aco_opcode::v_mul_lo_u16_e64:
1888 case aco_opcode::v_mul_u32_u24:
1889 ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
1890 break;
1891 case aco_opcode::v_med3_f16:
1892 case aco_opcode::v_med3_f32: { /* clamp */
1893 unsigned idx;
1894 if (detect_clamp(instr.get(), &idx) && !instr->valu().abs && !instr->valu().neg)
1895 ctx.info[instr->operands[idx].tempId()].set_clamp(instr.get());
1896 break;
1897 }
1898 case aco_opcode::v_cndmask_b32:
1899 if (instr->operands[0].constantEquals(0) && instr->operands[1].constantEquals(0x3f800000u))
1900 ctx.info[instr->definitions[0].tempId()].set_b2f(instr->operands[2].getTemp());
1901 else if (instr->operands[0].constantEquals(0) && instr->operands[1].constantEquals(1))
1902 ctx.info[instr->definitions[0].tempId()].set_b2i(instr->operands[2].getTemp());
1903
1904 break;
1905 case aco_opcode::v_add_u32:
1906 case aco_opcode::v_add_co_u32:
1907 case aco_opcode::v_add_co_u32_e64:
1908 case aco_opcode::s_add_i32:
1909 case aco_opcode::s_add_u32:
1910 case aco_opcode::v_subbrev_co_u32:
1911 case aco_opcode::v_sub_u32:
1912 case aco_opcode::v_sub_i32:
1913 case aco_opcode::v_sub_co_u32:
1914 case aco_opcode::v_sub_co_u32_e64:
1915 case aco_opcode::s_sub_u32:
1916 case aco_opcode::s_sub_i32:
1917 case aco_opcode::v_subrev_u32:
1918 case aco_opcode::v_subrev_co_u32:
1919 case aco_opcode::v_subrev_co_u32_e64:
1920 ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
1921 break;
1922 case aco_opcode::s_not_b32:
1923 case aco_opcode::s_not_b64:
1924 if (!instr->operands[0].isTemp()) {
1925 } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
1926 ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
1927 ctx.info[instr->definitions[1].tempId()].set_scc_invert(
1928 ctx.info[instr->operands[0].tempId()].temp);
1929 } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bitwise()) {
1930 ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
1931 ctx.info[instr->definitions[1].tempId()].set_scc_invert(
1932 ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
1933 }
1934 ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
1935 break;
1936 case aco_opcode::s_and_b32:
1937 case aco_opcode::s_and_b64:
1938 if (fixed_to_exec(instr->operands[1]) && instr->operands[0].isTemp()) {
1939 if (ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
1940 /* Try to get rid of the superfluous s_cselect + s_and_b64 that comes from turning a
1941 * uniform bool into divergent */
1942 ctx.info[instr->definitions[1].tempId()].set_temp(
1943 ctx.info[instr->operands[0].tempId()].temp);
1944 ctx.info[instr->definitions[0].tempId()].set_uniform_bool(
1945 ctx.info[instr->operands[0].tempId()].temp);
1946 break;
1947 } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bitwise()) {
1948 /* Try to get rid of the superfluous s_and_b64, since the uniform bitwise instruction
1949 * already produces the same SCC */
1950 ctx.info[instr->definitions[1].tempId()].set_temp(
1951 ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
1952 ctx.info[instr->definitions[0].tempId()].set_uniform_bool(
1953 ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
1954 break;
1955 } else if ((ctx.program->stage.num_sw_stages() > 1 ||
1956 ctx.program->stage.hw == AC_HW_NEXT_GEN_GEOMETRY_SHADER) &&
1957 instr->pass_flags == 1) {
1958 /* In case of merged shaders, pass_flags=1 means that all lanes are active (exec=-1), so
1959 * s_and is unnecessary. */
1960 ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
1961 break;
1962 }
1963 }
1964 FALLTHROUGH;
1965 case aco_opcode::s_or_b32:
1966 case aco_opcode::s_or_b64:
1967 case aco_opcode::s_xor_b32:
1968 case aco_opcode::s_xor_b64:
1969 if (std::all_of(instr->operands.begin(), instr->operands.end(),
1970 [&ctx](const Operand& op)
1971 {
1972 return op.isTemp() && (ctx.info[op.tempId()].is_uniform_bool() ||
1973 ctx.info[op.tempId()].is_uniform_bitwise());
1974 })) {
1975 ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
1976 }
1977 ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
1978 break;
1979 case aco_opcode::s_lshl_b32:
1980 case aco_opcode::v_or_b32:
1981 case aco_opcode::v_lshlrev_b32:
1982 case aco_opcode::v_bcnt_u32_b32:
1983 case aco_opcode::v_and_b32:
1984 case aco_opcode::v_xor_b32:
1985 case aco_opcode::v_not_b32:
1986 ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
1987 break;
1988 case aco_opcode::v_min_f32:
1989 case aco_opcode::v_min_f16:
1990 case aco_opcode::v_min_u32:
1991 case aco_opcode::v_min_i32:
1992 case aco_opcode::v_min_u16:
1993 case aco_opcode::v_min_i16:
1994 case aco_opcode::v_min_u16_e64:
1995 case aco_opcode::v_min_i16_e64:
1996 case aco_opcode::v_max_f32:
1997 case aco_opcode::v_max_f16:
1998 case aco_opcode::v_max_u32:
1999 case aco_opcode::v_max_i32:
2000 case aco_opcode::v_max_u16:
2001 case aco_opcode::v_max_i16:
2002 case aco_opcode::v_max_u16_e64:
2003 case aco_opcode::v_max_i16_e64:
2004 ctx.info[instr->definitions[0].tempId()].set_minmax(instr.get());
2005 break;
2006 case aco_opcode::s_cselect_b64:
2007 case aco_opcode::s_cselect_b32:
2008 if (instr->operands[0].constantEquals((unsigned)-1) && instr->operands[1].constantEquals(0)) {
2009 /* Found a cselect that operates on a uniform bool that comes from eg. s_cmp */
2010 ctx.info[instr->definitions[0].tempId()].set_uniform_bool(instr->operands[2].getTemp());
2011 }
2012 if (instr->operands[2].isTemp() && ctx.info[instr->operands[2].tempId()].is_scc_invert()) {
2013 /* Flip the operands to get rid of the scc_invert instruction */
2014 std::swap(instr->operands[0], instr->operands[1]);
2015 instr->operands[2].setTemp(ctx.info[instr->operands[2].tempId()].temp);
2016 }
2017 break;
2018 case aco_opcode::s_mul_i32:
2019 /* Testing every uint32_t shows that 0x3f800000*n is never a denormal.
2020 * This pattern is created from a uniform nir_op_b2f. */
2021 if (instr->operands[0].constantEquals(0x3f800000u))
2022 ctx.info[instr->definitions[0].tempId()].set_canonicalized();
2023 break;
2024 case aco_opcode::p_extract: {
2025 if (instr->definitions[0].bytes() == 4 && instr->operands[0].isTemp()) {
2026 ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
2027 if (instr->operands[0].regClass() == v1 && parse_insert(instr.get()))
2028 ctx.info[instr->operands[0].tempId()].set_insert(instr.get());
2029 }
2030 break;
2031 }
2032 case aco_opcode::p_insert: {
2033 if (instr->operands[0].bytes() == 4 && instr->operands[0].isTemp()) {
2034 if (instr->operands[0].regClass() == v1)
2035 ctx.info[instr->operands[0].tempId()].set_insert(instr.get());
2036 if (parse_extract(instr.get()))
2037 ctx.info[instr->definitions[0].tempId()].set_extract(instr.get());
2038 ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
2039 }
2040 break;
2041 }
2042 case aco_opcode::ds_read_u8:
2043 case aco_opcode::ds_read_u8_d16:
2044 case aco_opcode::ds_read_u16:
2045 case aco_opcode::ds_read_u16_d16: {
2046 ctx.info[instr->definitions[0].tempId()].set_usedef(instr.get());
2047 break;
2048 }
2049 case aco_opcode::v_cvt_f16_f32: {
2050 if (instr->operands[0].isTemp()) {
2051 ssa_info& info = ctx.info[instr->operands[0].tempId()];
2052 if (!info.is_dpp() || info.instr->pass_flags != instr->pass_flags)
2053 info.set_f2f16(instr.get());
2054 }
2055 break;
2056 }
2057 case aco_opcode::v_cvt_f32_f16: {
2058 if (instr->operands[0].isTemp())
2059 ctx.info[instr->definitions[0].tempId()].set_f2f32(instr.get());
2060 break;
2061 }
2062 default: break;
2063 }
2064
2065 /* Don't remove label_extract if we can't apply the extract to
2066 * neg/abs instructions because we'll likely combine it into another valu. */
2067 if (!(ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs)))
2068 check_sdwa_extract(ctx, instr);
2069 }
2070
2071 unsigned
original_temp_id(opt_ctx & ctx,Temp tmp)2072 original_temp_id(opt_ctx& ctx, Temp tmp)
2073 {
2074 if (ctx.info[tmp.id()].is_temp())
2075 return ctx.info[tmp.id()].temp.id();
2076 else
2077 return tmp.id();
2078 }
2079
2080 void
decrease_op_uses_if_dead(opt_ctx & ctx,Instruction * instr)2081 decrease_op_uses_if_dead(opt_ctx& ctx, Instruction* instr)
2082 {
2083 if (is_dead(ctx.uses, instr)) {
2084 for (const Operand& op : instr->operands) {
2085 if (op.isTemp())
2086 ctx.uses[op.tempId()]--;
2087 }
2088 }
2089 }
2090
2091 void
decrease_uses(opt_ctx & ctx,Instruction * instr)2092 decrease_uses(opt_ctx& ctx, Instruction* instr)
2093 {
2094 ctx.uses[instr->definitions[0].tempId()]--;
2095 decrease_op_uses_if_dead(ctx, instr);
2096 }
2097
2098 Operand
copy_operand(opt_ctx & ctx,Operand op)2099 copy_operand(opt_ctx& ctx, Operand op)
2100 {
2101 if (op.isTemp())
2102 ctx.uses[op.tempId()]++;
2103 return op;
2104 }
2105
2106 Instruction*
follow_operand(opt_ctx & ctx,Operand op,bool ignore_uses=false)2107 follow_operand(opt_ctx& ctx, Operand op, bool ignore_uses = false)
2108 {
2109 if (!op.isTemp() || !(ctx.info[op.tempId()].label & instr_usedef_labels))
2110 return nullptr;
2111 if (!ignore_uses && ctx.uses[op.tempId()] > 1)
2112 return nullptr;
2113
2114 Instruction* instr = ctx.info[op.tempId()].instr;
2115
2116 if (instr->definitions.size() == 2) {
2117 unsigned idx = ctx.info[op.tempId()].label & label_split ? 1 : 0;
2118 assert(instr->definitions[idx].isTemp() && instr->definitions[idx].tempId() == op.tempId());
2119 if (instr->definitions[!idx].isTemp() && ctx.uses[instr->definitions[!idx].tempId()])
2120 return nullptr;
2121 }
2122
2123 for (Operand& operand : instr->operands) {
2124 if (fixed_to_exec(operand))
2125 return nullptr;
2126 }
2127
2128 return instr;
2129 }
2130
2131 bool
is_operand_constant(opt_ctx & ctx,Operand op,unsigned bit_size,uint64_t * value)2132 is_operand_constant(opt_ctx& ctx, Operand op, unsigned bit_size, uint64_t* value)
2133 {
2134 if (op.isConstant()) {
2135 *value = op.constantValue64();
2136 return true;
2137 } else if (op.isTemp()) {
2138 unsigned id = original_temp_id(ctx, op.getTemp());
2139 if (!ctx.info[id].is_constant_or_literal(bit_size))
2140 return false;
2141 *value = get_constant_op(ctx, ctx.info[id], bit_size).constantValue64();
2142 return true;
2143 }
2144 return false;
2145 }
2146
2147 /* s_not(cmp(a, b)) -> get_vcmp_inverse(cmp)(a, b) */
2148 bool
combine_inverse_comparison(opt_ctx & ctx,aco_ptr<Instruction> & instr)2149 combine_inverse_comparison(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2150 {
2151 if (ctx.uses[instr->definitions[1].tempId()])
2152 return false;
2153 if (!instr->operands[0].isTemp() || ctx.uses[instr->operands[0].tempId()] != 1)
2154 return false;
2155
2156 Instruction* cmp = follow_operand(ctx, instr->operands[0]);
2157 if (!cmp)
2158 return false;
2159
2160 aco_opcode new_opcode = get_vcmp_inverse(cmp->opcode);
2161 if (new_opcode == aco_opcode::num_opcodes)
2162 return false;
2163
2164 /* Invert compare instruction and assign this instruction's definition */
2165 cmp->opcode = new_opcode;
2166 ctx.info[instr->definitions[0].tempId()] = ctx.info[cmp->definitions[0].tempId()];
2167 std::swap(instr->definitions[0], cmp->definitions[0]);
2168
2169 ctx.uses[instr->operands[0].tempId()]--;
2170 return true;
2171 }
2172
2173 /* op1(op2(1, 2), 0) if swap = false
2174 * op1(0, op2(1, 2)) if swap = true */
2175 bool
match_op3_for_vop3(opt_ctx & ctx,aco_opcode op1,aco_opcode op2,Instruction * op1_instr,bool swap,const char * shuffle_str,Operand operands[3],bitarray8 & neg,bitarray8 & abs,bitarray8 & opsel,bool * op1_clamp,uint8_t * op1_omod,bool * inbetween_neg,bool * inbetween_abs,bool * inbetween_opsel,bool * precise)2176 match_op3_for_vop3(opt_ctx& ctx, aco_opcode op1, aco_opcode op2, Instruction* op1_instr, bool swap,
2177 const char* shuffle_str, Operand operands[3], bitarray8& neg, bitarray8& abs,
2178 bitarray8& opsel, bool* op1_clamp, uint8_t* op1_omod, bool* inbetween_neg,
2179 bool* inbetween_abs, bool* inbetween_opsel, bool* precise)
2180 {
2181 /* checks */
2182 if (op1_instr->opcode != op1)
2183 return false;
2184
2185 Instruction* op2_instr = follow_operand(ctx, op1_instr->operands[swap]);
2186 if (!op2_instr || op2_instr->opcode != op2)
2187 return false;
2188
2189 VALU_instruction* op1_valu = op1_instr->isVALU() ? &op1_instr->valu() : NULL;
2190 VALU_instruction* op2_valu = op2_instr->isVALU() ? &op2_instr->valu() : NULL;
2191
2192 if (op1_instr->isSDWA() || op2_instr->isSDWA())
2193 return false;
2194 if (op1_instr->isDPP() || op2_instr->isDPP())
2195 return false;
2196
2197 /* don't support inbetween clamp/omod */
2198 if (op2_valu && (op2_valu->clamp || op2_valu->omod))
2199 return false;
2200
2201 /* get operands and modifiers and check inbetween modifiers */
2202 *op1_clamp = op1_valu ? (bool)op1_valu->clamp : false;
2203 *op1_omod = op1_valu ? (unsigned)op1_valu->omod : 0u;
2204
2205 if (inbetween_neg)
2206 *inbetween_neg = op1_valu ? op1_valu->neg[swap] : false;
2207 else if (op1_valu && op1_valu->neg[swap])
2208 return false;
2209
2210 if (inbetween_abs)
2211 *inbetween_abs = op1_valu ? op1_valu->abs[swap] : false;
2212 else if (op1_valu && op1_valu->abs[swap])
2213 return false;
2214
2215 if (inbetween_opsel)
2216 *inbetween_opsel = op1_valu ? op1_valu->opsel[swap] : false;
2217 else if (op1_valu && op1_valu->opsel[swap])
2218 return false;
2219
2220 *precise = op1_instr->definitions[0].isPrecise() || op2_instr->definitions[0].isPrecise();
2221
2222 int shuffle[3];
2223 shuffle[shuffle_str[0] - '0'] = 0;
2224 shuffle[shuffle_str[1] - '0'] = 1;
2225 shuffle[shuffle_str[2] - '0'] = 2;
2226
2227 operands[shuffle[0]] = op1_instr->operands[!swap];
2228 neg[shuffle[0]] = op1_valu ? op1_valu->neg[!swap] : false;
2229 abs[shuffle[0]] = op1_valu ? op1_valu->abs[!swap] : false;
2230 opsel[shuffle[0]] = op1_valu ? op1_valu->opsel[!swap] : false;
2231
2232 for (unsigned i = 0; i < 2; i++) {
2233 operands[shuffle[i + 1]] = op2_instr->operands[i];
2234 neg[shuffle[i + 1]] = op2_valu ? op2_valu->neg[i] : false;
2235 abs[shuffle[i + 1]] = op2_valu ? op2_valu->abs[i] : false;
2236 opsel[shuffle[i + 1]] = op2_valu ? op2_valu->opsel[i] : false;
2237 }
2238
2239 /* check operands */
2240 if (!check_vop3_operands(ctx, 3, operands))
2241 return false;
2242
2243 return true;
2244 }
2245
2246 void
create_vop3_for_op3(opt_ctx & ctx,aco_opcode opcode,aco_ptr<Instruction> & instr,Operand operands[3],uint8_t neg,uint8_t abs,uint8_t opsel,bool clamp,unsigned omod)2247 create_vop3_for_op3(opt_ctx& ctx, aco_opcode opcode, aco_ptr<Instruction>& instr,
2248 Operand operands[3], uint8_t neg, uint8_t abs, uint8_t opsel, bool clamp,
2249 unsigned omod)
2250 {
2251 Instruction* new_instr = create_instruction(opcode, Format::VOP3, 3, 1);
2252 new_instr->valu().neg = neg;
2253 new_instr->valu().abs = abs;
2254 new_instr->valu().clamp = clamp;
2255 new_instr->valu().omod = omod;
2256 new_instr->valu().opsel = opsel;
2257 new_instr->operands[0] = operands[0];
2258 new_instr->operands[1] = operands[1];
2259 new_instr->operands[2] = operands[2];
2260 new_instr->definitions[0] = instr->definitions[0];
2261 new_instr->pass_flags = instr->pass_flags;
2262 ctx.info[instr->definitions[0].tempId()].label = 0;
2263
2264 instr.reset(new_instr);
2265 }
2266
2267 bool
combine_three_valu_op(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode op2,aco_opcode new_op,const char * shuffle,uint8_t ops)2268 combine_three_valu_op(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode op2, aco_opcode new_op,
2269 const char* shuffle, uint8_t ops)
2270 {
2271 for (unsigned swap = 0; swap < 2; swap++) {
2272 if (!((1 << swap) & ops))
2273 continue;
2274
2275 Operand operands[3];
2276 bool clamp, precise;
2277 bitarray8 neg = 0, abs = 0, opsel = 0;
2278 uint8_t omod = 0;
2279 if (match_op3_for_vop3(ctx, instr->opcode, op2, instr.get(), swap, shuffle, operands, neg,
2280 abs, opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) {
2281 ctx.uses[instr->operands[swap].tempId()]--;
2282 create_vop3_for_op3(ctx, new_op, instr, operands, neg, abs, opsel, clamp, omod);
2283 return true;
2284 }
2285 }
2286 return false;
2287 }
2288
2289 /* creates v_lshl_add_u32, v_lshl_or_b32 or v_and_or_b32 */
2290 bool
combine_add_or_then_and_lshl(opt_ctx & ctx,aco_ptr<Instruction> & instr)2291 combine_add_or_then_and_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2292 {
2293 bool is_or = instr->opcode == aco_opcode::v_or_b32;
2294 aco_opcode new_op_lshl = is_or ? aco_opcode::v_lshl_or_b32 : aco_opcode::v_lshl_add_u32;
2295
2296 if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::s_and_b32, aco_opcode::v_and_or_b32,
2297 "120", 1 | 2))
2298 return true;
2299 if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::v_and_b32, aco_opcode::v_and_or_b32,
2300 "120", 1 | 2))
2301 return true;
2302 if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, new_op_lshl, "120", 1 | 2))
2303 return true;
2304 if (combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, new_op_lshl, "210", 1 | 2))
2305 return true;
2306
2307 if (instr->isSDWA() || instr->isDPP())
2308 return false;
2309
2310 /* v_or_b32(p_extract(a, 0, 8/16, 0), b) -> v_and_or_b32(a, 0xff/0xffff, b)
2311 * v_or_b32(p_insert(a, 0, 8/16), b) -> v_and_or_b32(a, 0xff/0xffff, b)
2312 * v_or_b32(p_insert(a, 24/16, 8/16), b) -> v_lshl_or_b32(a, 24/16, b)
2313 * v_add_u32(p_insert(a, 24/16, 8/16), b) -> v_lshl_add_b32(a, 24/16, b)
2314 */
2315 for (unsigned i = 0; i < 2; i++) {
2316 Instruction* extins = follow_operand(ctx, instr->operands[i]);
2317 if (!extins)
2318 continue;
2319
2320 aco_opcode op;
2321 Operand operands[3];
2322
2323 if (extins->opcode == aco_opcode::p_insert &&
2324 (extins->operands[1].constantValue() + 1) * extins->operands[2].constantValue() == 32) {
2325 op = new_op_lshl;
2326 operands[1] =
2327 Operand::c32(extins->operands[1].constantValue() * extins->operands[2].constantValue());
2328 } else if (is_or &&
2329 (extins->opcode == aco_opcode::p_insert ||
2330 (extins->opcode == aco_opcode::p_extract &&
2331 extins->operands[3].constantEquals(0))) &&
2332 extins->operands[1].constantEquals(0)) {
2333 op = aco_opcode::v_and_or_b32;
2334 operands[1] = Operand::c32(extins->operands[2].constantEquals(8) ? 0xffu : 0xffffu);
2335 } else {
2336 continue;
2337 }
2338
2339 operands[0] = extins->operands[0];
2340 operands[2] = instr->operands[!i];
2341
2342 if (!check_vop3_operands(ctx, 3, operands))
2343 continue;
2344
2345 uint8_t neg = 0, abs = 0, opsel = 0, omod = 0;
2346 bool clamp = false;
2347 if (instr->isVOP3())
2348 clamp = instr->valu().clamp;
2349
2350 ctx.uses[instr->operands[i].tempId()]--;
2351 create_vop3_for_op3(ctx, op, instr, operands, neg, abs, opsel, clamp, omod);
2352 return true;
2353 }
2354
2355 return false;
2356 }
2357
2358 /* v_xor(a, s_not(b)) -> v_xnor(a, b)
2359 * v_xor(a, v_not(b)) -> v_xnor(a, b)
2360 */
2361 bool
combine_xor_not(opt_ctx & ctx,aco_ptr<Instruction> & instr)2362 combine_xor_not(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2363 {
2364 if (instr->usesModifiers())
2365 return false;
2366
2367 for (unsigned i = 0; i < 2; i++) {
2368 Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
2369 if (!op_instr ||
2370 (op_instr->opcode != aco_opcode::v_not_b32 &&
2371 op_instr->opcode != aco_opcode::s_not_b32) ||
2372 op_instr->usesModifiers() || op_instr->operands[0].isLiteral())
2373 continue;
2374
2375 instr->opcode = aco_opcode::v_xnor_b32;
2376 instr->operands[i] = copy_operand(ctx, op_instr->operands[0]);
2377 decrease_uses(ctx, op_instr);
2378 if (instr->operands[0].isOfType(RegType::vgpr))
2379 std::swap(instr->operands[0], instr->operands[1]);
2380 if (!instr->operands[1].isOfType(RegType::vgpr))
2381 instr->format = asVOP3(instr->format);
2382
2383 return true;
2384 }
2385
2386 return false;
2387 }
2388
2389 /* v_not(v_xor(a, b)) -> v_xnor(a, b) */
2390 bool
combine_not_xor(opt_ctx & ctx,aco_ptr<Instruction> & instr)2391 combine_not_xor(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2392 {
2393 if (instr->usesModifiers())
2394 return false;
2395
2396 Instruction* op_instr = follow_operand(ctx, instr->operands[0]);
2397 if (!op_instr || op_instr->opcode != aco_opcode::v_xor_b32 || op_instr->isSDWA())
2398 return false;
2399
2400 ctx.uses[instr->operands[0].tempId()]--;
2401 std::swap(instr->definitions[0], op_instr->definitions[0]);
2402 op_instr->opcode = aco_opcode::v_xnor_b32;
2403 ctx.info[op_instr->definitions[0].tempId()].label = 0;
2404
2405 return true;
2406 }
2407
2408 bool
combine_minmax(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode opposite,aco_opcode op3src,aco_opcode minmax)2409 combine_minmax(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode opposite, aco_opcode op3src,
2410 aco_opcode minmax)
2411 {
2412 /* TODO: this can handle SDWA min/max instructions by using opsel */
2413
2414 /* min(min(a, b), c) -> min3(a, b, c)
2415 * max(max(a, b), c) -> max3(a, b, c)
2416 * gfx11: min(-min(a, b), c) -> maxmin(-a, -b, c)
2417 * gfx11: max(-max(a, b), c) -> minmax(-a, -b, c)
2418 */
2419 for (unsigned swap = 0; swap < 2; swap++) {
2420 Operand operands[3];
2421 bool clamp, precise;
2422 bitarray8 opsel = 0, neg = 0, abs = 0;
2423 uint8_t omod = 0;
2424 bool inbetween_neg;
2425 if (match_op3_for_vop3(ctx, instr->opcode, instr->opcode, instr.get(), swap, "120", operands,
2426 neg, abs, opsel, &clamp, &omod, &inbetween_neg, NULL, NULL,
2427 &precise) &&
2428 (!inbetween_neg ||
2429 (minmax != aco_opcode::num_opcodes && ctx.program->gfx_level >= GFX11))) {
2430 ctx.uses[instr->operands[swap].tempId()]--;
2431 if (inbetween_neg) {
2432 neg[0] = !neg[0];
2433 neg[1] = !neg[1];
2434 create_vop3_for_op3(ctx, minmax, instr, operands, neg, abs, opsel, clamp, omod);
2435 } else {
2436 create_vop3_for_op3(ctx, op3src, instr, operands, neg, abs, opsel, clamp, omod);
2437 }
2438 return true;
2439 }
2440 }
2441
2442 /* min(-max(a, b), c) -> min3(-a, -b, c)
2443 * max(-min(a, b), c) -> max3(-a, -b, c)
2444 * gfx11: min(max(a, b), c) -> maxmin(a, b, c)
2445 * gfx11: max(min(a, b), c) -> minmax(a, b, c)
2446 */
2447 for (unsigned swap = 0; swap < 2; swap++) {
2448 Operand operands[3];
2449 bool clamp, precise;
2450 bitarray8 opsel = 0, neg = 0, abs = 0;
2451 uint8_t omod = 0;
2452 bool inbetween_neg;
2453 if (match_op3_for_vop3(ctx, instr->opcode, opposite, instr.get(), swap, "120", operands, neg,
2454 abs, opsel, &clamp, &omod, &inbetween_neg, NULL, NULL, &precise) &&
2455 (inbetween_neg ||
2456 (minmax != aco_opcode::num_opcodes && ctx.program->gfx_level >= GFX11))) {
2457 ctx.uses[instr->operands[swap].tempId()]--;
2458 if (inbetween_neg) {
2459 neg[0] = !neg[0];
2460 neg[1] = !neg[1];
2461 create_vop3_for_op3(ctx, op3src, instr, operands, neg, abs, opsel, clamp, omod);
2462 } else {
2463 create_vop3_for_op3(ctx, minmax, instr, operands, neg, abs, opsel, clamp, omod);
2464 }
2465 return true;
2466 }
2467 }
2468 return false;
2469 }
2470
2471 /* s_not_b32(s_and_b32(a, b)) -> s_nand_b32(a, b)
2472 * s_not_b32(s_or_b32(a, b)) -> s_nor_b32(a, b)
2473 * s_not_b32(s_xor_b32(a, b)) -> s_xnor_b32(a, b)
2474 * s_not_b64(s_and_b64(a, b)) -> s_nand_b64(a, b)
2475 * s_not_b64(s_or_b64(a, b)) -> s_nor_b64(a, b)
2476 * s_not_b64(s_xor_b64(a, b)) -> s_xnor_b64(a, b) */
2477 bool
combine_salu_not_bitwise(opt_ctx & ctx,aco_ptr<Instruction> & instr)2478 combine_salu_not_bitwise(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2479 {
2480 /* checks */
2481 if (!instr->operands[0].isTemp())
2482 return false;
2483 if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
2484 return false;
2485
2486 Instruction* op2_instr = follow_operand(ctx, instr->operands[0]);
2487 if (!op2_instr)
2488 return false;
2489 switch (op2_instr->opcode) {
2490 case aco_opcode::s_and_b32:
2491 case aco_opcode::s_or_b32:
2492 case aco_opcode::s_xor_b32:
2493 case aco_opcode::s_and_b64:
2494 case aco_opcode::s_or_b64:
2495 case aco_opcode::s_xor_b64: break;
2496 default: return false;
2497 }
2498
2499 /* create instruction */
2500 std::swap(instr->definitions[0], op2_instr->definitions[0]);
2501 std::swap(instr->definitions[1], op2_instr->definitions[1]);
2502 ctx.uses[instr->operands[0].tempId()]--;
2503 ctx.info[op2_instr->definitions[0].tempId()].label = 0;
2504
2505 switch (op2_instr->opcode) {
2506 case aco_opcode::s_and_b32: op2_instr->opcode = aco_opcode::s_nand_b32; break;
2507 case aco_opcode::s_or_b32: op2_instr->opcode = aco_opcode::s_nor_b32; break;
2508 case aco_opcode::s_xor_b32: op2_instr->opcode = aco_opcode::s_xnor_b32; break;
2509 case aco_opcode::s_and_b64: op2_instr->opcode = aco_opcode::s_nand_b64; break;
2510 case aco_opcode::s_or_b64: op2_instr->opcode = aco_opcode::s_nor_b64; break;
2511 case aco_opcode::s_xor_b64: op2_instr->opcode = aco_opcode::s_xnor_b64; break;
2512 default: break;
2513 }
2514
2515 return true;
2516 }
2517
2518 /* s_and_b32(a, s_not_b32(b)) -> s_andn2_b32(a, b)
2519 * s_or_b32(a, s_not_b32(b)) -> s_orn2_b32(a, b)
2520 * s_and_b64(a, s_not_b64(b)) -> s_andn2_b64(a, b)
2521 * s_or_b64(a, s_not_b64(b)) -> s_orn2_b64(a, b) */
2522 bool
combine_salu_n2(opt_ctx & ctx,aco_ptr<Instruction> & instr)2523 combine_salu_n2(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2524 {
2525 if (instr->definitions[0].isTemp() && ctx.info[instr->definitions[0].tempId()].is_uniform_bool())
2526 return false;
2527
2528 for (unsigned i = 0; i < 2; i++) {
2529 Instruction* op2_instr = follow_operand(ctx, instr->operands[i]);
2530 if (!op2_instr || (op2_instr->opcode != aco_opcode::s_not_b32 &&
2531 op2_instr->opcode != aco_opcode::s_not_b64))
2532 continue;
2533 if (ctx.uses[op2_instr->definitions[1].tempId()])
2534 continue;
2535
2536 if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
2537 instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
2538 continue;
2539
2540 ctx.uses[instr->operands[i].tempId()]--;
2541 instr->operands[0] = instr->operands[!i];
2542 instr->operands[1] = op2_instr->operands[0];
2543 ctx.info[instr->definitions[0].tempId()].label = 0;
2544
2545 switch (instr->opcode) {
2546 case aco_opcode::s_and_b32: instr->opcode = aco_opcode::s_andn2_b32; break;
2547 case aco_opcode::s_or_b32: instr->opcode = aco_opcode::s_orn2_b32; break;
2548 case aco_opcode::s_and_b64: instr->opcode = aco_opcode::s_andn2_b64; break;
2549 case aco_opcode::s_or_b64: instr->opcode = aco_opcode::s_orn2_b64; break;
2550 default: break;
2551 }
2552
2553 return true;
2554 }
2555 return false;
2556 }
2557
2558 /* s_add_{i32,u32}(a, s_lshl_b32(b, <n>)) -> s_lshl<n>_add_u32(a, b) */
2559 bool
combine_salu_lshl_add(opt_ctx & ctx,aco_ptr<Instruction> & instr)2560 combine_salu_lshl_add(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2561 {
2562 if (instr->opcode == aco_opcode::s_add_i32 && ctx.uses[instr->definitions[1].tempId()])
2563 return false;
2564
2565 for (unsigned i = 0; i < 2; i++) {
2566 Instruction* op2_instr = follow_operand(ctx, instr->operands[i], true);
2567 if (!op2_instr || op2_instr->opcode != aco_opcode::s_lshl_b32 ||
2568 ctx.uses[op2_instr->definitions[1].tempId()])
2569 continue;
2570 if (!op2_instr->operands[1].isConstant())
2571 continue;
2572
2573 uint32_t shift = op2_instr->operands[1].constantValue();
2574 if (shift < 1 || shift > 4)
2575 continue;
2576
2577 if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
2578 instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
2579 continue;
2580
2581 instr->operands[1] = instr->operands[!i];
2582 instr->operands[0] = copy_operand(ctx, op2_instr->operands[0]);
2583 decrease_uses(ctx, op2_instr);
2584 ctx.info[instr->definitions[0].tempId()].label = 0;
2585
2586 instr->opcode = std::array<aco_opcode, 4>{
2587 aco_opcode::s_lshl1_add_u32, aco_opcode::s_lshl2_add_u32, aco_opcode::s_lshl3_add_u32,
2588 aco_opcode::s_lshl4_add_u32}[shift - 1];
2589
2590 return true;
2591 }
2592 return false;
2593 }
2594
2595 /* s_abs_i32(s_sub_[iu]32(a, b)) -> s_absdiff_i32(a, b)
2596 * s_abs_i32(s_add_[iu]32(a, #b)) -> s_absdiff_i32(a, -b)
2597 */
2598 bool
combine_sabsdiff(opt_ctx & ctx,aco_ptr<Instruction> & instr)2599 combine_sabsdiff(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2600 {
2601 if (!instr->operands[0].isTemp() || !ctx.info[instr->operands[0].tempId()].is_add_sub())
2602 return false;
2603
2604 Instruction* op_instr = follow_operand(ctx, instr->operands[0], false);
2605 if (!op_instr)
2606 return false;
2607
2608 if (op_instr->opcode == aco_opcode::s_add_i32 || op_instr->opcode == aco_opcode::s_add_u32) {
2609 for (unsigned i = 0; i < 2; i++) {
2610 uint64_t constant;
2611 if (op_instr->operands[!i].isLiteral() ||
2612 !is_operand_constant(ctx, op_instr->operands[i], 32, &constant))
2613 continue;
2614
2615 if (op_instr->operands[i].isTemp())
2616 ctx.uses[op_instr->operands[i].tempId()]--;
2617 op_instr->operands[0] = op_instr->operands[!i];
2618 op_instr->operands[1] = Operand::c32(-int32_t(constant));
2619 goto use_absdiff;
2620 }
2621 return false;
2622 }
2623
2624 use_absdiff:
2625 op_instr->opcode = aco_opcode::s_absdiff_i32;
2626 std::swap(instr->definitions[0], op_instr->definitions[0]);
2627 std::swap(instr->definitions[1], op_instr->definitions[1]);
2628 ctx.uses[instr->operands[0].tempId()]--;
2629 ctx.info[op_instr->definitions[0].tempId()].label = 0;
2630
2631 return true;
2632 }
2633
2634 bool
combine_add_sub_b2i(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode new_op,uint8_t ops)2635 combine_add_sub_b2i(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode new_op, uint8_t ops)
2636 {
2637 if (instr->usesModifiers())
2638 return false;
2639
2640 for (unsigned i = 0; i < 2; i++) {
2641 if (!((1 << i) & ops))
2642 continue;
2643 if (instr->operands[i].isTemp() && ctx.info[instr->operands[i].tempId()].is_b2i() &&
2644 ctx.uses[instr->operands[i].tempId()] == 1) {
2645
2646 aco_ptr<Instruction> new_instr;
2647 if (instr->operands[!i].isTemp() &&
2648 instr->operands[!i].getTemp().type() == RegType::vgpr) {
2649 new_instr.reset(create_instruction(new_op, Format::VOP2, 3, 2));
2650 } else if (ctx.program->gfx_level >= GFX10 ||
2651 (instr->operands[!i].isConstant() && !instr->operands[!i].isLiteral())) {
2652 new_instr.reset(create_instruction(new_op, asVOP3(Format::VOP2), 3, 2));
2653 } else {
2654 return false;
2655 }
2656 ctx.uses[instr->operands[i].tempId()]--;
2657 new_instr->definitions[0] = instr->definitions[0];
2658 if (instr->definitions.size() == 2) {
2659 new_instr->definitions[1] = instr->definitions[1];
2660 } else {
2661 new_instr->definitions[1] =
2662 Definition(ctx.program->allocateTmp(ctx.program->lane_mask));
2663 /* Make sure the uses vector is large enough and the number of
2664 * uses properly initialized to 0.
2665 */
2666 ctx.uses.push_back(0);
2667 ctx.info.push_back(ssa_info{});
2668 }
2669 new_instr->operands[0] = Operand::zero();
2670 new_instr->operands[1] = instr->operands[!i];
2671 new_instr->operands[2] = Operand(ctx.info[instr->operands[i].tempId()].temp);
2672 new_instr->pass_flags = instr->pass_flags;
2673 instr = std::move(new_instr);
2674 ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
2675 return true;
2676 }
2677 }
2678
2679 return false;
2680 }
2681
2682 bool
combine_add_bcnt(opt_ctx & ctx,aco_ptr<Instruction> & instr)2683 combine_add_bcnt(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2684 {
2685 if (instr->usesModifiers())
2686 return false;
2687
2688 for (unsigned i = 0; i < 2; i++) {
2689 Instruction* op_instr = follow_operand(ctx, instr->operands[i]);
2690 if (op_instr && op_instr->opcode == aco_opcode::v_bcnt_u32_b32 &&
2691 !op_instr->usesModifiers() && op_instr->operands[0].isTemp() &&
2692 op_instr->operands[0].getTemp().type() == RegType::vgpr &&
2693 op_instr->operands[1].constantEquals(0)) {
2694 aco_ptr<Instruction> new_instr{
2695 create_instruction(aco_opcode::v_bcnt_u32_b32, Format::VOP3, 2, 1)};
2696 ctx.uses[instr->operands[i].tempId()]--;
2697 new_instr->operands[0] = op_instr->operands[0];
2698 new_instr->operands[1] = instr->operands[!i];
2699 new_instr->definitions[0] = instr->definitions[0];
2700 new_instr->pass_flags = instr->pass_flags;
2701 instr = std::move(new_instr);
2702 ctx.info[instr->definitions[0].tempId()].label = 0;
2703
2704 return true;
2705 }
2706 }
2707
2708 return false;
2709 }
2710
2711 bool
get_minmax_info(aco_opcode op,aco_opcode * min,aco_opcode * max,aco_opcode * min3,aco_opcode * max3,aco_opcode * med3,aco_opcode * minmax,bool * some_gfx9_only)2712 get_minmax_info(aco_opcode op, aco_opcode* min, aco_opcode* max, aco_opcode* min3, aco_opcode* max3,
2713 aco_opcode* med3, aco_opcode* minmax, bool* some_gfx9_only)
2714 {
2715 switch (op) {
2716 #define MINMAX(type, gfx9) \
2717 case aco_opcode::v_min_##type: \
2718 case aco_opcode::v_max_##type: \
2719 *min = aco_opcode::v_min_##type; \
2720 *max = aco_opcode::v_max_##type; \
2721 *med3 = aco_opcode::v_med3_##type; \
2722 *min3 = aco_opcode::v_min3_##type; \
2723 *max3 = aco_opcode::v_max3_##type; \
2724 *minmax = op == *min ? aco_opcode::v_maxmin_##type : aco_opcode::v_minmax_##type; \
2725 *some_gfx9_only = gfx9; \
2726 return true;
2727 #define MINMAX_INT16(type, gfx9) \
2728 case aco_opcode::v_min_##type: \
2729 case aco_opcode::v_max_##type: \
2730 *min = aco_opcode::v_min_##type; \
2731 *max = aco_opcode::v_max_##type; \
2732 *med3 = aco_opcode::v_med3_##type; \
2733 *min3 = aco_opcode::v_min3_##type; \
2734 *max3 = aco_opcode::v_max3_##type; \
2735 *minmax = aco_opcode::num_opcodes; \
2736 *some_gfx9_only = gfx9; \
2737 return true;
2738 #define MINMAX_INT16_E64(type, gfx9) \
2739 case aco_opcode::v_min_##type##_e64: \
2740 case aco_opcode::v_max_##type##_e64: \
2741 *min = aco_opcode::v_min_##type##_e64; \
2742 *max = aco_opcode::v_max_##type##_e64; \
2743 *med3 = aco_opcode::v_med3_##type; \
2744 *min3 = aco_opcode::v_min3_##type; \
2745 *max3 = aco_opcode::v_max3_##type; \
2746 *minmax = aco_opcode::num_opcodes; \
2747 *some_gfx9_only = gfx9; \
2748 return true;
2749 MINMAX(f32, false)
2750 MINMAX(u32, false)
2751 MINMAX(i32, false)
2752 MINMAX(f16, true)
2753 MINMAX_INT16(u16, true)
2754 MINMAX_INT16(i16, true)
2755 MINMAX_INT16_E64(u16, true)
2756 MINMAX_INT16_E64(i16, true)
2757 #undef MINMAX_INT16_E64
2758 #undef MINMAX_INT16
2759 #undef MINMAX
2760 default: return false;
2761 }
2762 }
2763
2764 /* when ub > lb:
2765 * v_min_{f,u,i}{16,32}(v_max_{f,u,i}{16,32}(a, lb), ub) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
2766 * v_max_{f,u,i}{16,32}(v_min_{f,u,i}{16,32}(a, ub), lb) -> v_med3_{f,u,i}{16,32}(a, lb, ub)
2767 */
2768 bool
combine_clamp(opt_ctx & ctx,aco_ptr<Instruction> & instr,aco_opcode min,aco_opcode max,aco_opcode med)2769 combine_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode min, aco_opcode max,
2770 aco_opcode med)
2771 {
2772 /* TODO: GLSL's clamp(x, minVal, maxVal) and SPIR-V's
2773 * FClamp(x, minVal, maxVal)/NClamp(x, minVal, maxVal) are undefined if
2774 * minVal > maxVal, which means we can always select it to a v_med3_f32 */
2775 aco_opcode other_op;
2776 if (instr->opcode == min)
2777 other_op = max;
2778 else if (instr->opcode == max)
2779 other_op = min;
2780 else
2781 return false;
2782
2783 for (unsigned swap = 0; swap < 2; swap++) {
2784 Operand operands[3];
2785 bool clamp, precise;
2786 bitarray8 opsel = 0, neg = 0, abs = 0;
2787 uint8_t omod = 0;
2788 if (match_op3_for_vop3(ctx, instr->opcode, other_op, instr.get(), swap, "012", operands, neg,
2789 abs, opsel, &clamp, &omod, NULL, NULL, NULL, &precise)) {
2790 /* max(min(src, upper), lower) returns upper if src is NaN, but
2791 * med3(src, lower, upper) returns lower.
2792 */
2793 if (precise && instr->opcode != min &&
2794 (min == aco_opcode::v_min_f16 || min == aco_opcode::v_min_f32))
2795 continue;
2796
2797 int const0_idx = -1, const1_idx = -1;
2798 uint32_t const0 = 0, const1 = 0;
2799 for (int i = 0; i < 3; i++) {
2800 uint32_t val;
2801 bool hi16 = opsel & (1 << i);
2802 if (operands[i].isConstant()) {
2803 val = hi16 ? operands[i].constantValue16(true) : operands[i].constantValue();
2804 } else if (operands[i].isTemp() &&
2805 ctx.info[operands[i].tempId()].is_constant_or_literal(32)) {
2806 val = ctx.info[operands[i].tempId()].val >> (hi16 ? 16 : 0);
2807 } else {
2808 continue;
2809 }
2810 if (const0_idx >= 0) {
2811 const1_idx = i;
2812 const1 = val;
2813 } else {
2814 const0_idx = i;
2815 const0 = val;
2816 }
2817 }
2818 if (const0_idx < 0 || const1_idx < 0)
2819 continue;
2820
2821 int lower_idx = const0_idx;
2822 switch (min) {
2823 case aco_opcode::v_min_f32:
2824 case aco_opcode::v_min_f16: {
2825 float const0_f, const1_f;
2826 if (min == aco_opcode::v_min_f32) {
2827 memcpy(&const0_f, &const0, 4);
2828 memcpy(&const1_f, &const1, 4);
2829 } else {
2830 const0_f = _mesa_half_to_float(const0);
2831 const1_f = _mesa_half_to_float(const1);
2832 }
2833 if (abs[const0_idx])
2834 const0_f = fabsf(const0_f);
2835 if (abs[const1_idx])
2836 const1_f = fabsf(const1_f);
2837 if (neg[const0_idx])
2838 const0_f = -const0_f;
2839 if (neg[const1_idx])
2840 const1_f = -const1_f;
2841 lower_idx = const0_f < const1_f ? const0_idx : const1_idx;
2842 break;
2843 }
2844 case aco_opcode::v_min_u32: {
2845 lower_idx = const0 < const1 ? const0_idx : const1_idx;
2846 break;
2847 }
2848 case aco_opcode::v_min_u16:
2849 case aco_opcode::v_min_u16_e64: {
2850 lower_idx = (uint16_t)const0 < (uint16_t)const1 ? const0_idx : const1_idx;
2851 break;
2852 }
2853 case aco_opcode::v_min_i32: {
2854 int32_t const0_i =
2855 const0 & 0x80000000u ? -2147483648 + (int32_t)(const0 & 0x7fffffffu) : const0;
2856 int32_t const1_i =
2857 const1 & 0x80000000u ? -2147483648 + (int32_t)(const1 & 0x7fffffffu) : const1;
2858 lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
2859 break;
2860 }
2861 case aco_opcode::v_min_i16:
2862 case aco_opcode::v_min_i16_e64: {
2863 int16_t const0_i = const0 & 0x8000u ? -32768 + (int16_t)(const0 & 0x7fffu) : const0;
2864 int16_t const1_i = const1 & 0x8000u ? -32768 + (int16_t)(const1 & 0x7fffu) : const1;
2865 lower_idx = const0_i < const1_i ? const0_idx : const1_idx;
2866 break;
2867 }
2868 default: break;
2869 }
2870 int upper_idx = lower_idx == const0_idx ? const1_idx : const0_idx;
2871
2872 if (instr->opcode == min) {
2873 if (upper_idx != 0 || lower_idx == 0)
2874 return false;
2875 } else {
2876 if (upper_idx == 0 || lower_idx != 0)
2877 return false;
2878 }
2879
2880 ctx.uses[instr->operands[swap].tempId()]--;
2881 create_vop3_for_op3(ctx, med, instr, operands, neg, abs, opsel, clamp, omod);
2882
2883 return true;
2884 }
2885 }
2886
2887 return false;
2888 }
2889
2890 void
apply_sgprs(opt_ctx & ctx,aco_ptr<Instruction> & instr)2891 apply_sgprs(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2892 {
2893 bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64_e64 ||
2894 instr->opcode == aco_opcode::v_lshlrev_b64 ||
2895 instr->opcode == aco_opcode::v_lshrrev_b64 ||
2896 instr->opcode == aco_opcode::v_ashrrev_i64;
2897
2898 /* find candidates and create the set of sgprs already read */
2899 unsigned sgpr_ids[2] = {0, 0};
2900 uint32_t operand_mask = 0;
2901 bool has_literal = false;
2902 for (unsigned i = 0; i < instr->operands.size(); i++) {
2903 if (instr->operands[i].isLiteral())
2904 has_literal = true;
2905 if (!instr->operands[i].isTemp())
2906 continue;
2907 if (instr->operands[i].getTemp().type() == RegType::sgpr) {
2908 if (instr->operands[i].tempId() != sgpr_ids[0])
2909 sgpr_ids[!!sgpr_ids[0]] = instr->operands[i].tempId();
2910 }
2911 ssa_info& info = ctx.info[instr->operands[i].tempId()];
2912 if (is_copy_label(ctx, instr, info, i) && info.temp.type() == RegType::sgpr)
2913 operand_mask |= 1u << i;
2914 if (info.is_extract() && info.instr->operands[0].getTemp().type() == RegType::sgpr)
2915 operand_mask |= 1u << i;
2916 }
2917 unsigned max_sgprs = 1;
2918 if (ctx.program->gfx_level >= GFX10 && !is_shift64)
2919 max_sgprs = 2;
2920 if (has_literal)
2921 max_sgprs--;
2922
2923 unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
2924
2925 /* keep on applying sgprs until there is nothing left to be done */
2926 while (operand_mask) {
2927 uint32_t sgpr_idx = 0;
2928 uint32_t sgpr_info_id = 0;
2929 uint32_t mask = operand_mask;
2930 /* choose a sgpr */
2931 while (mask) {
2932 unsigned i = u_bit_scan(&mask);
2933 uint16_t uses = ctx.uses[instr->operands[i].tempId()];
2934 if (sgpr_info_id == 0 || uses < ctx.uses[sgpr_info_id]) {
2935 sgpr_idx = i;
2936 sgpr_info_id = instr->operands[i].tempId();
2937 }
2938 }
2939 operand_mask &= ~(1u << sgpr_idx);
2940
2941 ssa_info& info = ctx.info[sgpr_info_id];
2942
2943 /* Applying two sgprs require making it VOP3, so don't do it unless it's
2944 * definitively beneficial.
2945 * TODO: this is too conservative because later the use count could be reduced to 1 */
2946 if (!info.is_extract() && num_sgprs && ctx.uses[sgpr_info_id] > 1 && !instr->isVOP3() &&
2947 !instr->isSDWA() && instr->format != Format::VOP3P)
2948 break;
2949
2950 Temp sgpr = info.is_extract() ? info.instr->operands[0].getTemp() : info.temp;
2951 bool new_sgpr = sgpr.id() != sgpr_ids[0] && sgpr.id() != sgpr_ids[1];
2952 if (new_sgpr && num_sgprs >= max_sgprs)
2953 continue;
2954
2955 if (sgpr_idx == 0)
2956 instr->format = withoutDPP(instr->format);
2957
2958 if (sgpr_idx == 1 && instr->isDPP())
2959 continue;
2960
2961 if (sgpr_idx == 0 || instr->isVOP3() || instr->isSDWA() || instr->isVOP3P() ||
2962 info.is_extract()) {
2963 /* can_apply_extract() checks SGPR encoding restrictions */
2964 if (info.is_extract() && can_apply_extract(ctx, instr, sgpr_idx, info))
2965 apply_extract(ctx, instr, sgpr_idx, info);
2966 else if (info.is_extract())
2967 continue;
2968 instr->operands[sgpr_idx] = Operand(sgpr);
2969 } else if (can_swap_operands(instr, &instr->opcode) && !instr->valu().opsel[sgpr_idx]) {
2970 instr->operands[sgpr_idx] = instr->operands[0];
2971 instr->operands[0] = Operand(sgpr);
2972 instr->valu().opsel[0].swap(instr->valu().opsel[sgpr_idx]);
2973 /* swap bits using a 4-entry LUT */
2974 uint32_t swapped = (0x3120 >> (operand_mask & 0x3)) & 0xf;
2975 operand_mask = (operand_mask & ~0x3) | swapped;
2976 } else if (can_use_VOP3(ctx, instr) && !info.is_extract()) {
2977 instr->format = asVOP3(instr->format);
2978 instr->operands[sgpr_idx] = Operand(sgpr);
2979 } else {
2980 continue;
2981 }
2982
2983 if (new_sgpr)
2984 sgpr_ids[num_sgprs++] = sgpr.id();
2985 ctx.uses[sgpr_info_id]--;
2986 ctx.uses[sgpr.id()]++;
2987
2988 /* TODO: handle when it's a VGPR */
2989 if ((ctx.info[sgpr.id()].label & (label_extract | label_temp)) &&
2990 ctx.info[sgpr.id()].temp.type() == RegType::sgpr)
2991 operand_mask |= 1u << sgpr_idx;
2992 }
2993 }
2994
2995 bool
interp_can_become_fma(opt_ctx & ctx,aco_ptr<Instruction> & instr)2996 interp_can_become_fma(opt_ctx& ctx, aco_ptr<Instruction>& instr)
2997 {
2998 if (instr->opcode != aco_opcode::v_interp_p2_f32_inreg)
2999 return false;
3000
3001 instr->opcode = aco_opcode::v_fma_f32;
3002 instr->format = Format::VOP3;
3003 bool dpp_allowed = can_use_DPP(ctx.program->gfx_level, instr, false);
3004 instr->opcode = aco_opcode::v_interp_p2_f32_inreg;
3005 instr->format = Format::VINTERP_INREG;
3006
3007 return dpp_allowed;
3008 }
3009
3010 void
interp_p2_f32_inreg_to_fma_dpp(aco_ptr<Instruction> & instr)3011 interp_p2_f32_inreg_to_fma_dpp(aco_ptr<Instruction>& instr)
3012 {
3013 static_assert(sizeof(DPP16_instruction) == sizeof(VINTERP_inreg_instruction),
3014 "Invalid instr cast.");
3015 instr->format = asVOP3(Format::DPP16);
3016 instr->opcode = aco_opcode::v_fma_f32;
3017 instr->dpp16().dpp_ctrl = dpp_quad_perm(2, 2, 2, 2);
3018 instr->dpp16().row_mask = 0xf;
3019 instr->dpp16().bank_mask = 0xf;
3020 instr->dpp16().bound_ctrl = 0;
3021 instr->dpp16().fetch_inactive = 1;
3022 }
3023
3024 /* apply omod / clamp modifiers if the def is used only once and the instruction can have modifiers */
3025 bool
apply_omod_clamp(opt_ctx & ctx,aco_ptr<Instruction> & instr)3026 apply_omod_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3027 {
3028 if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1 ||
3029 !instr_info.can_use_output_modifiers[(int)instr->opcode])
3030 return false;
3031
3032 bool can_vop3 = can_use_VOP3(ctx, instr);
3033 bool is_mad_mix =
3034 instr->opcode == aco_opcode::v_fma_mix_f32 || instr->opcode == aco_opcode::v_fma_mixlo_f16;
3035 bool needs_vop3 = !instr->isSDWA() && !instr->isVINTERP_INREG() && !is_mad_mix;
3036 if (needs_vop3 && !can_vop3)
3037 return false;
3038
3039 /* SDWA omod is GFX9+. */
3040 bool can_use_omod = (can_vop3 || ctx.program->gfx_level >= GFX9) && !instr->isVOP3P() &&
3041 (!instr->isVINTERP_INREG() || interp_can_become_fma(ctx, instr));
3042
3043 ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3044
3045 uint64_t omod_labels = label_omod2 | label_omod4 | label_omod5;
3046 if (!def_info.is_clamp() && !(can_use_omod && (def_info.label & omod_labels)))
3047 return false;
3048 /* if the omod/clamp instruction is dead, then the single user of this
3049 * instruction is a different instruction */
3050 if (!ctx.uses[def_info.instr->definitions[0].tempId()])
3051 return false;
3052
3053 if (def_info.instr->definitions[0].bytes() != instr->definitions[0].bytes())
3054 return false;
3055
3056 /* MADs/FMAs are created later, so we don't have to update the original add */
3057 assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
3058
3059 if (!def_info.is_clamp() && (instr->valu().clamp || instr->valu().omod))
3060 return false;
3061
3062 if (needs_vop3)
3063 instr->format = asVOP3(instr->format);
3064
3065 if (!def_info.is_clamp() && instr->opcode == aco_opcode::v_interp_p2_f32_inreg)
3066 interp_p2_f32_inreg_to_fma_dpp(instr);
3067
3068 if (def_info.is_omod2())
3069 instr->valu().omod = 1;
3070 else if (def_info.is_omod4())
3071 instr->valu().omod = 2;
3072 else if (def_info.is_omod5())
3073 instr->valu().omod = 3;
3074 else if (def_info.is_clamp())
3075 instr->valu().clamp = true;
3076
3077 instr->definitions[0].swapTemp(def_info.instr->definitions[0]);
3078 ctx.info[instr->definitions[0].tempId()].label &= label_clamp | label_insert | label_f2f16;
3079 ctx.uses[def_info.instr->definitions[0].tempId()]--;
3080
3081 return true;
3082 }
3083
3084 /* Combine an p_insert (or p_extract, in some cases) instruction with instr.
3085 * p_insert(instr(...)) -> instr_insert().
3086 */
3087 bool
apply_insert(opt_ctx & ctx,aco_ptr<Instruction> & instr)3088 apply_insert(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3089 {
3090 if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1)
3091 return false;
3092
3093 ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3094 if (!def_info.is_insert())
3095 return false;
3096 /* if the insert instruction is dead, then the single user of this
3097 * instruction is a different instruction */
3098 if (!ctx.uses[def_info.instr->definitions[0].tempId()])
3099 return false;
3100
3101 /* MADs/FMAs are created later, so we don't have to update the original add */
3102 assert(!ctx.info[instr->definitions[0].tempId()].is_mad());
3103
3104 SubdwordSel sel = parse_insert(def_info.instr);
3105 assert(sel);
3106
3107 if (!can_use_SDWA(ctx.program->gfx_level, instr, true))
3108 return false;
3109
3110 convert_to_SDWA(ctx.program->gfx_level, instr);
3111 if (instr->sdwa().dst_sel.size() != 4)
3112 return false;
3113 instr->sdwa().dst_sel = sel;
3114
3115 instr->definitions[0].swapTemp(def_info.instr->definitions[0]);
3116 ctx.info[instr->definitions[0].tempId()].label = 0;
3117 ctx.uses[def_info.instr->definitions[0].tempId()]--;
3118
3119 return true;
3120 }
3121
3122 /* Remove superfluous extract after ds_read like so:
3123 * p_extract(ds_read_uN(), 0, N, 0) -> ds_read_uN()
3124 */
3125 bool
apply_ds_extract(opt_ctx & ctx,aco_ptr<Instruction> & extract)3126 apply_ds_extract(opt_ctx& ctx, aco_ptr<Instruction>& extract)
3127 {
3128 /* Check if p_extract has a usedef operand and is the only user. */
3129 if (!ctx.info[extract->operands[0].tempId()].is_usedef() ||
3130 ctx.uses[extract->operands[0].tempId()] > 1)
3131 return false;
3132
3133 /* Check if the usedef is a DS instruction. */
3134 Instruction* ds = ctx.info[extract->operands[0].tempId()].instr;
3135 if (ds->format != Format::DS)
3136 return false;
3137
3138 unsigned extract_idx = extract->operands[1].constantValue();
3139 unsigned bits_extracted = extract->operands[2].constantValue();
3140 unsigned sign_ext = extract->operands[3].constantValue();
3141 unsigned dst_bitsize = extract->definitions[0].bytes() * 8u;
3142
3143 /* TODO: These are doable, but probably don't occur too often. */
3144 if (extract_idx || sign_ext || dst_bitsize != 32)
3145 return false;
3146
3147 unsigned bits_loaded = 0;
3148 if (ds->opcode == aco_opcode::ds_read_u8 || ds->opcode == aco_opcode::ds_read_u8_d16)
3149 bits_loaded = 8;
3150 else if (ds->opcode == aco_opcode::ds_read_u16 || ds->opcode == aco_opcode::ds_read_u16_d16)
3151 bits_loaded = 16;
3152 else
3153 return false;
3154
3155 /* Shrink the DS load if the extracted bit size is smaller. */
3156 bits_loaded = MIN2(bits_loaded, bits_extracted);
3157
3158 /* Change the DS opcode so it writes the full register. */
3159 if (bits_loaded == 8)
3160 ds->opcode = aco_opcode::ds_read_u8;
3161 else if (bits_loaded == 16)
3162 ds->opcode = aco_opcode::ds_read_u16;
3163 else
3164 unreachable("Forgot to add DS opcode above.");
3165
3166 /* The DS now produces the exact same thing as the extract, remove the extract. */
3167 std::swap(ds->definitions[0], extract->definitions[0]);
3168 ctx.uses[extract->definitions[0].tempId()] = 0;
3169 ctx.info[ds->definitions[0].tempId()].label = 0;
3170 return true;
3171 }
3172
3173 /* v_and(a, v_subbrev_co(0, 0, vcc)) -> v_cndmask(0, a, vcc) */
3174 bool
combine_and_subbrev(opt_ctx & ctx,aco_ptr<Instruction> & instr)3175 combine_and_subbrev(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3176 {
3177 if (instr->usesModifiers())
3178 return false;
3179
3180 for (unsigned i = 0; i < 2; i++) {
3181 Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
3182 if (op_instr && op_instr->opcode == aco_opcode::v_subbrev_co_u32 &&
3183 op_instr->operands[0].constantEquals(0) && op_instr->operands[1].constantEquals(0) &&
3184 !op_instr->usesModifiers()) {
3185
3186 aco_ptr<Instruction> new_instr;
3187 if (instr->operands[!i].isTemp() &&
3188 instr->operands[!i].getTemp().type() == RegType::vgpr) {
3189 new_instr.reset(create_instruction(aco_opcode::v_cndmask_b32, Format::VOP2, 3, 1));
3190 } else if (ctx.program->gfx_level >= GFX10 ||
3191 (instr->operands[!i].isConstant() && !instr->operands[!i].isLiteral())) {
3192 new_instr.reset(
3193 create_instruction(aco_opcode::v_cndmask_b32, asVOP3(Format::VOP2), 3, 1));
3194 } else {
3195 return false;
3196 }
3197
3198 new_instr->operands[0] = Operand::zero();
3199 new_instr->operands[1] = instr->operands[!i];
3200 new_instr->operands[2] = copy_operand(ctx, op_instr->operands[2]);
3201 new_instr->definitions[0] = instr->definitions[0];
3202 new_instr->pass_flags = instr->pass_flags;
3203 instr = std::move(new_instr);
3204 decrease_uses(ctx, op_instr);
3205 ctx.info[instr->definitions[0].tempId()].label = 0;
3206 return true;
3207 }
3208 }
3209
3210 return false;
3211 }
3212
3213 /* v_and(a, not(b)) -> v_bfi_b32(b, 0, a)
3214 * v_or(a, not(b)) -> v_bfi_b32(b, a, -1)
3215 */
3216 bool
combine_v_andor_not(opt_ctx & ctx,aco_ptr<Instruction> & instr)3217 combine_v_andor_not(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3218 {
3219 if (instr->usesModifiers())
3220 return false;
3221
3222 for (unsigned i = 0; i < 2; i++) {
3223 Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
3224 if (op_instr && !op_instr->usesModifiers() &&
3225 (op_instr->opcode == aco_opcode::v_not_b32 ||
3226 op_instr->opcode == aco_opcode::s_not_b32)) {
3227
3228 Operand ops[3] = {
3229 op_instr->operands[0],
3230 Operand::zero(),
3231 instr->operands[!i],
3232 };
3233 if (instr->opcode == aco_opcode::v_or_b32) {
3234 ops[1] = instr->operands[!i];
3235 ops[2] = Operand::c32(-1);
3236 }
3237 if (!check_vop3_operands(ctx, 3, ops))
3238 continue;
3239
3240 Instruction* new_instr = create_instruction(aco_opcode::v_bfi_b32, Format::VOP3, 3, 1);
3241
3242 if (op_instr->operands[0].isTemp())
3243 ctx.uses[op_instr->operands[0].tempId()]++;
3244 for (unsigned j = 0; j < 3; j++)
3245 new_instr->operands[j] = ops[j];
3246 new_instr->definitions[0] = instr->definitions[0];
3247 new_instr->pass_flags = instr->pass_flags;
3248 instr.reset(new_instr);
3249 decrease_uses(ctx, op_instr);
3250 ctx.info[instr->definitions[0].tempId()].label = 0;
3251 return true;
3252 }
3253 }
3254
3255 return false;
3256 }
3257
3258 /* v_add_co(c, s_lshl(a, b)) -> v_mad_u32_u24(a, 1<<b, c)
3259 * v_add_co(c, v_lshlrev(a, b)) -> v_mad_u32_u24(b, 1<<a, c)
3260 * v_sub(c, s_lshl(a, b)) -> v_mad_i32_i24(a, -(1<<b), c)
3261 * v_sub(c, v_lshlrev(a, b)) -> v_mad_i32_i24(b, -(1<<a), c)
3262 */
3263 bool
combine_add_lshl(opt_ctx & ctx,aco_ptr<Instruction> & instr,bool is_sub)3264 combine_add_lshl(opt_ctx& ctx, aco_ptr<Instruction>& instr, bool is_sub)
3265 {
3266 if (instr->usesModifiers())
3267 return false;
3268
3269 /* Substractions: start at operand 1 to avoid mixup such as
3270 * turning v_sub(v_lshlrev(a, b), c) into v_mad_i32_i24(b, -(1<<a), c)
3271 */
3272 unsigned start_op_idx = is_sub ? 1 : 0;
3273
3274 /* Don't allow 24-bit operands on subtraction because
3275 * v_mad_i32_i24 applies a sign extension.
3276 */
3277 bool allow_24bit = !is_sub;
3278
3279 for (unsigned i = start_op_idx; i < 2; i++) {
3280 Instruction* op_instr = follow_operand(ctx, instr->operands[i]);
3281 if (!op_instr)
3282 continue;
3283
3284 if (op_instr->opcode != aco_opcode::s_lshl_b32 &&
3285 op_instr->opcode != aco_opcode::v_lshlrev_b32)
3286 continue;
3287
3288 int shift_op_idx = op_instr->opcode == aco_opcode::s_lshl_b32 ? 1 : 0;
3289
3290 if (op_instr->operands[shift_op_idx].isConstant() &&
3291 ((allow_24bit && op_instr->operands[!shift_op_idx].is24bit()) ||
3292 op_instr->operands[!shift_op_idx].is16bit())) {
3293 uint32_t multiplier = 1 << (op_instr->operands[shift_op_idx].constantValue() % 32u);
3294 if (is_sub)
3295 multiplier = -multiplier;
3296 if (is_sub ? (multiplier < 0xff800000) : (multiplier > 0xffffff))
3297 continue;
3298
3299 Operand ops[3] = {
3300 op_instr->operands[!shift_op_idx],
3301 Operand::c32(multiplier),
3302 instr->operands[!i],
3303 };
3304 if (!check_vop3_operands(ctx, 3, ops))
3305 return false;
3306
3307 ctx.uses[instr->operands[i].tempId()]--;
3308
3309 aco_opcode mad_op = is_sub ? aco_opcode::v_mad_i32_i24 : aco_opcode::v_mad_u32_u24;
3310 aco_ptr<Instruction> new_instr{create_instruction(mad_op, Format::VOP3, 3, 1)};
3311 for (unsigned op_idx = 0; op_idx < 3; ++op_idx)
3312 new_instr->operands[op_idx] = ops[op_idx];
3313 new_instr->definitions[0] = instr->definitions[0];
3314 new_instr->pass_flags = instr->pass_flags;
3315 instr = std::move(new_instr);
3316 ctx.info[instr->definitions[0].tempId()].label = 0;
3317 return true;
3318 }
3319 }
3320
3321 return false;
3322 }
3323
3324 void
propagate_swizzles(VALU_instruction * instr,bool opsel_lo,bool opsel_hi)3325 propagate_swizzles(VALU_instruction* instr, bool opsel_lo, bool opsel_hi)
3326 {
3327 /* propagate swizzles which apply to a result down to the instruction's operands:
3328 * result = a.xy + b.xx -> result.yx = a.yx + b.xx */
3329 uint8_t tmp_lo = instr->opsel_lo;
3330 uint8_t tmp_hi = instr->opsel_hi;
3331 uint8_t neg_lo = instr->neg_lo;
3332 uint8_t neg_hi = instr->neg_hi;
3333 if (opsel_lo == 1) {
3334 instr->opsel_lo = tmp_hi;
3335 instr->neg_lo = neg_hi;
3336 }
3337 if (opsel_hi == 0) {
3338 instr->opsel_hi = tmp_lo;
3339 instr->neg_hi = neg_lo;
3340 }
3341 }
3342
3343 void
combine_vop3p(opt_ctx & ctx,aco_ptr<Instruction> & instr)3344 combine_vop3p(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3345 {
3346 VALU_instruction* vop3p = &instr->valu();
3347
3348 /* apply clamp */
3349 if (instr->opcode == aco_opcode::v_pk_mul_f16 && instr->operands[1].constantEquals(0x3C00) &&
3350 vop3p->clamp && instr->operands[0].isTemp() && ctx.uses[instr->operands[0].tempId()] == 1 &&
3351 !vop3p->opsel_lo[1] && !vop3p->opsel_hi[1]) {
3352
3353 ssa_info& info = ctx.info[instr->operands[0].tempId()];
3354 if (info.is_vop3p() && instr_info.can_use_output_modifiers[(int)info.instr->opcode]) {
3355 VALU_instruction* candidate = &ctx.info[instr->operands[0].tempId()].instr->valu();
3356 candidate->clamp = true;
3357 propagate_swizzles(candidate, vop3p->opsel_lo[0], vop3p->opsel_hi[0]);
3358 instr->definitions[0].swapTemp(candidate->definitions[0]);
3359 ctx.info[candidate->definitions[0].tempId()].instr = candidate;
3360 ctx.uses[instr->definitions[0].tempId()]--;
3361 return;
3362 }
3363 }
3364
3365 /* check for fneg modifiers */
3366 for (unsigned i = 0; i < instr->operands.size(); i++) {
3367 if (!can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, i))
3368 continue;
3369 Operand& op = instr->operands[i];
3370 if (!op.isTemp())
3371 continue;
3372
3373 ssa_info& info = ctx.info[op.tempId()];
3374 if (info.is_vop3p() && info.instr->opcode == aco_opcode::v_pk_mul_f16 &&
3375 (info.instr->operands[0].constantEquals(0x3C00) ||
3376 info.instr->operands[1].constantEquals(0x3C00))) {
3377
3378 VALU_instruction* fneg = &info.instr->valu();
3379
3380 unsigned fneg_src = fneg->operands[0].constantEquals(0x3C00);
3381
3382 if (fneg->opsel_lo[1 - fneg_src] || fneg->opsel_hi[1 - fneg_src])
3383 continue;
3384
3385 Operand ops[3];
3386 for (unsigned j = 0; j < instr->operands.size(); j++)
3387 ops[j] = instr->operands[j];
3388 ops[i] = fneg->operands[fneg_src];
3389 if (!check_vop3_operands(ctx, instr->operands.size(), ops))
3390 continue;
3391
3392 if (fneg->clamp)
3393 continue;
3394 instr->operands[i] = fneg->operands[fneg_src];
3395
3396 /* opsel_lo/hi is either 0 or 1:
3397 * if 0 - pick selection from fneg->lo
3398 * if 1 - pick selection from fneg->hi
3399 */
3400 bool opsel_lo = vop3p->opsel_lo[i];
3401 bool opsel_hi = vop3p->opsel_hi[i];
3402 bool neg_lo = fneg->neg_lo[0] ^ fneg->neg_lo[1];
3403 bool neg_hi = fneg->neg_hi[0] ^ fneg->neg_hi[1];
3404 vop3p->neg_lo[i] ^= opsel_lo ? neg_hi : neg_lo;
3405 vop3p->neg_hi[i] ^= opsel_hi ? neg_hi : neg_lo;
3406 vop3p->opsel_lo[i] ^= opsel_lo ? !fneg->opsel_hi[fneg_src] : fneg->opsel_lo[fneg_src];
3407 vop3p->opsel_hi[i] ^= opsel_hi ? !fneg->opsel_hi[fneg_src] : fneg->opsel_lo[fneg_src];
3408
3409 if (--ctx.uses[fneg->definitions[0].tempId()])
3410 ctx.uses[fneg->operands[fneg_src].tempId()]++;
3411 }
3412 }
3413
3414 if (instr->opcode == aco_opcode::v_pk_add_f16 || instr->opcode == aco_opcode::v_pk_add_u16) {
3415 bool fadd = instr->opcode == aco_opcode::v_pk_add_f16;
3416 if (fadd && instr->definitions[0].isPrecise())
3417 return;
3418 if (!fadd && instr->valu().clamp)
3419 return;
3420
3421 Instruction* mul_instr = nullptr;
3422 unsigned add_op_idx = 0;
3423 bitarray8 mul_neg_lo = 0, mul_neg_hi = 0, mul_opsel_lo = 0, mul_opsel_hi = 0;
3424 uint32_t uses = UINT32_MAX;
3425
3426 /* find the 'best' mul instruction to combine with the add */
3427 for (unsigned i = 0; i < 2; i++) {
3428 Instruction* op_instr = follow_operand(ctx, instr->operands[i], true);
3429 if (!op_instr)
3430 continue;
3431
3432 if (ctx.info[instr->operands[i].tempId()].is_vop3p()) {
3433 if (fadd) {
3434 if (op_instr->opcode != aco_opcode::v_pk_mul_f16 ||
3435 op_instr->definitions[0].isPrecise())
3436 continue;
3437 } else {
3438 if (op_instr->opcode != aco_opcode::v_pk_mul_lo_u16)
3439 continue;
3440 }
3441
3442 Operand op[3] = {op_instr->operands[0], op_instr->operands[1], instr->operands[1 - i]};
3443 if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op))
3444 continue;
3445
3446 /* no clamp allowed between mul and add */
3447 if (op_instr->valu().clamp)
3448 continue;
3449
3450 mul_instr = op_instr;
3451 add_op_idx = 1 - i;
3452 uses = ctx.uses[instr->operands[i].tempId()];
3453 mul_neg_lo = mul_instr->valu().neg_lo;
3454 mul_neg_hi = mul_instr->valu().neg_hi;
3455 mul_opsel_lo = mul_instr->valu().opsel_lo;
3456 mul_opsel_hi = mul_instr->valu().opsel_hi;
3457 } else if (instr->operands[i].bytes() == 2) {
3458 if ((fadd && (op_instr->opcode != aco_opcode::v_mul_f16 ||
3459 op_instr->definitions[0].isPrecise())) ||
3460 (!fadd && op_instr->opcode != aco_opcode::v_mul_lo_u16 &&
3461 op_instr->opcode != aco_opcode::v_mul_lo_u16_e64))
3462 continue;
3463
3464 if (op_instr->valu().clamp || op_instr->valu().omod || op_instr->valu().abs)
3465 continue;
3466
3467 if (op_instr->isDPP() || (op_instr->isSDWA() && (op_instr->sdwa().sel[0].size() < 2 ||
3468 op_instr->sdwa().sel[1].size() < 2)))
3469 continue;
3470
3471 Operand op[3] = {op_instr->operands[0], op_instr->operands[1], instr->operands[1 - i]};
3472 if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op))
3473 continue;
3474
3475 mul_instr = op_instr;
3476 add_op_idx = 1 - i;
3477 uses = ctx.uses[instr->operands[i].tempId()];
3478 mul_neg_lo = mul_instr->valu().neg;
3479 mul_neg_hi = mul_instr->valu().neg;
3480 if (mul_instr->isSDWA()) {
3481 for (unsigned j = 0; j < 2; j++)
3482 mul_opsel_lo[j] = mul_instr->sdwa().sel[j].offset();
3483 } else {
3484 mul_opsel_lo = mul_instr->valu().opsel;
3485 }
3486 mul_opsel_hi = mul_opsel_lo;
3487 }
3488 }
3489
3490 if (!mul_instr)
3491 return;
3492
3493 /* turn mul + packed add into v_pk_fma_f16 */
3494 aco_opcode mad = fadd ? aco_opcode::v_pk_fma_f16 : aco_opcode::v_pk_mad_u16;
3495 aco_ptr<Instruction> fma{create_instruction(mad, Format::VOP3P, 3, 1)};
3496 fma->operands[0] = copy_operand(ctx, mul_instr->operands[0]);
3497 fma->operands[1] = copy_operand(ctx, mul_instr->operands[1]);
3498 fma->operands[2] = instr->operands[add_op_idx];
3499 fma->valu().clamp = vop3p->clamp;
3500 fma->valu().neg_lo = mul_neg_lo;
3501 fma->valu().neg_hi = mul_neg_hi;
3502 fma->valu().opsel_lo = mul_opsel_lo;
3503 fma->valu().opsel_hi = mul_opsel_hi;
3504 propagate_swizzles(&fma->valu(), vop3p->opsel_lo[1 - add_op_idx],
3505 vop3p->opsel_hi[1 - add_op_idx]);
3506 fma->valu().opsel_lo[2] = vop3p->opsel_lo[add_op_idx];
3507 fma->valu().opsel_hi[2] = vop3p->opsel_hi[add_op_idx];
3508 fma->valu().neg_lo[2] = vop3p->neg_lo[add_op_idx];
3509 fma->valu().neg_hi[2] = vop3p->neg_hi[add_op_idx];
3510 fma->valu().neg_lo[1] = fma->valu().neg_lo[1] ^ vop3p->neg_lo[1 - add_op_idx];
3511 fma->valu().neg_hi[1] = fma->valu().neg_hi[1] ^ vop3p->neg_hi[1 - add_op_idx];
3512 fma->definitions[0] = instr->definitions[0];
3513 fma->pass_flags = instr->pass_flags;
3514 instr = std::move(fma);
3515 ctx.info[instr->definitions[0].tempId()].set_vop3p(instr.get());
3516 decrease_uses(ctx, mul_instr);
3517 return;
3518 }
3519 }
3520
3521 bool
can_use_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3522 can_use_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3523 {
3524 if (ctx.program->gfx_level < GFX9)
3525 return false;
3526
3527 /* v_mad_mix* on GFX9 always flushes denormals for 16-bit inputs/outputs */
3528 if (ctx.program->gfx_level == GFX9 && ctx.fp_mode.denorm16_64)
3529 return false;
3530
3531 if (instr->valu().omod)
3532 return false;
3533
3534 switch (instr->opcode) {
3535 case aco_opcode::v_add_f32:
3536 case aco_opcode::v_sub_f32:
3537 case aco_opcode::v_subrev_f32:
3538 case aco_opcode::v_mul_f32: return !instr->isSDWA() && !instr->isDPP();
3539 case aco_opcode::v_fma_f32:
3540 return ctx.program->dev.fused_mad_mix || !instr->definitions[0].isPrecise();
3541 case aco_opcode::v_fma_mix_f32:
3542 case aco_opcode::v_fma_mixlo_f16: return true;
3543 default: return false;
3544 }
3545 }
3546
3547 void
to_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3548 to_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3549 {
3550 ctx.info[instr->definitions[0].tempId()].label &= label_f2f16 | label_clamp | label_mul;
3551
3552 if (instr->opcode == aco_opcode::v_fma_f32) {
3553 instr->format = (Format)((uint32_t)withoutVOP3(instr->format) | (uint32_t)(Format::VOP3P));
3554 instr->opcode = aco_opcode::v_fma_mix_f32;
3555 return;
3556 }
3557
3558 bool is_add = instr->opcode != aco_opcode::v_mul_f32;
3559
3560 aco_ptr<Instruction> vop3p{create_instruction(aco_opcode::v_fma_mix_f32, Format::VOP3P, 3, 1)};
3561
3562 for (unsigned i = 0; i < instr->operands.size(); i++) {
3563 vop3p->operands[is_add + i] = instr->operands[i];
3564 vop3p->valu().neg_lo[is_add + i] = instr->valu().neg[i];
3565 vop3p->valu().neg_hi[is_add + i] = instr->valu().abs[i];
3566 }
3567 if (instr->opcode == aco_opcode::v_mul_f32) {
3568 vop3p->operands[2] = Operand::zero();
3569 vop3p->valu().neg_lo[2] = true;
3570 } else if (is_add) {
3571 vop3p->operands[0] = Operand::c32(0x3f800000);
3572 if (instr->opcode == aco_opcode::v_sub_f32)
3573 vop3p->valu().neg_lo[2] ^= true;
3574 else if (instr->opcode == aco_opcode::v_subrev_f32)
3575 vop3p->valu().neg_lo[1] ^= true;
3576 }
3577 vop3p->definitions[0] = instr->definitions[0];
3578 vop3p->valu().clamp = instr->valu().clamp;
3579 vop3p->pass_flags = instr->pass_flags;
3580 instr = std::move(vop3p);
3581
3582 if (ctx.info[instr->definitions[0].tempId()].label & label_mul)
3583 ctx.info[instr->definitions[0].tempId()].instr = instr.get();
3584 }
3585
3586 bool
combine_output_conversion(opt_ctx & ctx,aco_ptr<Instruction> & instr)3587 combine_output_conversion(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3588 {
3589 ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
3590 if (!def_info.is_f2f16())
3591 return false;
3592 Instruction* conv = def_info.instr;
3593
3594 if (!ctx.uses[conv->definitions[0].tempId()] || ctx.uses[instr->definitions[0].tempId()] != 1)
3595 return false;
3596
3597 if (conv->usesModifiers())
3598 return false;
3599
3600 if (interp_can_become_fma(ctx, instr))
3601 interp_p2_f32_inreg_to_fma_dpp(instr);
3602
3603 if (!can_use_mad_mix(ctx, instr))
3604 return false;
3605
3606 if (!instr->isVOP3P())
3607 to_mad_mix(ctx, instr);
3608
3609 instr->opcode = aco_opcode::v_fma_mixlo_f16;
3610 instr->definitions[0].swapTemp(conv->definitions[0]);
3611 if (conv->definitions[0].isPrecise())
3612 instr->definitions[0].setPrecise(true);
3613 ctx.info[instr->definitions[0].tempId()].label &= label_clamp;
3614 ctx.uses[conv->definitions[0].tempId()]--;
3615
3616 return true;
3617 }
3618
3619 void
combine_mad_mix(opt_ctx & ctx,aco_ptr<Instruction> & instr)3620 combine_mad_mix(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3621 {
3622 if (!can_use_mad_mix(ctx, instr))
3623 return;
3624
3625 for (unsigned i = 0; i < instr->operands.size(); i++) {
3626 if (!instr->operands[i].isTemp())
3627 continue;
3628 Temp tmp = instr->operands[i].getTemp();
3629 if (!ctx.info[tmp.id()].is_f2f32())
3630 continue;
3631
3632 Instruction* conv = ctx.info[tmp.id()].instr;
3633 if (conv->valu().clamp || conv->valu().omod) {
3634 continue;
3635 } else if (conv->isSDWA() &&
3636 (conv->sdwa().dst_sel.size() != 4 || conv->sdwa().sel[0].size() != 2)) {
3637 continue;
3638 } else if (conv->isDPP()) {
3639 continue;
3640 }
3641
3642 if (get_operand_size(instr, i) != 32)
3643 continue;
3644
3645 /* Conversion to VOP3P will add inline constant operands, but that shouldn't affect
3646 * check_vop3_operands(). */
3647 Operand op[3];
3648 for (unsigned j = 0; j < instr->operands.size(); j++)
3649 op[j] = instr->operands[j];
3650 op[i] = conv->operands[0];
3651 if (!check_vop3_operands(ctx, instr->operands.size(), op))
3652 continue;
3653 if (!conv->operands[0].isOfType(RegType::vgpr) && instr->isDPP())
3654 continue;
3655
3656 if (!instr->isVOP3P()) {
3657 bool is_add =
3658 instr->opcode != aco_opcode::v_mul_f32 && instr->opcode != aco_opcode::v_fma_f32;
3659 to_mad_mix(ctx, instr);
3660 i += is_add;
3661 }
3662
3663 if (--ctx.uses[tmp.id()])
3664 ctx.uses[conv->operands[0].tempId()]++;
3665 instr->operands[i].setTemp(conv->operands[0].getTemp());
3666 if (conv->definitions[0].isPrecise())
3667 instr->definitions[0].setPrecise(true);
3668 instr->valu().opsel_hi[i] = true;
3669 if (conv->isSDWA() && conv->sdwa().sel[0].offset() == 2)
3670 instr->valu().opsel_lo[i] = true;
3671 else
3672 instr->valu().opsel_lo[i] = conv->valu().opsel[0];
3673 bool neg = conv->valu().neg[0];
3674 bool abs = conv->valu().abs[0];
3675 if (!instr->valu().abs[i]) {
3676 instr->valu().neg[i] ^= neg;
3677 instr->valu().abs[i] = abs;
3678 }
3679 }
3680 }
3681
3682 // TODO: we could possibly move the whole label_instruction pass to combine_instruction:
3683 // this would mean that we'd have to fix the instruction uses while value propagation
3684
3685 /* also returns true for inf */
3686 bool
is_pow_of_two(opt_ctx & ctx,Operand op)3687 is_pow_of_two(opt_ctx& ctx, Operand op)
3688 {
3689 if (op.isTemp() && ctx.info[op.tempId()].is_constant_or_literal(op.bytes() * 8))
3690 return is_pow_of_two(ctx, get_constant_op(ctx, ctx.info[op.tempId()], op.bytes() * 8));
3691 else if (!op.isConstant())
3692 return false;
3693
3694 uint64_t val = op.constantValue64();
3695
3696 if (op.bytes() == 4) {
3697 uint32_t exponent = (val & 0x7f800000) >> 23;
3698 uint32_t fraction = val & 0x007fffff;
3699 return (exponent >= 127) && (fraction == 0);
3700 } else if (op.bytes() == 2) {
3701 uint32_t exponent = (val & 0x7c00) >> 10;
3702 uint32_t fraction = val & 0x03ff;
3703 return (exponent >= 15) && (fraction == 0);
3704 } else {
3705 assert(op.bytes() == 8);
3706 uint64_t exponent = (val & UINT64_C(0x7ff0000000000000)) >> 52;
3707 uint64_t fraction = val & UINT64_C(0x000fffffffffffff);
3708 return (exponent >= 1023) && (fraction == 0);
3709 }
3710 }
3711
3712 void
combine_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)3713 combine_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
3714 {
3715 if (instr->definitions.empty() || is_dead(ctx.uses, instr.get()))
3716 return;
3717
3718 if (instr->isVALU() || instr->isSALU()) {
3719 /* Apply SDWA. Do this after label_instruction() so it can remove
3720 * label_extract if not all instructions can take SDWA. */
3721 for (unsigned i = 0; i < instr->operands.size(); i++) {
3722 Operand& op = instr->operands[i];
3723 if (!op.isTemp())
3724 continue;
3725 ssa_info& info = ctx.info[op.tempId()];
3726 if (!info.is_extract())
3727 continue;
3728 /* if there are that many uses, there are likely better combinations */
3729 // TODO: delay applying extract to a point where we know better
3730 if (ctx.uses[op.tempId()] > 4) {
3731 info.label &= ~label_extract;
3732 continue;
3733 }
3734 if (info.is_extract() &&
3735 (info.instr->operands[0].getTemp().type() == RegType::vgpr ||
3736 instr->operands[i].getTemp().type() == RegType::sgpr) &&
3737 can_apply_extract(ctx, instr, i, info)) {
3738 /* Increase use count of the extract's operand if the extract still has uses. */
3739 apply_extract(ctx, instr, i, info);
3740 if (--ctx.uses[instr->operands[i].tempId()])
3741 ctx.uses[info.instr->operands[0].tempId()]++;
3742 instr->operands[i].setTemp(info.instr->operands[0].getTemp());
3743 }
3744 }
3745 }
3746
3747 if (instr->isVALU()) {
3748 if (can_apply_sgprs(ctx, instr))
3749 apply_sgprs(ctx, instr);
3750 combine_mad_mix(ctx, instr);
3751 while (apply_omod_clamp(ctx, instr) || combine_output_conversion(ctx, instr))
3752 ;
3753 apply_insert(ctx, instr);
3754 }
3755
3756 if (instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_mix_f32 &&
3757 instr->opcode != aco_opcode::v_fma_mixlo_f16)
3758 return combine_vop3p(ctx, instr);
3759
3760 if (instr->isSDWA() || instr->isDPP())
3761 return;
3762
3763 if (instr->opcode == aco_opcode::p_extract) {
3764 ssa_info& info = ctx.info[instr->operands[0].tempId()];
3765 if (info.is_extract() && can_apply_extract(ctx, instr, 0, info)) {
3766 apply_extract(ctx, instr, 0, info);
3767 if (--ctx.uses[instr->operands[0].tempId()])
3768 ctx.uses[info.instr->operands[0].tempId()]++;
3769 instr->operands[0].setTemp(info.instr->operands[0].getTemp());
3770 }
3771
3772 apply_ds_extract(ctx, instr);
3773 }
3774
3775 /* TODO: There are still some peephole optimizations that could be done:
3776 * - abs(a - b) -> s_absdiff_i32
3777 * - various patterns for s_bitcmp{0,1}_b32 and s_bitset{0,1}_b32
3778 * - patterns for v_alignbit_b32 and v_alignbyte_b32
3779 * These aren't probably too interesting though.
3780 * There are also patterns for v_cmp_class_f{16,32,64}. This is difficult but
3781 * probably more useful than the previously mentioned optimizations.
3782 * The various comparison optimizations also currently only work with 32-bit
3783 * floats. */
3784
3785 /* neg(mul(a, b)) -> mul(neg(a), b), abs(mul(a, b)) -> mul(abs(a), abs(b)) */
3786 if ((ctx.info[instr->definitions[0].tempId()].label & (label_neg | label_abs)) &&
3787 ctx.uses[instr->operands[1].tempId()] == 1) {
3788 Temp val = ctx.info[instr->definitions[0].tempId()].temp;
3789
3790 if (!ctx.info[val.id()].is_mul())
3791 return;
3792
3793 Instruction* mul_instr = ctx.info[val.id()].instr;
3794
3795 if (mul_instr->operands[0].isLiteral())
3796 return;
3797 if (mul_instr->valu().clamp)
3798 return;
3799 if (mul_instr->isSDWA() || mul_instr->isDPP())
3800 return;
3801 if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32 &&
3802 mul_instr->definitions[0].isSZPreserve())
3803 return;
3804 if (mul_instr->definitions[0].bytes() != instr->definitions[0].bytes())
3805 return;
3806
3807 /* convert to mul(neg(a), b), mul(abs(a), abs(b)) or mul(neg(abs(a)), abs(b)) */
3808 ctx.uses[mul_instr->definitions[0].tempId()]--;
3809 Definition def = instr->definitions[0];
3810 bool is_neg = ctx.info[instr->definitions[0].tempId()].is_neg();
3811 bool is_abs = ctx.info[instr->definitions[0].tempId()].is_abs();
3812 uint32_t pass_flags = instr->pass_flags;
3813 Format format = mul_instr->format == Format::VOP2 ? asVOP3(Format::VOP2) : mul_instr->format;
3814 instr.reset(create_instruction(mul_instr->opcode, format, mul_instr->operands.size(), 1));
3815 std::copy(mul_instr->operands.cbegin(), mul_instr->operands.cend(), instr->operands.begin());
3816 instr->pass_flags = pass_flags;
3817 instr->definitions[0] = def;
3818 VALU_instruction& new_mul = instr->valu();
3819 VALU_instruction& mul = mul_instr->valu();
3820 new_mul.neg = mul.neg;
3821 new_mul.abs = mul.abs;
3822 new_mul.omod = mul.omod;
3823 new_mul.opsel = mul.opsel;
3824 new_mul.opsel_lo = mul.opsel_lo;
3825 new_mul.opsel_hi = mul.opsel_hi;
3826 if (is_abs) {
3827 new_mul.neg[0] = new_mul.neg[1] = false;
3828 new_mul.abs[0] = new_mul.abs[1] = true;
3829 }
3830 new_mul.neg[0] ^= is_neg;
3831 new_mul.clamp = false;
3832
3833 ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
3834 return;
3835 }
3836
3837 /* combine mul+add -> mad */
3838 bool is_add_mix =
3839 (instr->opcode == aco_opcode::v_fma_mix_f32 ||
3840 instr->opcode == aco_opcode::v_fma_mixlo_f16) &&
3841 !instr->valu().neg_lo[0] &&
3842 ((instr->operands[0].constantEquals(0x3f800000) && !instr->valu().opsel_hi[0]) ||
3843 (instr->operands[0].constantEquals(0x3C00) && instr->valu().opsel_hi[0] &&
3844 !instr->valu().opsel_lo[0]));
3845 bool mad32 = instr->opcode == aco_opcode::v_add_f32 || instr->opcode == aco_opcode::v_sub_f32 ||
3846 instr->opcode == aco_opcode::v_subrev_f32;
3847 bool mad16 = instr->opcode == aco_opcode::v_add_f16 || instr->opcode == aco_opcode::v_sub_f16 ||
3848 instr->opcode == aco_opcode::v_subrev_f16;
3849 bool mad64 =
3850 instr->opcode == aco_opcode::v_add_f64_e64 || instr->opcode == aco_opcode::v_add_f64;
3851 if (is_add_mix || mad16 || mad32 || mad64) {
3852 Instruction* mul_instr = nullptr;
3853 unsigned add_op_idx = 0;
3854 uint32_t uses = UINT32_MAX;
3855 bool emit_fma = false;
3856 /* find the 'best' mul instruction to combine with the add */
3857 for (unsigned i = is_add_mix ? 1 : 0; i < instr->operands.size(); i++) {
3858 if (!instr->operands[i].isTemp() || !ctx.info[instr->operands[i].tempId()].is_mul())
3859 continue;
3860 ssa_info& info = ctx.info[instr->operands[i].tempId()];
3861
3862 /* no clamp/omod allowed between mul and add */
3863 if (info.instr->isVOP3() && (info.instr->valu().clamp || info.instr->valu().omod))
3864 continue;
3865 if (info.instr->isVOP3P() && info.instr->valu().clamp)
3866 continue;
3867 /* v_fma_mix_f32/etc can't do omod */
3868 if (info.instr->isVOP3P() && instr->isVOP3() && instr->valu().omod)
3869 continue;
3870 /* don't promote fp16 to fp32 or remove fp32->fp16->fp32 conversions */
3871 if (is_add_mix && info.instr->definitions[0].bytes() == 2)
3872 continue;
3873
3874 if (get_operand_size(instr, i) != info.instr->definitions[0].bytes() * 8)
3875 continue;
3876
3877 bool legacy = info.instr->opcode == aco_opcode::v_mul_legacy_f32;
3878 bool mad_mix = is_add_mix || info.instr->isVOP3P();
3879
3880 /* Multiplication by power-of-two should never need rounding. 1/power-of-two also works,
3881 * but using fma removes denormal flushing (0xfffffe * 0.5 + 0x810001a2).
3882 */
3883 bool is_fma_precise = is_pow_of_two(ctx, info.instr->operands[0]) ||
3884 is_pow_of_two(ctx, info.instr->operands[1]);
3885
3886 bool has_fma = mad16 || mad64 || (legacy && ctx.program->gfx_level >= GFX10_3) ||
3887 (mad32 && !legacy && !mad_mix && ctx.program->dev.has_fast_fma32) ||
3888 (mad_mix && ctx.program->dev.fused_mad_mix);
3889 bool has_mad = mad_mix ? !ctx.program->dev.fused_mad_mix
3890 : ((mad32 && ctx.program->gfx_level < GFX10_3) ||
3891 (mad16 && ctx.program->gfx_level <= GFX9));
3892 bool can_use_fma =
3893 has_fma &&
3894 (!(info.instr->definitions[0].isPrecise() || instr->definitions[0].isPrecise()) ||
3895 is_fma_precise);
3896 bool can_use_mad =
3897 has_mad && (mad_mix || mad32 ? ctx.fp_mode.denorm32 : ctx.fp_mode.denorm16_64) == 0;
3898 if (mad_mix && legacy)
3899 continue;
3900 if (!can_use_fma && !can_use_mad)
3901 continue;
3902
3903 unsigned candidate_add_op_idx = is_add_mix ? (3 - i) : (1 - i);
3904 Operand op[3] = {info.instr->operands[0], info.instr->operands[1],
3905 instr->operands[candidate_add_op_idx]};
3906 if (info.instr->isSDWA() || info.instr->isDPP() || !check_vop3_operands(ctx, 3, op) ||
3907 ctx.uses[instr->operands[i].tempId()] > uses)
3908 continue;
3909
3910 if (ctx.uses[instr->operands[i].tempId()] == uses) {
3911 unsigned cur_idx = mul_instr->definitions[0].tempId();
3912 unsigned new_idx = info.instr->definitions[0].tempId();
3913 if (cur_idx > new_idx)
3914 continue;
3915 }
3916
3917 mul_instr = info.instr;
3918 add_op_idx = candidate_add_op_idx;
3919 uses = ctx.uses[instr->operands[i].tempId()];
3920 emit_fma = !can_use_mad;
3921 }
3922
3923 if (mul_instr) {
3924 /* turn mul+add into v_mad/v_fma */
3925 Operand op[3] = {mul_instr->operands[0], mul_instr->operands[1],
3926 instr->operands[add_op_idx]};
3927 ctx.uses[mul_instr->definitions[0].tempId()]--;
3928 if (ctx.uses[mul_instr->definitions[0].tempId()]) {
3929 if (op[0].isTemp())
3930 ctx.uses[op[0].tempId()]++;
3931 if (op[1].isTemp())
3932 ctx.uses[op[1].tempId()]++;
3933 }
3934
3935 bool neg[3] = {false, false, false};
3936 bool abs[3] = {false, false, false};
3937 unsigned omod = 0;
3938 bool clamp = false;
3939 bitarray8 opsel_lo = 0;
3940 bitarray8 opsel_hi = 0;
3941 bitarray8 opsel = 0;
3942 unsigned mul_op_idx = (instr->isVOP3P() ? 3 : 1) - add_op_idx;
3943
3944 VALU_instruction& valu_mul = mul_instr->valu();
3945 neg[0] = valu_mul.neg[0];
3946 neg[1] = valu_mul.neg[1];
3947 abs[0] = valu_mul.abs[0];
3948 abs[1] = valu_mul.abs[1];
3949 opsel_lo = valu_mul.opsel_lo & 0x3;
3950 opsel_hi = valu_mul.opsel_hi & 0x3;
3951 opsel = valu_mul.opsel & 0x3;
3952
3953 VALU_instruction& valu = instr->valu();
3954 neg[2] = valu.neg[add_op_idx];
3955 abs[2] = valu.abs[add_op_idx];
3956 opsel_lo[2] = valu.opsel_lo[add_op_idx];
3957 opsel_hi[2] = valu.opsel_hi[add_op_idx];
3958 opsel[2] = valu.opsel[add_op_idx];
3959 opsel[3] = valu.opsel[3];
3960 omod = valu.omod;
3961 clamp = valu.clamp;
3962 /* abs of the multiplication result */
3963 if (valu.abs[mul_op_idx]) {
3964 neg[0] = false;
3965 neg[1] = false;
3966 abs[0] = true;
3967 abs[1] = true;
3968 }
3969 /* neg of the multiplication result */
3970 neg[1] ^= valu.neg[mul_op_idx];
3971
3972 if (instr->opcode == aco_opcode::v_sub_f32 || instr->opcode == aco_opcode::v_sub_f16)
3973 neg[1 + add_op_idx] = neg[1 + add_op_idx] ^ true;
3974 else if (instr->opcode == aco_opcode::v_subrev_f32 ||
3975 instr->opcode == aco_opcode::v_subrev_f16)
3976 neg[2 - add_op_idx] = neg[2 - add_op_idx] ^ true;
3977
3978 aco_ptr<Instruction> add_instr = std::move(instr);
3979 aco_ptr<Instruction> mad;
3980 if (add_instr->isVOP3P() || mul_instr->isVOP3P()) {
3981 assert(!omod);
3982 assert(!opsel);
3983
3984 aco_opcode mad_op = add_instr->definitions[0].bytes() == 2 ? aco_opcode::v_fma_mixlo_f16
3985 : aco_opcode::v_fma_mix_f32;
3986 mad.reset(create_instruction(mad_op, Format::VOP3P, 3, 1));
3987 } else {
3988 assert(!opsel_lo);
3989 assert(!opsel_hi);
3990
3991 aco_opcode mad_op = emit_fma ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32;
3992 if (mul_instr->opcode == aco_opcode::v_mul_legacy_f32) {
3993 assert(emit_fma == (ctx.program->gfx_level >= GFX10_3));
3994 mad_op = emit_fma ? aco_opcode::v_fma_legacy_f32 : aco_opcode::v_mad_legacy_f32;
3995 } else if (mad16) {
3996 mad_op = emit_fma ? (ctx.program->gfx_level == GFX8 ? aco_opcode::v_fma_legacy_f16
3997 : aco_opcode::v_fma_f16)
3998 : (ctx.program->gfx_level == GFX8 ? aco_opcode::v_mad_legacy_f16
3999 : aco_opcode::v_mad_f16);
4000 } else if (mad64) {
4001 mad_op = aco_opcode::v_fma_f64;
4002 }
4003
4004 mad.reset(create_instruction(mad_op, Format::VOP3, 3, 1));
4005 }
4006
4007 for (unsigned i = 0; i < 3; i++) {
4008 mad->operands[i] = op[i];
4009 mad->valu().neg[i] = neg[i];
4010 mad->valu().abs[i] = abs[i];
4011 }
4012 mad->valu().omod = omod;
4013 mad->valu().clamp = clamp;
4014 mad->valu().opsel_lo = opsel_lo;
4015 mad->valu().opsel_hi = opsel_hi;
4016 mad->valu().opsel = opsel;
4017 mad->definitions[0] = add_instr->definitions[0];
4018 mad->definitions[0].setPrecise(add_instr->definitions[0].isPrecise() ||
4019 mul_instr->definitions[0].isPrecise());
4020 mad->pass_flags = add_instr->pass_flags;
4021
4022 instr = std::move(mad);
4023
4024 /* mark this ssa_def to be re-checked for profitability and literals */
4025 ctx.mad_infos.emplace_back(std::move(add_instr), mul_instr->definitions[0].tempId());
4026 ctx.info[instr->definitions[0].tempId()].set_mad(ctx.mad_infos.size() - 1);
4027 return;
4028 }
4029 }
4030 /* v_mul_f32(v_cndmask_b32(0, 1.0, cond), a) -> v_cndmask_b32(0, a, cond) */
4031 else if (((instr->opcode == aco_opcode::v_mul_f32 && !instr->definitions[0].isNaNPreserve() &&
4032 !instr->definitions[0].isInfPreserve()) ||
4033 (instr->opcode == aco_opcode::v_mul_legacy_f32 &&
4034 !instr->definitions[0].isSZPreserve())) &&
4035 !instr->usesModifiers() && !ctx.fp_mode.must_flush_denorms32) {
4036 for (unsigned i = 0; i < 2; i++) {
4037 if (instr->operands[i].isTemp() && ctx.info[instr->operands[i].tempId()].is_b2f() &&
4038 ctx.uses[instr->operands[i].tempId()] == 1 && instr->operands[!i].isTemp() &&
4039 instr->operands[!i].getTemp().type() == RegType::vgpr) {
4040 ctx.uses[instr->operands[i].tempId()]--;
4041 ctx.uses[ctx.info[instr->operands[i].tempId()].temp.id()]++;
4042
4043 aco_ptr<Instruction> new_instr{
4044 create_instruction(aco_opcode::v_cndmask_b32, Format::VOP2, 3, 1)};
4045 new_instr->operands[0] = Operand::zero();
4046 new_instr->operands[1] = instr->operands[!i];
4047 new_instr->operands[2] = Operand(ctx.info[instr->operands[i].tempId()].temp);
4048 new_instr->definitions[0] = instr->definitions[0];
4049 new_instr->pass_flags = instr->pass_flags;
4050 instr = std::move(new_instr);
4051 ctx.info[instr->definitions[0].tempId()].label = 0;
4052 return;
4053 }
4054 }
4055 } else if (instr->opcode == aco_opcode::v_or_b32 && ctx.program->gfx_level >= GFX9) {
4056 if (combine_three_valu_op(ctx, instr, aco_opcode::s_or_b32, aco_opcode::v_or3_b32, "012",
4057 1 | 2)) {
4058 } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_or_b32, aco_opcode::v_or3_b32,
4059 "012", 1 | 2)) {
4060 } else if (combine_add_or_then_and_lshl(ctx, instr)) {
4061 } else if (combine_v_andor_not(ctx, instr)) {
4062 }
4063 } else if (instr->opcode == aco_opcode::v_xor_b32 && ctx.program->gfx_level >= GFX10) {
4064 if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xor3_b32, "012",
4065 1 | 2)) {
4066 } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xor3_b32,
4067 "012", 1 | 2)) {
4068 } else if (combine_xor_not(ctx, instr)) {
4069 }
4070 } else if (instr->opcode == aco_opcode::v_not_b32 && ctx.program->gfx_level >= GFX10) {
4071 combine_not_xor(ctx, instr);
4072 } else if (instr->opcode == aco_opcode::v_add_u16 && !instr->valu().clamp) {
4073 combine_three_valu_op(
4074 ctx, instr, aco_opcode::v_mul_lo_u16,
4075 ctx.program->gfx_level == GFX8 ? aco_opcode::v_mad_legacy_u16 : aco_opcode::v_mad_u16,
4076 "120", 1 | 2);
4077 } else if (instr->opcode == aco_opcode::v_add_u16_e64 && !instr->valu().clamp) {
4078 combine_three_valu_op(ctx, instr, aco_opcode::v_mul_lo_u16_e64, aco_opcode::v_mad_u16, "120",
4079 1 | 2);
4080 } else if (instr->opcode == aco_opcode::v_add_u32 && !instr->usesModifiers()) {
4081 if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_addc_co_u32, 1 | 2)) {
4082 } else if (combine_add_bcnt(ctx, instr)) {
4083 } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_mul_u32_u24,
4084 aco_opcode::v_mad_u32_u24, "120", 1 | 2)) {
4085 } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_mul_i32_i24,
4086 aco_opcode::v_mad_i32_i24, "120", 1 | 2)) {
4087 } else if (ctx.program->gfx_level >= GFX9) {
4088 if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xad_u32, "120",
4089 1 | 2)) {
4090 } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xad_u32,
4091 "120", 1 | 2)) {
4092 } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_i32, aco_opcode::v_add3_u32,
4093 "012", 1 | 2)) {
4094 } else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_u32, aco_opcode::v_add3_u32,
4095 "012", 1 | 2)) {
4096 } else if (combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add3_u32,
4097 "012", 1 | 2)) {
4098 } else if (combine_add_or_then_and_lshl(ctx, instr)) {
4099 }
4100 }
4101 } else if ((instr->opcode == aco_opcode::v_add_co_u32 ||
4102 instr->opcode == aco_opcode::v_add_co_u32_e64) &&
4103 !instr->usesModifiers()) {
4104 bool carry_out = ctx.uses[instr->definitions[1].tempId()] > 0;
4105 if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_addc_co_u32, 1 | 2)) {
4106 } else if (!carry_out && combine_add_bcnt(ctx, instr)) {
4107 } else if (!carry_out && combine_three_valu_op(ctx, instr, aco_opcode::v_mul_u32_u24,
4108 aco_opcode::v_mad_u32_u24, "120", 1 | 2)) {
4109 } else if (!carry_out && combine_three_valu_op(ctx, instr, aco_opcode::v_mul_i32_i24,
4110 aco_opcode::v_mad_i32_i24, "120", 1 | 2)) {
4111 } else if (!carry_out && combine_add_lshl(ctx, instr, false)) {
4112 }
4113 } else if (instr->opcode == aco_opcode::v_sub_u32 || instr->opcode == aco_opcode::v_sub_co_u32 ||
4114 instr->opcode == aco_opcode::v_sub_co_u32_e64) {
4115 bool carry_out =
4116 instr->opcode != aco_opcode::v_sub_u32 && ctx.uses[instr->definitions[1].tempId()] > 0;
4117 if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 2)) {
4118 } else if (!carry_out && combine_add_lshl(ctx, instr, true)) {
4119 }
4120 } else if (instr->opcode == aco_opcode::v_subrev_u32 ||
4121 instr->opcode == aco_opcode::v_subrev_co_u32 ||
4122 instr->opcode == aco_opcode::v_subrev_co_u32_e64) {
4123 combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 1);
4124 } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && ctx.program->gfx_level >= GFX9) {
4125 combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add_lshl_u32, "120",
4126 2);
4127 } else if ((instr->opcode == aco_opcode::s_add_u32 || instr->opcode == aco_opcode::s_add_i32) &&
4128 ctx.program->gfx_level >= GFX9) {
4129 combine_salu_lshl_add(ctx, instr);
4130 } else if (instr->opcode == aco_opcode::s_not_b32 || instr->opcode == aco_opcode::s_not_b64) {
4131 if (!combine_salu_not_bitwise(ctx, instr))
4132 combine_inverse_comparison(ctx, instr);
4133 } else if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_or_b32 ||
4134 instr->opcode == aco_opcode::s_and_b64 || instr->opcode == aco_opcode::s_or_b64) {
4135 combine_salu_n2(ctx, instr);
4136 } else if (instr->opcode == aco_opcode::s_abs_i32) {
4137 combine_sabsdiff(ctx, instr);
4138 } else if (instr->opcode == aco_opcode::v_and_b32) {
4139 if (combine_and_subbrev(ctx, instr)) {
4140 } else if (combine_v_andor_not(ctx, instr)) {
4141 }
4142 } else if (instr->opcode == aco_opcode::v_fma_f32 || instr->opcode == aco_opcode::v_fma_f16) {
4143 /* set existing v_fma_f32 with label_mad so we can create v_fmamk_f32/v_fmaak_f32.
4144 * since ctx.uses[mad_info::mul_temp_id] is always 0, we don't have to worry about
4145 * select_instruction() using mad_info::add_instr.
4146 */
4147 ctx.mad_infos.emplace_back(nullptr, 0);
4148 ctx.info[instr->definitions[0].tempId()].set_mad(ctx.mad_infos.size() - 1);
4149 } else if (instr->opcode == aco_opcode::v_med3_f32 || instr->opcode == aco_opcode::v_med3_f16) {
4150 /* Optimize v_med3 to v_add so that it can be dual issued on GFX11. We start with v_med3 in
4151 * case omod can be applied.
4152 */
4153 unsigned idx;
4154 if (detect_clamp(instr.get(), &idx)) {
4155 instr->format = asVOP3(Format::VOP2);
4156 instr->operands[0] = instr->operands[idx];
4157 instr->operands[1] = Operand::zero();
4158 instr->opcode =
4159 instr->opcode == aco_opcode::v_med3_f32 ? aco_opcode::v_add_f32 : aco_opcode::v_add_f16;
4160 instr->valu().clamp = true;
4161 instr->valu().abs = (uint8_t)instr->valu().abs[idx];
4162 instr->valu().neg = (uint8_t)instr->valu().neg[idx];
4163 instr->operands.pop_back();
4164 }
4165 } else {
4166 aco_opcode min, max, min3, max3, med3, minmax;
4167 bool some_gfx9_only;
4168 if (get_minmax_info(instr->opcode, &min, &max, &min3, &max3, &med3, &minmax,
4169 &some_gfx9_only) &&
4170 (!some_gfx9_only || ctx.program->gfx_level >= GFX9)) {
4171 if (combine_minmax(ctx, instr, instr->opcode == min ? max : min,
4172 instr->opcode == min ? min3 : max3, minmax)) {
4173 } else {
4174 combine_clamp(ctx, instr, min, max, med3);
4175 }
4176 }
4177 }
4178 }
4179
4180 struct remat_entry {
4181 Instruction* instr;
4182 uint32_t block;
4183 };
4184
4185 inline bool
is_constant(Instruction * instr)4186 is_constant(Instruction* instr)
4187 {
4188 if (instr->opcode != aco_opcode::p_parallelcopy || instr->operands.size() != 1)
4189 return false;
4190
4191 return instr->operands[0].isConstant() && instr->definitions[0].isTemp();
4192 }
4193
4194 void
remat_constants_instr(opt_ctx & ctx,aco::map<Temp,remat_entry> & constants,Instruction * instr,uint32_t block_idx)4195 remat_constants_instr(opt_ctx& ctx, aco::map<Temp, remat_entry>& constants, Instruction* instr,
4196 uint32_t block_idx)
4197 {
4198 for (Operand& op : instr->operands) {
4199 if (!op.isTemp())
4200 continue;
4201
4202 auto it = constants.find(op.getTemp());
4203 if (it == constants.end())
4204 continue;
4205
4206 /* Check if we already emitted the same constant in this block. */
4207 if (it->second.block != block_idx) {
4208 /* Rematerialize the constant. */
4209 Builder bld(ctx.program, &ctx.instructions);
4210 Operand const_op = it->second.instr->operands[0];
4211 it->second.instr = bld.copy(bld.def(op.regClass()), const_op);
4212 it->second.block = block_idx;
4213 ctx.uses.push_back(0);
4214 ctx.info.push_back(ctx.info[op.tempId()]);
4215 }
4216
4217 /* Use the rematerialized constant and update information about latest use. */
4218 if (op.getTemp() != it->second.instr->definitions[0].getTemp()) {
4219 ctx.uses[op.tempId()]--;
4220 op.setTemp(it->second.instr->definitions[0].getTemp());
4221 ctx.uses[op.tempId()]++;
4222 }
4223 }
4224 }
4225
4226 /**
4227 * This pass implements a simple constant rematerialization.
4228 * As common subexpression elimination (CSE) might increase the live-ranges
4229 * of loaded constants over large distances, this pass splits the live-ranges
4230 * again by re-emitting constants in every basic block.
4231 */
4232 void
rematerialize_constants(opt_ctx & ctx)4233 rematerialize_constants(opt_ctx& ctx)
4234 {
4235 aco::monotonic_buffer_resource memory(1024);
4236 aco::map<Temp, remat_entry> constants(memory);
4237
4238 for (Block& block : ctx.program->blocks) {
4239 if (block.logical_idom == -1)
4240 continue;
4241
4242 if (block.logical_idom == (int)block.index)
4243 constants.clear();
4244
4245 ctx.instructions.reserve(block.instructions.size());
4246
4247 for (aco_ptr<Instruction>& instr : block.instructions) {
4248 if (is_dead(ctx.uses, instr.get()))
4249 continue;
4250
4251 if (is_constant(instr.get())) {
4252 Temp tmp = instr->definitions[0].getTemp();
4253 constants[tmp] = {instr.get(), block.index};
4254 } else if (!is_phi(instr)) {
4255 remat_constants_instr(ctx, constants, instr.get(), block.index);
4256 }
4257
4258 ctx.instructions.emplace_back(instr.release());
4259 }
4260
4261 block.instructions = std::move(ctx.instructions);
4262 }
4263 }
4264
4265 bool
to_uniform_bool_instr(opt_ctx & ctx,aco_ptr<Instruction> & instr)4266 to_uniform_bool_instr(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4267 {
4268 /* Check every operand to make sure they are suitable. */
4269 for (Operand& op : instr->operands) {
4270 if (!op.isTemp())
4271 return false;
4272 if (!ctx.info[op.tempId()].is_uniform_bool() && !ctx.info[op.tempId()].is_uniform_bitwise())
4273 return false;
4274 }
4275
4276 switch (instr->opcode) {
4277 case aco_opcode::s_and_b32:
4278 case aco_opcode::s_and_b64: instr->opcode = aco_opcode::s_and_b32; break;
4279 case aco_opcode::s_or_b32:
4280 case aco_opcode::s_or_b64: instr->opcode = aco_opcode::s_or_b32; break;
4281 case aco_opcode::s_xor_b32:
4282 case aco_opcode::s_xor_b64: instr->opcode = aco_opcode::s_absdiff_i32; break;
4283 default:
4284 /* Don't transform other instructions. They are very unlikely to appear here. */
4285 return false;
4286 }
4287
4288 for (Operand& op : instr->operands) {
4289 ctx.uses[op.tempId()]--;
4290
4291 if (ctx.info[op.tempId()].is_uniform_bool()) {
4292 /* Just use the uniform boolean temp. */
4293 op.setTemp(ctx.info[op.tempId()].temp);
4294 } else if (ctx.info[op.tempId()].is_uniform_bitwise()) {
4295 /* Use the SCC definition of the predecessor instruction.
4296 * This allows the predecessor to get picked up by the same optimization (if it has no
4297 * divergent users), and it also makes sure that the current instruction will keep working
4298 * even if the predecessor won't be transformed.
4299 */
4300 Instruction* pred_instr = ctx.info[op.tempId()].instr;
4301 assert(pred_instr->definitions.size() >= 2);
4302 assert(pred_instr->definitions[1].isFixed() &&
4303 pred_instr->definitions[1].physReg() == scc);
4304 op.setTemp(pred_instr->definitions[1].getTemp());
4305 } else {
4306 unreachable("Invalid operand on uniform bitwise instruction.");
4307 }
4308
4309 ctx.uses[op.tempId()]++;
4310 }
4311
4312 instr->definitions[0].setTemp(Temp(instr->definitions[0].tempId(), s1));
4313 ctx.program->temp_rc[instr->definitions[0].tempId()] = s1;
4314 assert(instr->operands[0].regClass() == s1);
4315 assert(instr->operands[1].regClass() == s1);
4316 return true;
4317 }
4318
4319 void
select_instruction(opt_ctx & ctx,aco_ptr<Instruction> & instr)4320 select_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4321 {
4322 const uint32_t threshold = 4;
4323
4324 if (is_dead(ctx.uses, instr.get())) {
4325 instr.reset();
4326 return;
4327 }
4328
4329 /* convert split_vector into a copy or extract_vector if only one definition is ever used */
4330 if (instr->opcode == aco_opcode::p_split_vector) {
4331 unsigned num_used = 0;
4332 unsigned idx = 0;
4333 unsigned split_offset = 0;
4334 for (unsigned i = 0, offset = 0; i < instr->definitions.size();
4335 offset += instr->definitions[i++].bytes()) {
4336 if (ctx.uses[instr->definitions[i].tempId()]) {
4337 num_used++;
4338 idx = i;
4339 split_offset = offset;
4340 }
4341 }
4342 bool done = false;
4343 if (num_used == 1 && ctx.info[instr->operands[0].tempId()].is_vec() &&
4344 ctx.uses[instr->operands[0].tempId()] == 1) {
4345 Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
4346
4347 unsigned off = 0;
4348 Operand op;
4349 for (Operand& vec_op : vec->operands) {
4350 if (off == split_offset) {
4351 op = vec_op;
4352 break;
4353 }
4354 off += vec_op.bytes();
4355 }
4356 if (off != instr->operands[0].bytes() && op.bytes() == instr->definitions[idx].bytes()) {
4357 ctx.uses[instr->operands[0].tempId()]--;
4358 for (Operand& vec_op : vec->operands) {
4359 if (vec_op.isTemp())
4360 ctx.uses[vec_op.tempId()]--;
4361 }
4362 if (op.isTemp())
4363 ctx.uses[op.tempId()]++;
4364
4365 aco_ptr<Instruction> copy{
4366 create_instruction(aco_opcode::p_parallelcopy, Format::PSEUDO, 1, 1)};
4367 copy->operands[0] = op;
4368 copy->definitions[0] = instr->definitions[idx];
4369 instr = std::move(copy);
4370
4371 done = true;
4372 }
4373 }
4374
4375 if (!done && num_used == 1 &&
4376 instr->operands[0].bytes() % instr->definitions[idx].bytes() == 0 &&
4377 split_offset % instr->definitions[idx].bytes() == 0) {
4378 aco_ptr<Instruction> extract{
4379 create_instruction(aco_opcode::p_extract_vector, Format::PSEUDO, 2, 1)};
4380 extract->operands[0] = instr->operands[0];
4381 extract->operands[1] =
4382 Operand::c32((uint32_t)split_offset / instr->definitions[idx].bytes());
4383 extract->definitions[0] = instr->definitions[idx];
4384 instr = std::move(extract);
4385 }
4386 }
4387
4388 mad_info* mad_info = NULL;
4389 if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
4390 mad_info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
4391 /* re-check mad instructions */
4392 if (ctx.uses[mad_info->mul_temp_id] && mad_info->add_instr) {
4393 ctx.uses[mad_info->mul_temp_id]++;
4394 if (instr->operands[0].isTemp())
4395 ctx.uses[instr->operands[0].tempId()]--;
4396 if (instr->operands[1].isTemp())
4397 ctx.uses[instr->operands[1].tempId()]--;
4398 instr.swap(mad_info->add_instr);
4399 mad_info = NULL;
4400 }
4401 /* check literals */
4402 else if (!instr->isDPP() && !instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_f64 &&
4403 instr->opcode != aco_opcode::v_mad_legacy_f32 &&
4404 instr->opcode != aco_opcode::v_fma_legacy_f32) {
4405 /* FMA can only take literals on GFX10+ */
4406 if ((instr->opcode == aco_opcode::v_fma_f32 || instr->opcode == aco_opcode::v_fma_f16) &&
4407 ctx.program->gfx_level < GFX10)
4408 return;
4409 /* There are no v_fmaak_legacy_f16/v_fmamk_legacy_f16 and on chips where VOP3 can take
4410 * literals (GFX10+), these instructions don't exist.
4411 */
4412 if (instr->opcode == aco_opcode::v_fma_legacy_f16)
4413 return;
4414
4415 uint32_t literal_mask = 0;
4416 uint32_t fp16_mask = 0;
4417 uint32_t sgpr_mask = 0;
4418 uint32_t vgpr_mask = 0;
4419 uint32_t literal_uses = UINT32_MAX;
4420 uint32_t literal_value = 0;
4421
4422 /* Iterate in reverse to prefer v_madak/v_fmaak. */
4423 for (int i = 2; i >= 0; i--) {
4424 Operand& op = instr->operands[i];
4425 if (!op.isTemp())
4426 continue;
4427 if (ctx.info[op.tempId()].is_literal(get_operand_size(instr, i))) {
4428 uint32_t new_literal = ctx.info[op.tempId()].val;
4429 float value = uif(new_literal);
4430 uint16_t fp16_val = _mesa_float_to_half(value);
4431 bool is_denorm = (fp16_val & 0x7fff) != 0 && (fp16_val & 0x7fff) <= 0x3ff;
4432 if (_mesa_half_to_float(fp16_val) == value &&
4433 (!is_denorm || (ctx.fp_mode.denorm16_64 & fp_denorm_keep_in)))
4434 fp16_mask |= 1 << i;
4435
4436 if (!literal_mask || literal_value == new_literal) {
4437 literal_value = new_literal;
4438 literal_uses = MIN2(literal_uses, ctx.uses[op.tempId()]);
4439 literal_mask |= 1 << i;
4440 continue;
4441 }
4442 }
4443 sgpr_mask |= op.isOfType(RegType::sgpr) << i;
4444 vgpr_mask |= op.isOfType(RegType::vgpr) << i;
4445 }
4446
4447 /* The constant bus limitations before GFX10 disallows SGPRs. */
4448 if (sgpr_mask && ctx.program->gfx_level < GFX10)
4449 literal_mask = 0;
4450
4451 /* Encoding needs a vgpr. */
4452 if (!vgpr_mask)
4453 literal_mask = 0;
4454
4455 /* v_madmk/v_fmamk needs a vgpr in the third source. */
4456 if (!(literal_mask & 0b100) && !(vgpr_mask & 0b100))
4457 literal_mask = 0;
4458
4459 /* opsel with GFX11+ is the only modifier supported by fmamk/fmaak*/
4460 if (instr->valu().abs || instr->valu().neg || instr->valu().omod || instr->valu().clamp ||
4461 (instr->valu().opsel && ctx.program->gfx_level < GFX11))
4462 literal_mask = 0;
4463
4464 if (instr->valu().opsel & ~vgpr_mask)
4465 literal_mask = 0;
4466
4467 /* We can't use three unique fp16 literals */
4468 if (fp16_mask == 0b111)
4469 fp16_mask = 0b11;
4470
4471 if ((instr->opcode == aco_opcode::v_fma_f32 ||
4472 (instr->opcode == aco_opcode::v_mad_f32 && !instr->definitions[0].isPrecise())) &&
4473 !instr->valu().omod && ctx.program->gfx_level >= GFX10 &&
4474 util_bitcount(fp16_mask) > std::max<uint32_t>(util_bitcount(literal_mask), 1)) {
4475 assert(ctx.program->dev.fused_mad_mix);
4476 u_foreach_bit (i, fp16_mask)
4477 ctx.uses[instr->operands[i].tempId()]--;
4478 mad_info->fp16_mask = fp16_mask;
4479 return;
4480 }
4481
4482 /* Limit the number of literals to apply to not increase the code
4483 * size too much, but always apply literals for v_mad->v_madak
4484 * because both instructions are 64-bit and this doesn't increase
4485 * code size.
4486 * TODO: try to apply the literals earlier to lower the number of
4487 * uses below threshold
4488 */
4489 if (literal_mask && (literal_uses < threshold || (literal_mask & 0b100))) {
4490 u_foreach_bit (i, literal_mask)
4491 ctx.uses[instr->operands[i].tempId()]--;
4492 mad_info->literal_mask = literal_mask;
4493 return;
4494 }
4495 }
4496 }
4497
4498 /* Mark SCC needed, so the uniform boolean transformation won't swap the definitions
4499 * when it isn't beneficial */
4500 if (instr->isBranch() && instr->operands.size() && instr->operands[0].isTemp() &&
4501 instr->operands[0].isFixed() && instr->operands[0].physReg() == scc) {
4502 ctx.info[instr->operands[0].tempId()].set_scc_needed();
4503 return;
4504 } else if ((instr->opcode == aco_opcode::s_cselect_b64 ||
4505 instr->opcode == aco_opcode::s_cselect_b32) &&
4506 instr->operands[2].isTemp()) {
4507 ctx.info[instr->operands[2].tempId()].set_scc_needed();
4508 }
4509
4510 /* check for literals */
4511 if (!instr->isSALU() && !instr->isVALU())
4512 return;
4513
4514 /* Transform uniform bitwise boolean operations to 32-bit when there are no divergent uses. */
4515 if (instr->definitions.size() && ctx.uses[instr->definitions[0].tempId()] == 0 &&
4516 ctx.info[instr->definitions[0].tempId()].is_uniform_bitwise()) {
4517 bool transform_done = to_uniform_bool_instr(ctx, instr);
4518
4519 if (transform_done && !ctx.info[instr->definitions[1].tempId()].is_scc_needed()) {
4520 /* Swap the two definition IDs in order to avoid overusing the SCC.
4521 * This reduces extra moves generated by RA. */
4522 uint32_t def0_id = instr->definitions[0].getTemp().id();
4523 uint32_t def1_id = instr->definitions[1].getTemp().id();
4524 instr->definitions[0].setTemp(Temp(def1_id, s1));
4525 instr->definitions[1].setTemp(Temp(def0_id, s1));
4526 }
4527
4528 return;
4529 }
4530
4531 /* This optimization is done late in order to be able to apply otherwise
4532 * unsafe optimizations such as the inverse comparison optimization.
4533 */
4534 if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_and_b64) {
4535 if (instr->operands[0].isTemp() && fixed_to_exec(instr->operands[1]) &&
4536 ctx.uses[instr->operands[0].tempId()] == 1 &&
4537 ctx.uses[instr->definitions[1].tempId()] == 0 &&
4538 can_eliminate_and_exec(ctx, instr->operands[0].getTemp(), instr->pass_flags)) {
4539 ctx.uses[instr->operands[0].tempId()]--;
4540 ctx.info[instr->operands[0].tempId()].instr->definitions[0].setTemp(
4541 instr->definitions[0].getTemp());
4542 instr.reset();
4543 return;
4544 }
4545 }
4546
4547 /* Combine DPP copies into VALU. This should be done after creating MAD/FMA. */
4548 if (instr->isVALU() && !instr->isDPP()) {
4549 for (unsigned i = 0; i < instr->operands.size(); i++) {
4550 if (!instr->operands[i].isTemp())
4551 continue;
4552 ssa_info info = ctx.info[instr->operands[i].tempId()];
4553
4554 if (!info.is_dpp() || info.instr->pass_flags != instr->pass_flags)
4555 continue;
4556
4557 /* We won't eliminate the DPP mov if the operand is used twice */
4558 bool op_used_twice = false;
4559 for (unsigned j = 0; j < instr->operands.size(); j++)
4560 op_used_twice |= i != j && instr->operands[i] == instr->operands[j];
4561 if (op_used_twice)
4562 continue;
4563
4564 if (i != 0) {
4565 if (!can_swap_operands(instr, &instr->opcode, 0, i))
4566 continue;
4567 instr->valu().swapOperands(0, i);
4568 }
4569
4570 if (!can_use_DPP(ctx.program->gfx_level, instr, info.is_dpp8()))
4571 continue;
4572
4573 bool dpp8 = info.is_dpp8();
4574 bool input_mods = can_use_input_modifiers(ctx.program->gfx_level, instr->opcode, 0) &&
4575 get_operand_size(instr, 0) == 32;
4576 bool mov_uses_mods = info.instr->valu().neg[0] || info.instr->valu().abs[0];
4577 if (((dpp8 && ctx.program->gfx_level < GFX11) || !input_mods) && mov_uses_mods)
4578 continue;
4579
4580 convert_to_DPP(ctx.program->gfx_level, instr, dpp8);
4581
4582 if (dpp8) {
4583 DPP8_instruction* dpp = &instr->dpp8();
4584 dpp->lane_sel = info.instr->dpp8().lane_sel;
4585 dpp->fetch_inactive = info.instr->dpp8().fetch_inactive;
4586 if (mov_uses_mods)
4587 instr->format = asVOP3(instr->format);
4588 } else {
4589 DPP16_instruction* dpp = &instr->dpp16();
4590 dpp->dpp_ctrl = info.instr->dpp16().dpp_ctrl;
4591 dpp->bound_ctrl = info.instr->dpp16().bound_ctrl;
4592 dpp->fetch_inactive = info.instr->dpp16().fetch_inactive;
4593 }
4594
4595 instr->valu().neg[0] ^= info.instr->valu().neg[0] && !instr->valu().abs[0];
4596 instr->valu().abs[0] |= info.instr->valu().abs[0];
4597
4598 if (--ctx.uses[info.instr->definitions[0].tempId()])
4599 ctx.uses[info.instr->operands[0].tempId()]++;
4600 instr->operands[0].setTemp(info.instr->operands[0].getTemp());
4601 break;
4602 }
4603 }
4604
4605 /* Use v_fma_mix for f2f32/f2f16 if it has higher throughput.
4606 * Do this late to not disturb other optimizations.
4607 */
4608 if ((instr->opcode == aco_opcode::v_cvt_f32_f16 || instr->opcode == aco_opcode::v_cvt_f16_f32) &&
4609 ctx.program->gfx_level >= GFX11 && ctx.program->wave_size == 64 && !instr->valu().omod &&
4610 !instr->isDPP()) {
4611 bool is_f2f16 = instr->opcode == aco_opcode::v_cvt_f16_f32;
4612 Instruction* fma = create_instruction(
4613 is_f2f16 ? aco_opcode::v_fma_mixlo_f16 : aco_opcode::v_fma_mix_f32, Format::VOP3P, 3, 1);
4614 fma->definitions[0] = instr->definitions[0];
4615 fma->operands[0] = instr->operands[0];
4616 fma->valu().opsel_hi[0] = !is_f2f16;
4617 fma->valu().opsel_lo[0] = instr->valu().opsel[0];
4618 fma->valu().clamp = instr->valu().clamp;
4619 fma->valu().abs[0] = instr->valu().abs[0];
4620 fma->valu().neg[0] = instr->valu().neg[0];
4621 fma->operands[1] = Operand::c32(fui(1.0f));
4622 fma->operands[2] = Operand::zero();
4623 fma->valu().neg[2] = true;
4624 instr.reset(fma);
4625 ctx.info[instr->definitions[0].tempId()].label = 0;
4626 }
4627
4628 if (instr->isSDWA() || (instr->isVOP3() && ctx.program->gfx_level < GFX10) ||
4629 (instr->isVOP3P() && ctx.program->gfx_level < GFX10))
4630 return; /* some encodings can't ever take literals */
4631
4632 /* we do not apply the literals yet as we don't know if it is profitable */
4633 Operand current_literal(s1);
4634
4635 unsigned literal_id = 0;
4636 unsigned literal_uses = UINT32_MAX;
4637 Operand literal(s1);
4638 unsigned num_operands = 1;
4639 if (instr->isSALU() || (ctx.program->gfx_level >= GFX10 &&
4640 (can_use_VOP3(ctx, instr) || instr->isVOP3P()) && !instr->isDPP()))
4641 num_operands = instr->operands.size();
4642 /* catch VOP2 with a 3rd SGPR operand (e.g. v_cndmask_b32, v_addc_co_u32) */
4643 else if (instr->isVALU() && instr->operands.size() >= 3)
4644 return;
4645
4646 unsigned sgpr_ids[2] = {0, 0};
4647 bool is_literal_sgpr = false;
4648 uint32_t mask = 0;
4649
4650 /* choose a literal to apply */
4651 for (unsigned i = 0; i < num_operands; i++) {
4652 Operand op = instr->operands[i];
4653 unsigned bits = get_operand_size(instr, i);
4654
4655 if (instr->isVALU() && op.isTemp() && op.getTemp().type() == RegType::sgpr &&
4656 op.tempId() != sgpr_ids[0])
4657 sgpr_ids[!!sgpr_ids[0]] = op.tempId();
4658
4659 if (op.isLiteral()) {
4660 current_literal = op;
4661 continue;
4662 } else if (!op.isTemp() || !ctx.info[op.tempId()].is_literal(bits)) {
4663 continue;
4664 }
4665
4666 if (!alu_can_accept_constant(instr, i))
4667 continue;
4668
4669 if (ctx.uses[op.tempId()] < literal_uses) {
4670 is_literal_sgpr = op.getTemp().type() == RegType::sgpr;
4671 mask = 0;
4672 literal = Operand::c32(ctx.info[op.tempId()].val);
4673 literal_uses = ctx.uses[op.tempId()];
4674 literal_id = op.tempId();
4675 }
4676
4677 mask |= (op.tempId() == literal_id) << i;
4678 }
4679
4680 /* don't go over the constant bus limit */
4681 bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64_e64 ||
4682 instr->opcode == aco_opcode::v_lshlrev_b64 ||
4683 instr->opcode == aco_opcode::v_lshrrev_b64 ||
4684 instr->opcode == aco_opcode::v_ashrrev_i64;
4685 unsigned const_bus_limit = instr->isVALU() ? 1 : UINT32_MAX;
4686 if (ctx.program->gfx_level >= GFX10 && !is_shift64)
4687 const_bus_limit = 2;
4688
4689 unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
4690 if (num_sgprs == const_bus_limit && !is_literal_sgpr)
4691 return;
4692
4693 if (literal_id && literal_uses < threshold &&
4694 (current_literal.isUndefined() ||
4695 (current_literal.size() == literal.size() &&
4696 current_literal.constantValue() == literal.constantValue()))) {
4697 /* mark the literal to be applied */
4698 while (mask) {
4699 unsigned i = u_bit_scan(&mask);
4700 if (instr->operands[i].isTemp() && instr->operands[i].tempId() == literal_id)
4701 ctx.uses[instr->operands[i].tempId()]--;
4702 }
4703 }
4704 }
4705
4706 static aco_opcode
sopk_opcode_for_sopc(aco_opcode opcode)4707 sopk_opcode_for_sopc(aco_opcode opcode)
4708 {
4709 #define CTOK(op) \
4710 case aco_opcode::s_cmp_##op##_i32: return aco_opcode::s_cmpk_##op##_i32; \
4711 case aco_opcode::s_cmp_##op##_u32: return aco_opcode::s_cmpk_##op##_u32;
4712 switch (opcode) {
4713 CTOK(eq)
4714 CTOK(lg)
4715 CTOK(gt)
4716 CTOK(ge)
4717 CTOK(lt)
4718 CTOK(le)
4719 default: return aco_opcode::num_opcodes;
4720 }
4721 #undef CTOK
4722 }
4723
4724 static bool
sopc_is_signed(aco_opcode opcode)4725 sopc_is_signed(aco_opcode opcode)
4726 {
4727 #define SOPC(op) \
4728 case aco_opcode::s_cmp_##op##_i32: return true; \
4729 case aco_opcode::s_cmp_##op##_u32: return false;
4730 switch (opcode) {
4731 SOPC(eq)
4732 SOPC(lg)
4733 SOPC(gt)
4734 SOPC(ge)
4735 SOPC(lt)
4736 SOPC(le)
4737 default: unreachable("Not a valid SOPC instruction.");
4738 }
4739 #undef SOPC
4740 }
4741
4742 static aco_opcode
sopc_32_swapped(aco_opcode opcode)4743 sopc_32_swapped(aco_opcode opcode)
4744 {
4745 #define SOPC(op1, op2) \
4746 case aco_opcode::s_cmp_##op1##_i32: return aco_opcode::s_cmp_##op2##_i32; \
4747 case aco_opcode::s_cmp_##op1##_u32: return aco_opcode::s_cmp_##op2##_u32;
4748 switch (opcode) {
4749 SOPC(eq, eq)
4750 SOPC(lg, lg)
4751 SOPC(gt, lt)
4752 SOPC(ge, le)
4753 SOPC(lt, gt)
4754 SOPC(le, ge)
4755 default: return aco_opcode::num_opcodes;
4756 }
4757 #undef SOPC
4758 }
4759
4760 static void
try_convert_sopc_to_sopk(aco_ptr<Instruction> & instr)4761 try_convert_sopc_to_sopk(aco_ptr<Instruction>& instr)
4762 {
4763 if (sopk_opcode_for_sopc(instr->opcode) == aco_opcode::num_opcodes)
4764 return;
4765
4766 if (instr->operands[0].isLiteral()) {
4767 std::swap(instr->operands[0], instr->operands[1]);
4768 instr->opcode = sopc_32_swapped(instr->opcode);
4769 }
4770
4771 if (!instr->operands[1].isLiteral())
4772 return;
4773
4774 if (instr->operands[0].isFixed() && instr->operands[0].physReg() >= 128)
4775 return;
4776
4777 uint32_t value = instr->operands[1].constantValue();
4778
4779 const uint32_t i16_mask = 0xffff8000u;
4780
4781 bool value_is_i16 = (value & i16_mask) == 0 || (value & i16_mask) == i16_mask;
4782 bool value_is_u16 = !(value & 0xffff0000u);
4783
4784 if (!value_is_i16 && !value_is_u16)
4785 return;
4786
4787 if (!value_is_i16 && sopc_is_signed(instr->opcode)) {
4788 if (instr->opcode == aco_opcode::s_cmp_lg_i32)
4789 instr->opcode = aco_opcode::s_cmp_lg_u32;
4790 else if (instr->opcode == aco_opcode::s_cmp_eq_i32)
4791 instr->opcode = aco_opcode::s_cmp_eq_u32;
4792 else
4793 return;
4794 } else if (!value_is_u16 && !sopc_is_signed(instr->opcode)) {
4795 if (instr->opcode == aco_opcode::s_cmp_lg_u32)
4796 instr->opcode = aco_opcode::s_cmp_lg_i32;
4797 else if (instr->opcode == aco_opcode::s_cmp_eq_u32)
4798 instr->opcode = aco_opcode::s_cmp_eq_i32;
4799 else
4800 return;
4801 }
4802
4803 instr->format = Format::SOPK;
4804 SALU_instruction* instr_sopk = &instr->salu();
4805
4806 instr_sopk->imm = instr_sopk->operands[1].constantValue() & 0xffff;
4807 instr_sopk->opcode = sopk_opcode_for_sopc(instr_sopk->opcode);
4808 instr_sopk->operands.pop_back();
4809 }
4810
4811 static void
opt_fma_mix_acc(opt_ctx & ctx,aco_ptr<Instruction> & instr)4812 opt_fma_mix_acc(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4813 {
4814 /* fma_mix is only dual issued on gfx11 if dst and acc type match */
4815 bool f2f16 = instr->opcode == aco_opcode::v_fma_mixlo_f16;
4816
4817 if (instr->valu().opsel_hi[2] == f2f16 || instr->isDPP())
4818 return;
4819
4820 bool is_add = false;
4821 for (unsigned i = 0; i < 2; i++) {
4822 uint32_t one = instr->valu().opsel_hi[i] ? 0x3800 : 0x3f800000;
4823 is_add = instr->operands[i].constantEquals(one) && !instr->valu().neg[i] &&
4824 !instr->valu().opsel_lo[i];
4825 if (is_add) {
4826 instr->valu().swapOperands(0, i);
4827 break;
4828 }
4829 }
4830
4831 if (is_add && instr->valu().opsel_hi[1] == f2f16) {
4832 instr->valu().swapOperands(1, 2);
4833 return;
4834 }
4835
4836 unsigned literal_count = instr->operands[0].isLiteral() + instr->operands[1].isLiteral() +
4837 instr->operands[2].isLiteral();
4838
4839 if (!f2f16 || literal_count > 1)
4840 return;
4841
4842 /* try to convert constant operand to fp16 */
4843 for (unsigned i = 2 - is_add; i < 3; i++) {
4844 if (!instr->operands[i].isConstant())
4845 continue;
4846
4847 float value = uif(instr->operands[i].constantValue());
4848 uint16_t fp16_val = _mesa_float_to_half(value);
4849 bool is_denorm = (fp16_val & 0x7fff) != 0 && (fp16_val & 0x7fff) <= 0x3ff;
4850
4851 if (_mesa_half_to_float(fp16_val) != value ||
4852 (is_denorm && !(ctx.fp_mode.denorm16_64 & fp_denorm_keep_in)))
4853 continue;
4854
4855 instr->valu().swapOperands(i, 2);
4856
4857 Operand op16 = Operand::c16(fp16_val);
4858 assert(!op16.isLiteral() || instr->operands[2].isLiteral());
4859
4860 instr->operands[2] = op16;
4861 instr->valu().opsel_lo[2] = false;
4862 instr->valu().opsel_hi[2] = true;
4863 return;
4864 }
4865 }
4866
4867 void
apply_literals(opt_ctx & ctx,aco_ptr<Instruction> & instr)4868 apply_literals(opt_ctx& ctx, aco_ptr<Instruction>& instr)
4869 {
4870 /* Cleanup Dead Instructions */
4871 if (!instr)
4872 return;
4873
4874 /* apply literals on MAD */
4875 if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
4876 mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
4877 const bool madak = (info->literal_mask & 0b100);
4878 bool has_dead_literal = false;
4879 u_foreach_bit (i, info->literal_mask | info->fp16_mask)
4880 has_dead_literal |= ctx.uses[instr->operands[i].tempId()] == 0;
4881
4882 if (has_dead_literal && info->fp16_mask) {
4883 instr->format = Format::VOP3P;
4884 instr->opcode = aco_opcode::v_fma_mix_f32;
4885
4886 uint32_t literal = 0;
4887 bool second = false;
4888 u_foreach_bit (i, info->fp16_mask) {
4889 float value = uif(ctx.info[instr->operands[i].tempId()].val);
4890 literal |= _mesa_float_to_half(value) << (second * 16);
4891 instr->valu().opsel_lo[i] = second;
4892 instr->valu().opsel_hi[i] = true;
4893 second = true;
4894 }
4895
4896 for (unsigned i = 0; i < 3; i++) {
4897 if (info->fp16_mask & (1 << i))
4898 instr->operands[i] = Operand::literal32(literal);
4899 }
4900
4901 ctx.instructions.emplace_back(std::move(instr));
4902 return;
4903 }
4904
4905 if (has_dead_literal || madak) {
4906 aco_opcode new_op = madak ? aco_opcode::v_madak_f32 : aco_opcode::v_madmk_f32;
4907 if (instr->opcode == aco_opcode::v_fma_f32)
4908 new_op = madak ? aco_opcode::v_fmaak_f32 : aco_opcode::v_fmamk_f32;
4909 else if (instr->opcode == aco_opcode::v_mad_f16 ||
4910 instr->opcode == aco_opcode::v_mad_legacy_f16)
4911 new_op = madak ? aco_opcode::v_madak_f16 : aco_opcode::v_madmk_f16;
4912 else if (instr->opcode == aco_opcode::v_fma_f16)
4913 new_op = madak ? aco_opcode::v_fmaak_f16 : aco_opcode::v_fmamk_f16;
4914
4915 uint32_t literal = ctx.info[instr->operands[ffs(info->literal_mask) - 1].tempId()].val;
4916 instr->format = Format::VOP2;
4917 instr->opcode = new_op;
4918 for (unsigned i = 0; i < 3; i++) {
4919 if (info->literal_mask & (1 << i))
4920 instr->operands[i] = Operand::literal32(literal);
4921 }
4922 if (madak) { /* add literal -> madak */
4923 if (!instr->operands[1].isOfType(RegType::vgpr))
4924 instr->valu().swapOperands(0, 1);
4925 } else { /* mul literal -> madmk */
4926 if (!(info->literal_mask & 0b10))
4927 instr->valu().swapOperands(0, 1);
4928 instr->valu().swapOperands(1, 2);
4929 }
4930 ctx.instructions.emplace_back(std::move(instr));
4931 return;
4932 }
4933 }
4934
4935 /* apply literals on other SALU/VALU */
4936 if (instr->isSALU() || instr->isVALU()) {
4937 for (unsigned i = 0; i < instr->operands.size(); i++) {
4938 Operand op = instr->operands[i];
4939 unsigned bits = get_operand_size(instr, i);
4940 if (op.isTemp() && ctx.info[op.tempId()].is_literal(bits) && ctx.uses[op.tempId()] == 0) {
4941 Operand literal = Operand::literal32(ctx.info[op.tempId()].val);
4942 instr->format = withoutDPP(instr->format);
4943 if (instr->isVALU() && i > 0 && instr->format != Format::VOP3P)
4944 instr->format = asVOP3(instr->format);
4945 instr->operands[i] = literal;
4946 }
4947 }
4948 }
4949
4950 if (instr->isSOPC() && ctx.program->gfx_level < GFX12)
4951 try_convert_sopc_to_sopk(instr);
4952
4953 if (instr->opcode == aco_opcode::v_fma_mixlo_f16 || instr->opcode == aco_opcode::v_fma_mix_f32)
4954 opt_fma_mix_acc(ctx, instr);
4955
4956 ctx.instructions.emplace_back(std::move(instr));
4957 }
4958
4959 } /* end namespace */
4960
4961 void
optimize(Program * program)4962 optimize(Program* program)
4963 {
4964 opt_ctx ctx;
4965 ctx.program = program;
4966 ctx.info = std::vector<ssa_info>(program->peekAllocationId());
4967
4968 /* 1. Bottom-Up DAG pass (forward) to label all ssa-defs */
4969 for (Block& block : program->blocks) {
4970 ctx.fp_mode = block.fp_mode;
4971 for (aco_ptr<Instruction>& instr : block.instructions)
4972 label_instruction(ctx, instr);
4973 }
4974
4975 ctx.uses = dead_code_analysis(program);
4976
4977 /* 2. Rematerialize constants in every block. */
4978 rematerialize_constants(ctx);
4979
4980 /* 3. Combine v_mad, omod, clamp and propagate sgpr on VALU instructions */
4981 for (Block& block : program->blocks) {
4982 ctx.fp_mode = block.fp_mode;
4983 for (aco_ptr<Instruction>& instr : block.instructions)
4984 combine_instruction(ctx, instr);
4985 }
4986
4987 /* 4. Top-Down DAG pass (backward) to select instructions (includes DCE) */
4988 for (auto block_rit = program->blocks.rbegin(); block_rit != program->blocks.rend();
4989 ++block_rit) {
4990 Block* block = &(*block_rit);
4991 ctx.fp_mode = block->fp_mode;
4992 for (auto instr_rit = block->instructions.rbegin(); instr_rit != block->instructions.rend();
4993 ++instr_rit)
4994 select_instruction(ctx, *instr_rit);
4995 }
4996
4997 /* 5. Add literals to instructions */
4998 for (Block& block : program->blocks) {
4999 ctx.instructions.reserve(block.instructions.size());
5000 ctx.fp_mode = block.fp_mode;
5001 for (aco_ptr<Instruction>& instr : block.instructions)
5002 apply_literals(ctx, instr);
5003 block.instructions = std::move(ctx.instructions);
5004 }
5005 }
5006
5007 } // namespace aco
5008