xref: /aosp_15_r20/external/mesa3d/src/nouveau/compiler/nak/opt_uniform_instrs.rs (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 // Copyright © 2024 Collabora, Ltd.
2 // SPDX-License-Identifier: MIT
3 
4 use crate::ir::*;
5 use std::collections::HashMap;
6 
should_lower_to_warp( sm: &dyn ShaderModel, instr: &Instr, r2ur: &HashMap<SSAValue, SSAValue>, ) -> bool7 fn should_lower_to_warp(
8     sm: &dyn ShaderModel,
9     instr: &Instr,
10     r2ur: &HashMap<SSAValue, SSAValue>,
11 ) -> bool {
12     if !sm.op_can_be_uniform(&instr.op) {
13         return true;
14     }
15 
16     let mut num_non_uniform_srcs = 0;
17     instr.for_each_ssa_use(|ssa| {
18         if !ssa.is_uniform() || r2ur.contains_key(ssa) {
19             num_non_uniform_srcs += 1;
20         }
21     });
22 
23     if num_non_uniform_srcs >= 2 {
24         return true;
25     }
26 
27     return false;
28 }
29 
propagate_r2ur( instr: &mut Instr, r2ur: &HashMap<SSAValue, SSAValue>, ) -> bool30 fn propagate_r2ur(
31     instr: &mut Instr,
32     r2ur: &HashMap<SSAValue, SSAValue>,
33 ) -> bool {
34     let mut progress = false;
35 
36     // We don't want Instr::for_each_ssa_use_mut() because it would treat
37     // bindless cbuf sources as SSA sources.
38     for src in instr.srcs_mut() {
39         if let SrcRef::SSA(vec) = &mut src.src_ref {
40             for ssa in &mut vec[..] {
41                 if let Some(r) = r2ur.get(ssa) {
42                     progress = true;
43                     *ssa = *r;
44                 }
45             }
46         }
47     }
48 
49     progress
50 }
51 
52 impl Shader<'_> {
opt_uniform_instrs(&mut self)53     pub fn opt_uniform_instrs(&mut self) {
54         let sm = self.sm;
55         let mut r2ur = HashMap::new();
56         let mut propagated_r2ur = false;
57         self.map_instrs(|mut instr, alloc| {
58             if matches!(
59                 &instr.op,
60                 Op::PhiDsts(_)
61                     | Op::PhiSrcs(_)
62                     | Op::Pin(_)
63                     | Op::Unpin(_)
64                     | Op::Vote(_)
65             ) {
66                 MappedInstrs::One(instr)
67             } else if instr.is_uniform() {
68                 let mut b = InstrBuilder::new(sm);
69                 if should_lower_to_warp(sm, &instr, &r2ur) {
70                     propagated_r2ur |= propagate_r2ur(&mut instr, &r2ur);
71                     instr.for_each_ssa_def_mut(|ssa| {
72                         let w = alloc.alloc(ssa.file().to_warp());
73                         r2ur.insert(*ssa, w);
74                         b.push_op(OpR2UR {
75                             dst: (*ssa).into(),
76                             src: w.into(),
77                         });
78                         *ssa = w;
79                     });
80                     let mut v = b.as_vec();
81                     v.insert(0, instr);
82                     MappedInstrs::Many(v)
83                 } else {
84                     // We may have non-uniform sources
85                     instr.for_each_ssa_use_mut(|ssa| {
86                         let file = ssa.file();
87                         if !file.is_uniform() {
88                             let u = alloc.alloc(file.to_uniform().unwrap());
89                             b.push_op(OpR2UR {
90                                 dst: u.into(),
91                                 src: (*ssa).into(),
92                             });
93                             *ssa = u;
94                         }
95                     });
96                     b.push_instr(instr);
97                     b.as_mapped_instrs()
98                 }
99             } else {
100                 propagated_r2ur |= propagate_r2ur(&mut instr, &r2ur);
101                 MappedInstrs::One(instr)
102             }
103         });
104 
105         if propagated_r2ur {
106             self.opt_dce();
107         }
108     }
109 }
110