xref: /aosp_15_r20/external/mesa3d/src/gallium/drivers/d3d12/d3d12_lower_image_casts.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © Microsoft Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_format_convert.h"
27 
28 #include "pipe/p_state.h"
29 #include "util/format/u_format.h"
30 
31 #include "d3d12_compiler.h"
32 #include "d3d12_nir_passes.h"
33 
34 static nir_def *
convert_value(nir_builder * b,nir_def * value,const struct util_format_description * from_desc,const struct util_format_description * to_desc)35 convert_value(nir_builder *b, nir_def *value,
36    const struct util_format_description *from_desc,
37    const struct util_format_description *to_desc)
38 {
39    if (from_desc->format == to_desc->format)
40       return value;
41 
42    assert(value->num_components == 4);
43    /* No support for 16 or 64 bit data in the shader for image loads/stores */
44    assert(value->bit_size == 32);
45    /* Overall format size needs to be the same */
46    assert(from_desc->block.bits == to_desc->block.bits);
47    assert(from_desc->nr_channels <= 4 && to_desc->nr_channels <= 4);
48 
49    const unsigned rgba1010102_bits[] = { 10, 10, 10, 2 };
50 
51    /* First, construct a "tightly packed" vector of the input values. For unorm/snorm, convert
52     * from the float we're given into the original bits (only happens while storing). For packed
53     * formats that don't fall on a nice bit size, convert/pack them into 32bit values. Otherwise,
54     * just produce a vecNx4 where N is the expected bit size.
55     */
56    nir_def *src_as_vec;
57    if (from_desc->format == PIPE_FORMAT_R10G10B10A2_UINT ||
58        from_desc->format == PIPE_FORMAT_R10G10B10A2_UNORM) {
59       if (from_desc->format == PIPE_FORMAT_R10G10B10A2_UNORM)
60          value = nir_format_float_to_unorm(b, value, rgba1010102_bits);
61       nir_def *channels[4];
62       for (unsigned i = 0; i < 4; ++i)
63          channels[i] = nir_channel(b, value, i);
64 
65       src_as_vec = channels[0];
66       src_as_vec = nir_mask_shift_or(b, src_as_vec, channels[1], (1 << 10) - 1, 10);
67       src_as_vec = nir_mask_shift_or(b, src_as_vec, channels[2], (1 << 10) - 1, 20);
68       src_as_vec = nir_mask_shift_or(b, src_as_vec, channels[3], (1 << 2) - 1, 30);
69    } else if (from_desc->format == PIPE_FORMAT_R11G11B10_FLOAT) {
70       src_as_vec = nir_format_pack_11f11f10f(b, value);
71    } else if (from_desc->is_unorm) {
72       if (from_desc->channel[0].size == 8)
73          src_as_vec = nir_pack_unorm_4x8(b, value);
74       else {
75          nir_def *packed_channels[2];
76          packed_channels[0] = nir_pack_unorm_2x16(b,
77                                                   nir_trim_vector(b, value, 2));
78          packed_channels[1] = nir_pack_unorm_2x16(b, nir_channels(b, value, 0x3 << 2));
79          src_as_vec = nir_vec(b, packed_channels, 2);
80       }
81    } else if (from_desc->is_snorm) {
82       if (from_desc->channel[0].size == 8)
83          src_as_vec = nir_pack_snorm_4x8(b, value);
84       else {
85          nir_def *packed_channels[2];
86          packed_channels[0] = nir_pack_snorm_2x16(b,
87                                                   nir_trim_vector(b, value, 2));
88          packed_channels[1] = nir_pack_snorm_2x16(b, nir_channels(b, value, 0x3 << 2));
89          src_as_vec = nir_vec(b, packed_channels, 2);
90       }
91    } else if (util_format_is_float(from_desc->format)) {
92       src_as_vec = nir_f2fN(b, value, from_desc->channel[0].size);
93    } else if (util_format_is_pure_sint(from_desc->format)) {
94       src_as_vec = nir_i2iN(b, value, from_desc->channel[0].size);
95    } else {
96       src_as_vec = nir_u2uN(b, value, from_desc->channel[0].size);
97    }
98 
99    /* Now that we have the tightly packed bits, we can use nir_extract_bits to get it into a
100     * vector of differently-sized components. For producing packed formats, get a 32-bit
101     * value and manually extract the bits. For unorm/snorm, get one or two 32-bit values,
102     * and extract it using helpers. Otherwise, get a format-sized dest vector and use a
103     * cast to expand it back to 32-bit.
104     *
105     * Pay extra attention for changing semantics for alpha as 1.
106     */
107    if (to_desc->format == PIPE_FORMAT_R10G10B10A2_UINT ||
108        to_desc->format == PIPE_FORMAT_R10G10B10A2_UNORM) {
109       nir_def *u32 = nir_extract_bits(b, &src_as_vec, 1, 0, 1, 32);
110       nir_def *channels[4] = {
111          nir_iand_imm(b, u32,                      (1 << 10) - 1),
112          nir_iand_imm(b, nir_ushr_imm(b, u32, 10), (1 << 10) - 1),
113          nir_iand_imm(b, nir_ushr_imm(b, u32, 20), (1 << 10) - 1),
114                          nir_ushr_imm(b, u32, 30)
115       };
116       nir_def *vec = nir_vec(b, channels, 4);
117       if (to_desc->format == PIPE_FORMAT_R10G10B10A2_UNORM)
118          vec = nir_format_unorm_to_float(b, vec, rgba1010102_bits);
119       return vec;
120    } else if (to_desc->format == PIPE_FORMAT_R11G11B10_FLOAT) {
121       nir_def *u32 = nir_extract_bits(b, &src_as_vec, 1, 0, 1, 32);
122       nir_def *vec3 = nir_format_unpack_11f11f10f(b, u32);
123       return nir_vec4(b, nir_channel(b, vec3, 0),
124                          nir_channel(b, vec3, 1),
125                          nir_channel(b, vec3, 2),
126                          nir_imm_float(b, 1.0f));
127    } else if (to_desc->is_unorm || to_desc->is_snorm) {
128       nir_def *dest_packed = nir_extract_bits(b, &src_as_vec, 1, 0,
129          DIV_ROUND_UP(to_desc->nr_channels * to_desc->channel[0].size, 32), 32);
130       if (to_desc->is_unorm) {
131          if (to_desc->channel[0].size == 8) {
132             nir_def *unpacked = nir_unpack_unorm_4x8(b, nir_channel(b, dest_packed, 0));
133             if (to_desc->nr_channels < 4)
134                unpacked = nir_vector_insert_imm(b, unpacked, nir_imm_float(b, 1.0f), 3);
135             return unpacked;
136          }
137          nir_def *vec2s[2] = {
138             nir_unpack_unorm_2x16(b, nir_channel(b, dest_packed, 0)),
139             to_desc->nr_channels > 2 ?
140                nir_unpack_unorm_2x16(b, nir_channel(b, dest_packed, 1)) :
141                nir_vec2(b, nir_imm_float(b, 0.0f), nir_imm_float(b, 1.0f))
142          };
143          if (to_desc->nr_channels == 1)
144             vec2s[0] = nir_vector_insert_imm(b, vec2s[0], nir_imm_float(b, 0.0f), 1);
145          return nir_vec4(b, nir_channel(b, vec2s[0], 0),
146                             nir_channel(b, vec2s[0], 1),
147                             nir_channel(b, vec2s[1], 0),
148                             nir_channel(b, vec2s[1], 1));
149       } else {
150          if (to_desc->channel[0].size == 8) {
151             nir_def *unpacked = nir_unpack_snorm_4x8(b, nir_channel(b, dest_packed, 0));
152             if (to_desc->nr_channels < 4)
153                unpacked = nir_vector_insert_imm(b, unpacked, nir_imm_float(b, 1.0f), 3);
154             return unpacked;
155          }
156          nir_def *vec2s[2] = {
157             nir_unpack_snorm_2x16(b, nir_channel(b, dest_packed, 0)),
158             to_desc->nr_channels > 2 ?
159                nir_unpack_snorm_2x16(b, nir_channel(b, dest_packed, 1)) :
160                nir_vec2(b, nir_imm_float(b, 0.0f), nir_imm_float(b, 1.0f))
161          };
162          if (to_desc->nr_channels == 1)
163             vec2s[0] = nir_vector_insert_imm(b, vec2s[0], nir_imm_float(b, 0.0f), 1);
164          return nir_vec4(b, nir_channel(b, vec2s[0], 0),
165                             nir_channel(b, vec2s[0], 1),
166                             nir_channel(b, vec2s[1], 0),
167                             nir_channel(b, vec2s[1], 1));
168       }
169    } else {
170       nir_def *dest_packed = nir_extract_bits(b, &src_as_vec, 1, 0,
171          to_desc->nr_channels, to_desc->channel[0].size);
172       nir_def *final_channels[4];
173       for (unsigned i = 0; i < 4; ++i) {
174          if (i >= dest_packed->num_components)
175             final_channels[i] = util_format_is_float(to_desc->format) ?
176             nir_imm_floatN_t(b, i == 3 ? 1.0f : 0.0f, to_desc->channel[0].size) :
177             nir_imm_intN_t(b, i == 3 ? 1 : 0, to_desc->channel[0].size);
178          else
179             final_channels[i] = nir_channel(b, dest_packed, i);
180       }
181       nir_def *final_vec = nir_vec(b, final_channels, 4);
182       if (util_format_is_float(to_desc->format))
183          return nir_f2f32(b, final_vec);
184       else if (util_format_is_pure_sint(to_desc->format))
185          return nir_i2i32(b, final_vec);
186       else
187          return nir_u2u32(b, final_vec);
188    }
189 }
190 
191 static bool
lower_image_cast_instr(nir_builder * b,nir_intrinsic_instr * intr,void * _data)192 lower_image_cast_instr(nir_builder *b, nir_intrinsic_instr *intr, void *_data)
193 {
194    if (intr->intrinsic != nir_intrinsic_image_deref_load &&
195        intr->intrinsic != nir_intrinsic_image_deref_store)
196       return false;
197 
198    const struct d3d12_image_format_conversion_info_arr* info = _data;
199    nir_variable *image = nir_intrinsic_get_var(intr, 0);
200    assert(image);
201 
202    if (image->data.driver_location >= info->n_images)
203       return false;
204 
205    enum pipe_format emulation_format = info->image_format_conversion[image->data.driver_location].emulated_format;
206    if (emulation_format == PIPE_FORMAT_NONE)
207       return false;
208 
209    enum pipe_format real_format = info->image_format_conversion[image->data.driver_location].view_format;
210    assert(real_format != emulation_format);
211 
212    nir_def *value;
213    const struct util_format_description *from_desc, *to_desc;
214    if (intr->intrinsic == nir_intrinsic_image_deref_load) {
215       b->cursor = nir_after_instr(&intr->instr);
216       value = &intr->def;
217       from_desc = util_format_description(emulation_format);
218       to_desc = util_format_description(real_format);
219    } else {
220       b->cursor = nir_before_instr(&intr->instr);
221       value = intr->src[3].ssa;
222       from_desc = util_format_description(real_format);
223       to_desc = util_format_description(emulation_format);
224    }
225 
226    nir_def *new_value = convert_value(b, value, from_desc, to_desc);
227 
228    nir_alu_type alu_type = util_format_is_pure_uint(emulation_format) ?
229       nir_type_uint : (util_format_is_pure_sint(emulation_format) ?
230          nir_type_int : nir_type_float);
231 
232    if (intr->intrinsic == nir_intrinsic_image_deref_load) {
233       nir_def_rewrite_uses_after(value, new_value, new_value->parent_instr);
234       nir_intrinsic_set_dest_type(intr, alu_type);
235    } else {
236       nir_src_rewrite(&intr->src[3], new_value);
237       nir_intrinsic_set_src_type(intr, alu_type);
238    }
239    nir_intrinsic_set_format(intr, emulation_format);
240    return true;
241 }
242 
243 /* Given a shader that does image loads/stores expecting to load from the format embedded in the intrinsic,
244  * if the corresponding entry in formats is not PIPE_FORMAT_NONE, replace the image format and convert
245  * the data being loaded/stored to/from the app's expected format.
246  */
247 bool
d3d12_lower_image_casts(nir_shader * s,struct d3d12_image_format_conversion_info_arr * info)248 d3d12_lower_image_casts(nir_shader *s, struct d3d12_image_format_conversion_info_arr *info)
249 {
250    bool progress = nir_shader_intrinsics_pass(s, lower_image_cast_instr,
251                                               nir_metadata_control_flow,
252                                               info);
253 
254    if (progress) {
255       nir_foreach_image_variable(var, s) {
256          if (var->data.driver_location < info->n_images && info->image_format_conversion[var->data.driver_location].emulated_format != PIPE_FORMAT_NONE) {
257             var->data.image.format = info->image_format_conversion[var->data.driver_location].emulated_format;
258          }
259       }
260    }
261 
262    return progress;
263 }
264