xref: /aosp_15_r20/external/angle/src/compiler/translator/hlsl/ShaderStorageBlockFunctionHLSL.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2018 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // ShaderStorageBlockFunctionHLSL: Wrapper functions for RWByteAddressBuffer Load/Store functions.
7 //
8 
9 #include "compiler/translator/hlsl/ShaderStorageBlockFunctionHLSL.h"
10 
11 #include "common/utilities.h"
12 #include "compiler/translator/blocklayout.h"
13 #include "compiler/translator/hlsl/UtilsHLSL.h"
14 #include "compiler/translator/hlsl/blocklayoutHLSL.h"
15 #include "compiler/translator/util.h"
16 
17 namespace sh
18 {
19 
20 // static
OutputSSBOLoadFunctionBody(TInfoSinkBase & out,const ShaderStorageBlockFunction & ssboFunction)21 void ShaderStorageBlockFunctionHLSL::OutputSSBOLoadFunctionBody(
22     TInfoSinkBase &out,
23     const ShaderStorageBlockFunction &ssboFunction)
24 {
25     const char *convertString;
26     switch (ssboFunction.type.getBasicType())
27     {
28         case EbtFloat:
29             convertString = "asfloat(";
30             break;
31         case EbtInt:
32             convertString = "asint(";
33             break;
34         case EbtUInt:
35             convertString = "asuint(";
36             break;
37         case EbtBool:
38             convertString = "asint(";
39             break;
40         default:
41             UNREACHABLE();
42             return;
43     }
44 
45     size_t bytesPerComponent =
46         gl::VariableComponentSize(gl::VariableComponentType(GLVariableType(ssboFunction.type)));
47     out << "    " << ssboFunction.typeString << " result";
48     if (ssboFunction.type.isScalar())
49     {
50         size_t offset = ssboFunction.swizzleOffsets[0] * bytesPerComponent;
51         out << " = " << convertString << "buffer.Load(loc + " << offset << "));\n ";
52     }
53     else if (ssboFunction.type.isVector())
54     {
55         if (ssboFunction.rowMajor || !ssboFunction.isDefaultSwizzle)
56         {
57             size_t componentStride = bytesPerComponent;
58             if (ssboFunction.rowMajor)
59             {
60                 componentStride = ssboFunction.matrixStride;
61             }
62 
63             out << " = {";
64             for (const int offset : ssboFunction.swizzleOffsets)
65             {
66                 size_t offsetInBytes = offset * componentStride;
67                 out << convertString << "buffer.Load(loc + " << offsetInBytes << ")),";
68             }
69             out << "};\n";
70         }
71         else
72         {
73             out << " = " << convertString << "buffer.Load"
74                 << static_cast<uint32_t>(ssboFunction.type.getNominalSize()) << "(loc));\n";
75         }
76     }
77     else if (ssboFunction.type.isMatrix())
78     {
79         if (ssboFunction.rowMajor)
80         {
81             out << ";";
82             out << "    float" << static_cast<uint32_t>(ssboFunction.type.getRows()) << "x"
83                 << static_cast<uint32_t>(ssboFunction.type.getCols()) << " tmp_ = {";
84             for (uint8_t rowIndex = 0; rowIndex < ssboFunction.type.getRows(); rowIndex++)
85             {
86                 out << "asfloat(buffer.Load" << static_cast<uint32_t>(ssboFunction.type.getCols())
87                     << "(loc + " << rowIndex * ssboFunction.matrixStride << ")), ";
88             }
89             out << "};\n";
90             out << "    result = transpose(tmp_);\n";
91         }
92         else
93         {
94             out << " = {";
95             for (uint8_t columnIndex = 0; columnIndex < ssboFunction.type.getCols(); columnIndex++)
96             {
97                 out << "asfloat(buffer.Load" << static_cast<uint32_t>(ssboFunction.type.getRows())
98                     << "(loc + " << columnIndex * ssboFunction.matrixStride << ")), ";
99             }
100             out << "};\n";
101         }
102     }
103     else
104     {
105         // TODO([email protected]): Process all possible return types.
106         // http://anglebug.com/40644618
107         out << ";\n";
108     }
109 
110     out << "    return result;\n";
111     return;
112 }
113 
114 // static
OutputSSBOStoreFunctionBody(TInfoSinkBase & out,const ShaderStorageBlockFunction & ssboFunction)115 void ShaderStorageBlockFunctionHLSL::OutputSSBOStoreFunctionBody(
116     TInfoSinkBase &out,
117     const ShaderStorageBlockFunction &ssboFunction)
118 {
119     size_t bytesPerComponent =
120         gl::VariableComponentSize(gl::VariableComponentType(GLVariableType(ssboFunction.type)));
121     if (ssboFunction.type.isScalar())
122     {
123         size_t offset = ssboFunction.swizzleOffsets[0] * bytesPerComponent;
124         if (ssboFunction.type.getBasicType() == EbtBool)
125         {
126             out << "    buffer.Store(loc + " << offset << ", uint(value));\n";
127         }
128         else
129         {
130             out << "    buffer.Store(loc + " << offset << ", asuint(value));\n";
131         }
132     }
133     else if (ssboFunction.type.isVector())
134     {
135         out << "    uint" << static_cast<uint32_t>(ssboFunction.type.getNominalSize())
136             << " _value;\n";
137         if (ssboFunction.type.getBasicType() == EbtBool)
138         {
139             out << "    _value = uint" << static_cast<uint32_t>(ssboFunction.type.getNominalSize())
140                 << "(value);\n";
141         }
142         else
143         {
144             out << "    _value = asuint(value);\n";
145         }
146 
147         if (ssboFunction.rowMajor || !ssboFunction.isDefaultSwizzle)
148         {
149             size_t componentStride = bytesPerComponent;
150             if (ssboFunction.rowMajor)
151             {
152                 componentStride = ssboFunction.matrixStride;
153             }
154             const TVector<int> &swizzleOffsets = ssboFunction.swizzleOffsets;
155             for (int index = 0; index < static_cast<int>(swizzleOffsets.size()); index++)
156             {
157                 size_t offsetInBytes = swizzleOffsets[index] * componentStride;
158                 out << "buffer.Store(loc + " << offsetInBytes << ", _value[" << index << "]);\n";
159             }
160         }
161         else
162         {
163             out << "    buffer.Store" << static_cast<uint32_t>(ssboFunction.type.getNominalSize())
164                 << "(loc, _value);\n";
165         }
166     }
167     else if (ssboFunction.type.isMatrix())
168     {
169         if (ssboFunction.rowMajor)
170         {
171             out << "    float" << static_cast<uint32_t>(ssboFunction.type.getRows()) << "x"
172                 << static_cast<uint32_t>(ssboFunction.type.getCols())
173                 << " tmp_ = transpose(value);\n";
174             for (uint8_t rowIndex = 0; rowIndex < ssboFunction.type.getRows(); rowIndex++)
175             {
176                 out << "    buffer.Store" << static_cast<uint32_t>(ssboFunction.type.getCols())
177                     << "(loc + " << rowIndex * ssboFunction.matrixStride << ", asuint(tmp_["
178                     << static_cast<uint32_t>(rowIndex) << "]));\n";
179             }
180         }
181         else
182         {
183             for (uint8_t columnIndex = 0; columnIndex < ssboFunction.type.getCols(); columnIndex++)
184             {
185                 out << "    buffer.Store" << static_cast<uint32_t>(ssboFunction.type.getRows())
186                     << "(loc + " << columnIndex * ssboFunction.matrixStride << ", asuint(value["
187                     << static_cast<uint32_t>(columnIndex) << "]));\n";
188             }
189         }
190     }
191     else
192     {
193         // TODO([email protected]): Process all possible return types.
194         // http://anglebug.com/40644618
195     }
196 }
197 
198 // static
OutputSSBOLengthFunctionBody(TInfoSinkBase & out,int unsizedArrayStride)199 void ShaderStorageBlockFunctionHLSL::OutputSSBOLengthFunctionBody(TInfoSinkBase &out,
200                                                                   int unsizedArrayStride)
201 {
202     out << "    uint dim = 0;\n";
203     out << "    buffer.GetDimensions(dim);\n";
204     out << "    return int((dim - loc)/uint(" << unsizedArrayStride << "));\n";
205 }
206 
207 // static
OutputSSBOAtomicMemoryFunctionBody(TInfoSinkBase & out,const ShaderStorageBlockFunction & ssboFunction)208 void ShaderStorageBlockFunctionHLSL::OutputSSBOAtomicMemoryFunctionBody(
209     TInfoSinkBase &out,
210     const ShaderStorageBlockFunction &ssboFunction)
211 {
212     out << "    " << ssboFunction.typeString << " original_value;\n";
213     switch (ssboFunction.method)
214     {
215         case SSBOMethod::ATOMIC_ADD:
216             out << "    buffer.InterlockedAdd(loc, value, original_value);\n";
217             break;
218         case SSBOMethod::ATOMIC_MIN:
219             out << "    buffer.InterlockedMin(loc, value, original_value);\n";
220             break;
221         case SSBOMethod::ATOMIC_MAX:
222             out << "    buffer.InterlockedMax(loc, value, original_value);\n";
223             break;
224         case SSBOMethod::ATOMIC_AND:
225             out << "    buffer.InterlockedAnd(loc, value, original_value);\n";
226             break;
227         case SSBOMethod::ATOMIC_OR:
228             out << "    buffer.InterlockedOr(loc, value, original_value);\n";
229             break;
230         case SSBOMethod::ATOMIC_XOR:
231             out << "    buffer.InterlockedXor(loc, value, original_value);\n";
232             break;
233         case SSBOMethod::ATOMIC_EXCHANGE:
234             out << "    buffer.InterlockedExchange(loc, value, original_value);\n";
235             break;
236         case SSBOMethod::ATOMIC_COMPSWAP:
237             out << "    buffer.InterlockedCompareExchange(loc, compare_value, value, "
238                    "original_value);\n";
239             break;
240         default:
241             UNREACHABLE();
242     }
243     out << "    return original_value;\n";
244 }
245 
operator <(const ShaderStorageBlockFunction & rhs) const246 bool ShaderStorageBlockFunctionHLSL::ShaderStorageBlockFunction::operator<(
247     const ShaderStorageBlockFunction &rhs) const
248 {
249     return functionName < rhs.functionName;
250 }
251 
registerShaderStorageBlockFunction(const TType & type,SSBOMethod method,TLayoutBlockStorage storage,bool rowMajor,int matrixStride,int unsizedArrayStride,TIntermSwizzle * swizzleNode)252 TString ShaderStorageBlockFunctionHLSL::registerShaderStorageBlockFunction(
253     const TType &type,
254     SSBOMethod method,
255     TLayoutBlockStorage storage,
256     bool rowMajor,
257     int matrixStride,
258     int unsizedArrayStride,
259     TIntermSwizzle *swizzleNode)
260 {
261     ShaderStorageBlockFunction ssboFunction;
262     ssboFunction.typeString = TypeString(type);
263     ssboFunction.method     = method;
264     switch (method)
265     {
266         case SSBOMethod::LOAD:
267             ssboFunction.functionName = "_Load_";
268             break;
269         case SSBOMethod::STORE:
270             ssboFunction.functionName = "_Store_";
271             break;
272         case SSBOMethod::LENGTH:
273             ssboFunction.unsizedArrayStride = unsizedArrayStride;
274             ssboFunction.functionName       = "_Length_" + str(unsizedArrayStride);
275             mRegisteredShaderStorageBlockFunctions.insert(ssboFunction);
276             return ssboFunction.functionName;
277         case SSBOMethod::ATOMIC_ADD:
278             ssboFunction.functionName = "_ssbo_atomicAdd_" + ssboFunction.typeString;
279             mRegisteredShaderStorageBlockFunctions.insert(ssboFunction);
280             return ssboFunction.functionName;
281         case SSBOMethod::ATOMIC_MIN:
282             ssboFunction.functionName = "_ssbo_atomicMin_" + ssboFunction.typeString;
283             mRegisteredShaderStorageBlockFunctions.insert(ssboFunction);
284             return ssboFunction.functionName;
285         case SSBOMethod::ATOMIC_MAX:
286             ssboFunction.functionName = "_ssbo_atomicMax_" + ssboFunction.typeString;
287             mRegisteredShaderStorageBlockFunctions.insert(ssboFunction);
288             return ssboFunction.functionName;
289         case SSBOMethod::ATOMIC_AND:
290             ssboFunction.functionName = "_ssbo_atomicAnd_" + ssboFunction.typeString;
291             mRegisteredShaderStorageBlockFunctions.insert(ssboFunction);
292             return ssboFunction.functionName;
293         case SSBOMethod::ATOMIC_OR:
294             ssboFunction.functionName = "_ssbo_atomicOr_" + ssboFunction.typeString;
295             mRegisteredShaderStorageBlockFunctions.insert(ssboFunction);
296             return ssboFunction.functionName;
297         case SSBOMethod::ATOMIC_XOR:
298             ssboFunction.functionName = "_ssbo_atomicXor_" + ssboFunction.typeString;
299             mRegisteredShaderStorageBlockFunctions.insert(ssboFunction);
300             return ssboFunction.functionName;
301         case SSBOMethod::ATOMIC_EXCHANGE:
302             ssboFunction.functionName = "_ssbo_atomicExchange_" + ssboFunction.typeString;
303             mRegisteredShaderStorageBlockFunctions.insert(ssboFunction);
304             return ssboFunction.functionName;
305         case SSBOMethod::ATOMIC_COMPSWAP:
306             ssboFunction.functionName = "_ssbo_atomicCompSwap_" + ssboFunction.typeString;
307             mRegisteredShaderStorageBlockFunctions.insert(ssboFunction);
308             return ssboFunction.functionName;
309         default:
310             UNREACHABLE();
311     }
312 
313     ssboFunction.functionName += ssboFunction.typeString;
314     ssboFunction.type = type;
315     if (swizzleNode != nullptr)
316     {
317         ssboFunction.swizzleOffsets   = swizzleNode->getSwizzleOffsets();
318         ssboFunction.isDefaultSwizzle = false;
319     }
320     else
321     {
322         if (ssboFunction.type.getNominalSize() > 1)
323         {
324             for (uint8_t index = 0; index < ssboFunction.type.getNominalSize(); index++)
325             {
326                 ssboFunction.swizzleOffsets.push_back(index);
327             }
328         }
329         else
330         {
331             ssboFunction.swizzleOffsets.push_back(0);
332         }
333 
334         ssboFunction.isDefaultSwizzle = true;
335     }
336     ssboFunction.rowMajor     = rowMajor;
337     ssboFunction.matrixStride = matrixStride;
338     ssboFunction.functionName += "_" + TString(getBlockStorageString(storage));
339 
340     if (rowMajor)
341     {
342         ssboFunction.functionName += "_rm_";
343     }
344     else
345     {
346         ssboFunction.functionName += "_cm_";
347     }
348 
349     for (const int offset : ssboFunction.swizzleOffsets)
350     {
351         switch (offset)
352         {
353             case 0:
354                 ssboFunction.functionName += "x";
355                 break;
356             case 1:
357                 ssboFunction.functionName += "y";
358                 break;
359             case 2:
360                 ssboFunction.functionName += "z";
361                 break;
362             case 3:
363                 ssboFunction.functionName += "w";
364                 break;
365             default:
366                 UNREACHABLE();
367         }
368     }
369 
370     mRegisteredShaderStorageBlockFunctions.insert(ssboFunction);
371     return ssboFunction.functionName;
372 }
373 
shaderStorageBlockFunctionHeader(TInfoSinkBase & out)374 void ShaderStorageBlockFunctionHLSL::shaderStorageBlockFunctionHeader(TInfoSinkBase &out)
375 {
376     for (const ShaderStorageBlockFunction &ssboFunction : mRegisteredShaderStorageBlockFunctions)
377     {
378         switch (ssboFunction.method)
379         {
380             case SSBOMethod::LOAD:
381             {
382                 // Function header
383                 out << ssboFunction.typeString << " " << ssboFunction.functionName
384                     << "(RWByteAddressBuffer buffer, uint loc)\n";
385                 out << "{\n";
386                 OutputSSBOLoadFunctionBody(out, ssboFunction);
387                 break;
388             }
389             case SSBOMethod::STORE:
390             {
391                 // Function header
392                 out << "void " << ssboFunction.functionName
393                     << "(RWByteAddressBuffer buffer, uint loc, " << ssboFunction.typeString
394                     << " value)\n";
395                 out << "{\n";
396                 OutputSSBOStoreFunctionBody(out, ssboFunction);
397                 break;
398             }
399             case SSBOMethod::LENGTH:
400             {
401                 out << "int " << ssboFunction.functionName
402                     << "(RWByteAddressBuffer buffer, uint loc)\n";
403                 out << "{\n";
404                 OutputSSBOLengthFunctionBody(out, ssboFunction.unsizedArrayStride);
405                 break;
406             }
407             case SSBOMethod::ATOMIC_ADD:
408             case SSBOMethod::ATOMIC_MIN:
409             case SSBOMethod::ATOMIC_MAX:
410             case SSBOMethod::ATOMIC_AND:
411             case SSBOMethod::ATOMIC_OR:
412             case SSBOMethod::ATOMIC_XOR:
413             case SSBOMethod::ATOMIC_EXCHANGE:
414             {
415                 out << ssboFunction.typeString << " " << ssboFunction.functionName
416                     << "(RWByteAddressBuffer buffer, uint loc, " << ssboFunction.typeString
417                     << " value)\n";
418                 out << "{\n";
419 
420                 OutputSSBOAtomicMemoryFunctionBody(out, ssboFunction);
421                 break;
422             }
423             case SSBOMethod::ATOMIC_COMPSWAP:
424             {
425                 out << ssboFunction.typeString << " " << ssboFunction.functionName
426                     << "(RWByteAddressBuffer buffer, uint loc, " << ssboFunction.typeString
427                     << " compare_value, " << ssboFunction.typeString << " value)\n";
428                 out << "{\n";
429                 OutputSSBOAtomicMemoryFunctionBody(out, ssboFunction);
430                 break;
431             }
432             default:
433                 UNREACHABLE();
434         }
435 
436         out << "}\n"
437                "\n";
438     }
439 }
440 
441 }  // namespace sh
442