xref: /aosp_15_r20/external/mesa3d/src/amd/compiler/aco_lower_to_cssa.cpp (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2019 Valve Corporation
3  *
4  * SPDX-License-Identifier: MIT
5  */
6 
7 #include "aco_builder.h"
8 #include "aco_ir.h"
9 
10 #include <algorithm>
11 #include <map>
12 #include <unordered_map>
13 #include <vector>
14 
15 /*
16  * Implements an algorithm to lower to Conventional SSA Form (CSSA).
17  * After "Revisiting Out-of-SSA Translation for Correctness, CodeQuality, and Efficiency"
18  * by B. Boissinot, A. Darte, F. Rastello, B. Dupont de Dinechin, C. Guillon,
19  *
20  * By lowering the IR to CSSA, the insertion of parallelcopies is separated from
21  * the register coalescing problem. Additionally, correctness is ensured w.r.t. spilling.
22  * The algorithm coalesces non-interfering phi-resources while taking value-equality
23  * into account. Re-indexes the SSA-defs.
24  */
25 
26 namespace aco {
27 namespace {
28 
29 typedef std::vector<Temp> merge_set;
30 
31 struct copy {
32    Definition def;
33    Operand op;
34 };
35 
36 struct merge_node {
37    Operand value = Operand(); /* original value: can be an SSA-def or constant value */
38    uint32_t index = -1u;      /* index into the vector of merge sets */
39    uint32_t defined_at = -1u; /* defining block */
40 
41    /* We also remember two closest equal intersecting ancestors. Because they intersect with this
42     * merge node, they must dominate it (intersection isn't possible otherwise) and have the same
43     * value (or else they would not be allowed to be in the same merge set).
44     */
45    Temp equal_anc_in = Temp();  /* within the same merge set */
46    Temp equal_anc_out = Temp(); /* from the other set we're currently trying to merge with */
47 };
48 
49 struct cssa_ctx {
50    Program* program;
51    std::vector<std::vector<copy>> parallelcopies; /* copies per block */
52    std::vector<merge_set> merge_sets;             /* each vector is one (ordered) merge set */
53    std::unordered_map<uint32_t, merge_node> merge_node_table; /* tempid -> merge node */
54 };
55 
56 /* create (virtual) parallelcopies for each phi instruction and
57  * already merge copy-definitions with phi-defs into merge sets */
58 void
collect_parallelcopies(cssa_ctx & ctx)59 collect_parallelcopies(cssa_ctx& ctx)
60 {
61    ctx.parallelcopies.resize(ctx.program->blocks.size());
62    Builder bld(ctx.program);
63    for (Block& block : ctx.program->blocks) {
64       for (aco_ptr<Instruction>& phi : block.instructions) {
65          if (phi->opcode != aco_opcode::p_phi && phi->opcode != aco_opcode::p_linear_phi)
66             break;
67 
68          const Definition& def = phi->definitions[0];
69 
70          /* if the definition is not temp, it is the exec mask.
71           * We can reload the exec mask directly from the spill slot.
72           */
73          if (!def.isTemp() || def.isKill())
74             continue;
75 
76          Block::edge_vec& preds =
77             phi->opcode == aco_opcode::p_phi ? block.logical_preds : block.linear_preds;
78          uint32_t index = ctx.merge_sets.size();
79          merge_set set;
80 
81          bool has_preheader_copy = false;
82          for (unsigned i = 0; i < phi->operands.size(); i++) {
83             Operand op = phi->operands[i];
84             if (op.isUndefined())
85                continue;
86 
87             if (def.regClass().type() == RegType::sgpr && !op.isTemp()) {
88                /* SGPR inline constants and literals on GFX10+ can be spilled
89                 * and reloaded directly (without intermediate register) */
90                if (op.isConstant()) {
91                   if (ctx.program->gfx_level >= GFX10)
92                      continue;
93                   if (op.size() == 1 && !op.isLiteral())
94                      continue;
95                } else {
96                   assert(op.isFixed() && op.physReg() == exec);
97                   continue;
98                }
99             }
100 
101             /* create new temporary and rename operands */
102             Temp tmp = bld.tmp(def.regClass());
103             ctx.parallelcopies[preds[i]].emplace_back(copy{Definition(tmp), op});
104             phi->operands[i] = Operand(tmp);
105             phi->operands[i].setKill(true);
106 
107             /* place the new operands in the same merge set */
108             set.emplace_back(tmp);
109             ctx.merge_node_table[tmp.id()] = {op, index, preds[i]};
110 
111             has_preheader_copy |= i == 0 && block.kind & block_kind_loop_header;
112          }
113 
114          if (set.empty())
115             continue;
116 
117          /* place the definition in dominance-order */
118          if (def.isTemp()) {
119             if (has_preheader_copy)
120                set.emplace(std::next(set.begin()), def.getTemp());
121             else if (block.kind & block_kind_loop_header)
122                set.emplace(set.begin(), def.getTemp());
123             else
124                set.emplace_back(def.getTemp());
125             ctx.merge_node_table[def.tempId()] = {Operand(def.getTemp()), index, block.index};
126          }
127          ctx.merge_sets.emplace_back(set);
128       }
129    }
130 }
131 
132 /* check whether the definition of a comes after b. */
133 inline bool
defined_after(cssa_ctx & ctx,Temp a,Temp b)134 defined_after(cssa_ctx& ctx, Temp a, Temp b)
135 {
136    merge_node& node_a = ctx.merge_node_table[a.id()];
137    merge_node& node_b = ctx.merge_node_table[b.id()];
138    if (node_a.defined_at == node_b.defined_at)
139       return a.id() > b.id();
140 
141    return node_a.defined_at > node_b.defined_at;
142 }
143 
144 /* check whether a dominates b where b is defined after a */
145 inline bool
dominates(cssa_ctx & ctx,Temp a,Temp b)146 dominates(cssa_ctx& ctx, Temp a, Temp b)
147 {
148    assert(defined_after(ctx, b, a));
149    Block& parent = ctx.program->blocks[ctx.merge_node_table[a.id()].defined_at];
150    Block& child = ctx.program->blocks[ctx.merge_node_table[b.id()].defined_at];
151    if (b.regClass().type() == RegType::vgpr)
152       return dominates_logical(parent, child);
153    else
154       return dominates_linear(parent, child);
155 }
156 
157 /* Checks whether some variable is live-out, not considering any phi-uses. */
158 inline bool
is_live_out(cssa_ctx & ctx,Temp var,uint32_t block_idx)159 is_live_out(cssa_ctx& ctx, Temp var, uint32_t block_idx)
160 {
161    Block::edge_vec& succs = var.is_linear() ? ctx.program->blocks[block_idx].linear_succs
162                                             : ctx.program->blocks[block_idx].logical_succs;
163 
164    return std::any_of(succs.begin(), succs.end(), [&](unsigned succ)
165                       { return ctx.program->live.live_in[succ].count(var.id()); });
166 }
167 
168 /* check intersection between var and parent:
169  * We already know that parent dominates var. */
170 inline bool
intersects(cssa_ctx & ctx,Temp var,Temp parent)171 intersects(cssa_ctx& ctx, Temp var, Temp parent)
172 {
173    merge_node& node_var = ctx.merge_node_table[var.id()];
174    merge_node& node_parent = ctx.merge_node_table[parent.id()];
175    assert(node_var.index != node_parent.index);
176    uint32_t block_idx = node_var.defined_at;
177 
178    /* if parent is defined in a different block than var */
179    if (node_parent.defined_at < node_var.defined_at) {
180       /* if the parent is not live-in, they don't interfere */
181       if (!ctx.program->live.live_in[block_idx].count(parent.id()))
182          return false;
183    }
184 
185    /* if the parent is live-out at the definition block of var, they intersect */
186    bool parent_live = is_live_out(ctx, parent, block_idx);
187    if (parent_live)
188       return true;
189 
190    for (const copy& cp : ctx.parallelcopies[block_idx]) {
191       /* if var is defined at the edge, they don't intersect */
192       if (cp.def.getTemp() == var)
193          return false;
194       if (cp.op.isTemp() && cp.op.getTemp() == parent)
195          parent_live = true;
196    }
197    /* if the parent is live at the edge, they intersect */
198    if (parent_live)
199       return true;
200 
201    /* both, parent and var, are present in the same block */
202    const Block& block = ctx.program->blocks[block_idx];
203    for (auto it = block.instructions.crbegin(); it != block.instructions.crend(); ++it) {
204       /* if the parent was not encountered yet, it can only be used by a phi */
205       if (is_phi(it->get()))
206          break;
207 
208       for (const Definition& def : (*it)->definitions) {
209          if (!def.isTemp())
210             continue;
211          /* if parent was not found yet, they don't intersect */
212          if (def.getTemp() == var)
213             return false;
214       }
215 
216       for (const Operand& op : (*it)->operands) {
217          if (!op.isTemp())
218             continue;
219          /* if the var was defined before this point, they intersect */
220          if (op.getTemp() == parent)
221             return true;
222       }
223    }
224 
225    return false;
226 }
227 
228 /* check interference between var and parent:
229  * i.e. they have different values and intersect.
230  * If parent and var intersect and share the same value, also updates the equal ancestor. */
231 inline bool
interference(cssa_ctx & ctx,Temp var,Temp parent)232 interference(cssa_ctx& ctx, Temp var, Temp parent)
233 {
234    assert(var != parent);
235    merge_node& node_var = ctx.merge_node_table[var.id()];
236    node_var.equal_anc_out = Temp();
237 
238    if (node_var.index == ctx.merge_node_table[parent.id()].index) {
239       /* Check/update in other set. equal_anc_out is only present if it intersects with 'parent',
240        * but that's fine since it has to for it to intersect with 'var'. */
241       parent = ctx.merge_node_table[parent.id()].equal_anc_out;
242    }
243 
244    Temp tmp = parent;
245    /* Check if 'var' intersects with 'parent' or any ancestors which might intersect too. */
246    while (tmp != Temp() && !intersects(ctx, var, tmp)) {
247       merge_node& node_tmp = ctx.merge_node_table[tmp.id()];
248       tmp = node_tmp.equal_anc_in;
249    }
250 
251    /* no intersection found */
252    if (tmp == Temp())
253       return false;
254 
255    /* var and parent, same value and intersect, but in different sets */
256    if (node_var.value == ctx.merge_node_table[parent.id()].value) {
257       node_var.equal_anc_out = tmp;
258       return false;
259    }
260 
261    /* var and parent, different values and intersect */
262    return true;
263 }
264 
265 /* tries to merge set_b into set_a of given temporary and
266  * drops that temporary as it is being coalesced */
267 bool
try_merge_merge_set(cssa_ctx & ctx,Temp dst,merge_set & set_b)268 try_merge_merge_set(cssa_ctx& ctx, Temp dst, merge_set& set_b)
269 {
270    auto def_node_it = ctx.merge_node_table.find(dst.id());
271    uint32_t index = def_node_it->second.index;
272    merge_set& set_a = ctx.merge_sets[index];
273    std::vector<Temp> dom; /* stack of the traversal */
274    merge_set union_set;   /* the new merged merge-set */
275    uint32_t i_a = 0;
276    uint32_t i_b = 0;
277 
278    while (i_a < set_a.size() || i_b < set_b.size()) {
279       Temp current;
280       if (i_a == set_a.size())
281          current = set_b[i_b++];
282       else if (i_b == set_b.size())
283          current = set_a[i_a++];
284       /* else pick the one defined first */
285       else if (defined_after(ctx, set_a[i_a], set_b[i_b]))
286          current = set_b[i_b++];
287       else
288          current = set_a[i_a++];
289 
290       while (!dom.empty() && !dominates(ctx, dom.back(), current))
291          dom.pop_back(); /* not the desired parent, remove */
292 
293       if (!dom.empty() && interference(ctx, current, dom.back())) {
294          for (Temp t : union_set)
295             ctx.merge_node_table[t.id()].equal_anc_out = Temp();
296          return false; /* intersection detected */
297       }
298 
299       dom.emplace_back(current); /* otherwise, keep checking */
300       if (current != dst)
301          union_set.emplace_back(current); /* maintain the new merge-set sorted */
302    }
303 
304    /* update hashmap */
305    for (Temp t : union_set) {
306       merge_node& node = ctx.merge_node_table[t.id()];
307       /* update the equal ancestors:
308        * i.e. the 'closest' dominating def which intersects */
309       Temp in = node.equal_anc_in;
310       Temp out = node.equal_anc_out;
311       if (in == Temp() || (out != Temp() && defined_after(ctx, out, in)))
312          node.equal_anc_in = out;
313       node.equal_anc_out = Temp();
314       /* update merge-set index */
315       node.index = index;
316    }
317    set_b = merge_set(); /* free the old set_b */
318    ctx.merge_sets[index] = union_set;
319    ctx.merge_node_table.erase(dst.id()); /* remove the temporary */
320 
321    return true;
322 }
323 
324 /* returns true if the copy can safely be omitted */
325 bool
try_coalesce_copy(cssa_ctx & ctx,copy copy,uint32_t block_idx)326 try_coalesce_copy(cssa_ctx& ctx, copy copy, uint32_t block_idx)
327 {
328    /* we can only coalesce temporaries */
329    if (!copy.op.isTemp() || !copy.op.isKill())
330       return false;
331 
332    /* we can only coalesce copies of the same register class */
333    if (copy.op.regClass() != copy.def.regClass())
334       return false;
335 
336    /* try emplace a merge_node for the copy operand */
337    merge_node& op_node = ctx.merge_node_table[copy.op.tempId()];
338    if (op_node.defined_at == -1u) {
339       /* find defining block of operand */
340       while (ctx.program->live.live_in[block_idx].count(copy.op.tempId()))
341          block_idx = copy.op.regClass().type() == RegType::vgpr
342                         ? ctx.program->blocks[block_idx].logical_idom
343                         : ctx.program->blocks[block_idx].linear_idom;
344       op_node.defined_at = block_idx;
345       op_node.value = copy.op;
346    }
347 
348    /* check if this operand has not yet been coalesced */
349    if (op_node.index == -1u) {
350       merge_set op_set = merge_set{copy.op.getTemp()};
351       return try_merge_merge_set(ctx, copy.def.getTemp(), op_set);
352    }
353 
354    /* check if this operand has been coalesced into the same set */
355    assert(ctx.merge_node_table.count(copy.def.tempId()));
356    if (op_node.index == ctx.merge_node_table[copy.def.tempId()].index)
357       return true;
358 
359    /* otherwise, try to coalesce both merge sets */
360    return try_merge_merge_set(ctx, copy.def.getTemp(), ctx.merge_sets[op_node.index]);
361 }
362 
363 /* node in the location-transfer-graph */
364 struct ltg_node {
365    copy* cp;
366    uint32_t read_idx;
367    uint32_t num_uses = 0;
368 };
369 
370 /* emit the copies in an order that does not
371  * create interferences within a merge-set */
372 void
emit_copies_block(Builder & bld,std::map<uint32_t,ltg_node> & ltg,RegType type)373 emit_copies_block(Builder& bld, std::map<uint32_t, ltg_node>& ltg, RegType type)
374 {
375    RegisterDemand live_changes;
376    RegisterDemand reg_demand = bld.it->get()->register_demand - get_temp_registers(bld.it->get()) -
377                                get_live_changes(bld.it->get());
378    auto&& it = ltg.begin();
379    while (it != ltg.end()) {
380       copy& cp = *it->second.cp;
381 
382       /* wrong regclass or still needed as operand */
383       if (cp.def.regClass().type() != type || it->second.num_uses > 0) {
384          ++it;
385          continue;
386       }
387 
388       /* update the location transfer graph */
389       if (it->second.read_idx != -1u) {
390          auto&& other = ltg.find(it->second.read_idx);
391          if (other != ltg.end())
392             other->second.num_uses--;
393       }
394       ltg.erase(it);
395 
396       /* Remove the kill flag if we still need this operand for other copies. */
397       if (cp.op.isKill() && std::any_of(ltg.begin(), ltg.end(),
398                                         [&](auto& other) { return other.second.cp->op == cp.op; }))
399          cp.op.setKill(false);
400 
401       /* emit the copy */
402       Instruction* instr = bld.copy(cp.def, cp.op);
403       live_changes += get_live_changes(instr);
404       RegisterDemand temps = get_temp_registers(instr);
405       instr->register_demand = reg_demand + live_changes + temps;
406 
407       it = ltg.begin();
408    }
409 
410    /* count the number of remaining circular dependencies */
411    unsigned num = std::count_if(
412       ltg.begin(), ltg.end(), [&](auto& n) { return n.second.cp->def.regClass().type() == type; });
413 
414    /* if there are circular dependencies, we just emit them as single parallelcopy */
415    if (num) {
416       // TODO: this should be restricted to a feasible number of registers
417       // and otherwise use a temporary to avoid having to reload more (spilled)
418       // variables than we have registers.
419       aco_ptr<Instruction> copy{
420          create_instruction(aco_opcode::p_parallelcopy, Format::PSEUDO, num, num)};
421       it = ltg.begin();
422       for (unsigned i = 0; i < num; i++) {
423          while (it->second.cp->def.regClass().type() != type)
424             ++it;
425 
426          copy->definitions[i] = it->second.cp->def;
427          copy->operands[i] = it->second.cp->op;
428          it = ltg.erase(it);
429       }
430       live_changes += get_live_changes(copy.get());
431       RegisterDemand temps = get_temp_registers(copy.get());
432       copy->register_demand = reg_demand + live_changes + temps;
433       bld.insert(std::move(copy));
434    }
435 
436    /* Update RegisterDemand after inserted copies */
437    for (auto instr_it = bld.it; instr_it != bld.instructions->end(); ++instr_it) {
438       instr_it->get()->register_demand += live_changes;
439    }
440 }
441 
442 /* either emits or coalesces all parallelcopies and
443  * renames the phi-operands accordingly. */
444 void
emit_parallelcopies(cssa_ctx & ctx)445 emit_parallelcopies(cssa_ctx& ctx)
446 {
447    std::unordered_map<uint32_t, Operand> renames;
448 
449    /* we iterate backwards to prioritize coalescing in else-blocks */
450    for (int i = ctx.program->blocks.size() - 1; i >= 0; i--) {
451       if (ctx.parallelcopies[i].empty())
452          continue;
453 
454       std::map<uint32_t, ltg_node> ltg;
455       bool has_vgpr_copy = false;
456       bool has_sgpr_copy = false;
457 
458       /* first, try to coalesce all parallelcopies */
459       for (copy& cp : ctx.parallelcopies[i]) {
460          if (try_coalesce_copy(ctx, cp, i)) {
461             assert(cp.op.isTemp() && cp.op.isKill());
462             /* As this temp will be used as phi operand and becomes live-out,
463              * remove the kill flag from any other copy of this same temp.
464              */
465             for (copy& other : ctx.parallelcopies[i]) {
466                if (&other != &cp && other.op.isTemp() && other.op.getTemp() == cp.op.getTemp())
467                   other.op.setKill(false);
468             }
469             renames.emplace(cp.def.tempId(), cp.op);
470          } else {
471             uint32_t read_idx = -1u;
472             if (cp.op.isTemp()) {
473                read_idx = ctx.merge_node_table[cp.op.tempId()].index;
474                /* In case the original phi-operand was killed, it might still be live-out
475                 * if the logical successor is not the same as linear successors.
476                 * Thus, re-check whether the temp is live-out.
477                 */
478                cp.op.setKill(cp.op.isKill() && !is_live_out(ctx, cp.op.getTemp(), i));
479                cp.op.setFirstKill(cp.op.isKill());
480             }
481             uint32_t write_idx = ctx.merge_node_table[cp.def.tempId()].index;
482             assert(write_idx != -1u);
483             ltg[write_idx] = {&cp, read_idx};
484 
485             bool is_vgpr = cp.def.regClass().type() == RegType::vgpr;
486             has_vgpr_copy |= is_vgpr;
487             has_sgpr_copy |= !is_vgpr;
488          }
489       }
490 
491       /* build location-transfer-graph */
492       for (auto& pair : ltg) {
493          if (pair.second.read_idx == -1u)
494             continue;
495          auto&& it = ltg.find(pair.second.read_idx);
496          if (it != ltg.end())
497             it->second.num_uses++;
498       }
499 
500       /* emit parallelcopies ordered */
501       Builder bld(ctx.program);
502       Block& block = ctx.program->blocks[i];
503 
504       if (has_vgpr_copy) {
505          /* emit VGPR copies */
506          auto IsLogicalEnd = [](const aco_ptr<Instruction>& inst) -> bool
507          { return inst->opcode == aco_opcode::p_logical_end; };
508          auto it =
509             std::find_if(block.instructions.rbegin(), block.instructions.rend(), IsLogicalEnd);
510          bld.reset(&block.instructions, std::prev(it.base()));
511          emit_copies_block(bld, ltg, RegType::vgpr);
512       }
513 
514       if (has_sgpr_copy) {
515          /* emit SGPR copies */
516          bld.reset(&block.instructions, std::prev(block.instructions.end()));
517          emit_copies_block(bld, ltg, RegType::sgpr);
518       }
519    }
520 
521    RegisterDemand new_demand;
522    for (Block& block : ctx.program->blocks) {
523       /* Finally, rename coalesced phi operands */
524       for (aco_ptr<Instruction>& phi : block.instructions) {
525          if (phi->opcode != aco_opcode::p_phi && phi->opcode != aco_opcode::p_linear_phi)
526             break;
527 
528          for (Operand& op : phi->operands) {
529             if (!op.isTemp())
530                continue;
531             auto&& it = renames.find(op.tempId());
532             if (it != renames.end()) {
533                op = it->second;
534                renames.erase(it);
535             }
536          }
537       }
538 
539       /* Resummarize the block's register demand */
540       block.register_demand = block.live_in_demand;
541       for (const aco_ptr<Instruction>& instr : block.instructions)
542          block.register_demand.update(instr->register_demand);
543       new_demand.update(block.register_demand);
544    }
545 
546    /* Update max_reg_demand and num_waves */
547    update_vgpr_sgpr_demand(ctx.program, new_demand);
548 
549    assert(renames.empty());
550 }
551 
552 } /* end namespace */
553 
554 void
lower_to_cssa(Program * program)555 lower_to_cssa(Program* program)
556 {
557    reindex_ssa(program, true);
558    cssa_ctx ctx = {program};
559    collect_parallelcopies(ctx);
560    emit_parallelcopies(ctx);
561 
562    /* Validate live variable information */
563    if (!validate_live_vars(program))
564       abort();
565 }
566 } // namespace aco
567