xref: /aosp_15_r20/external/angle/src/tests/compiler_tests/InitOutputVariables_test.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2017 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // InitOutputVariables_test.cpp: Tests correctness of the AST pass enabled through
7 // SH_INIT_OUTPUT_VARIABLES.
8 //
9 
10 #include "common/angleutils.h"
11 
12 #include "compiler/translator/SymbolTable.h"
13 #include "compiler/translator/tree_util/FindMain.h"
14 #include "compiler/translator/tree_util/IntermNode_util.h"
15 #include "compiler/translator/tree_util/IntermTraverse.h"
16 #include "tests/test_utils/ShaderCompileTreeTest.h"
17 
18 #include <algorithm>
19 
20 namespace sh
21 {
22 
23 namespace
24 {
25 
26 typedef std::vector<TIntermTyped *> ExpectedLValues;
27 
AreSymbolsTheSame(const TIntermSymbol * expected,const TIntermSymbol * candidate)28 bool AreSymbolsTheSame(const TIntermSymbol *expected, const TIntermSymbol *candidate)
29 {
30     if (expected == nullptr || candidate == nullptr)
31     {
32         return false;
33     }
34     const TType &expectedType  = expected->getType();
35     const TType &candidateType = candidate->getType();
36     const bool sameTypes       = expectedType == candidateType &&
37                            expectedType.getPrecision() == candidateType.getPrecision() &&
38                            expectedType.getQualifier() == candidateType.getQualifier();
39     const bool sameSymbols = (expected->variable().symbolType() == SymbolType::Empty &&
40                               candidate->variable().symbolType() == SymbolType::Empty) ||
41                              expected->getName() == candidate->getName();
42     return sameSymbols && sameTypes;
43 }
44 
AreLValuesTheSame(TIntermTyped * expected,TIntermTyped * candidate)45 bool AreLValuesTheSame(TIntermTyped *expected, TIntermTyped *candidate)
46 {
47     const TIntermBinary *expectedBinary = expected->getAsBinaryNode();
48     if (expectedBinary)
49     {
50         ASSERT(expectedBinary->getOp() == EOpIndexDirect);
51         const TIntermBinary *candidateBinary = candidate->getAsBinaryNode();
52         if (candidateBinary == nullptr || candidateBinary->getOp() != EOpIndexDirect)
53         {
54             return false;
55         }
56         if (expectedBinary->getRight()->getAsConstantUnion()->getIConst(0) !=
57             candidateBinary->getRight()->getAsConstantUnion()->getIConst(0))
58         {
59             return false;
60         }
61         return AreSymbolsTheSame(expectedBinary->getLeft()->getAsSymbolNode(),
62                                  candidateBinary->getLeft()->getAsSymbolNode());
63     }
64     return AreSymbolsTheSame(expected->getAsSymbolNode(), candidate->getAsSymbolNode());
65 }
66 
CreateLValueNode(const ImmutableString & lValueName,const TType & type)67 TIntermTyped *CreateLValueNode(const ImmutableString &lValueName, const TType &type)
68 {
69     // We're using a mock symbol table here, don't need to assign proper symbol ids to these nodes.
70     TSymbolTable symbolTable;
71     TVariable *variable =
72         new TVariable(&symbolTable, lValueName, new TType(type), SymbolType::UserDefined);
73     return new TIntermSymbol(variable);
74 }
75 
CreateIndexedLValueNodeList(const ImmutableString & lValueName,const TType & elementType,unsigned arraySize)76 ExpectedLValues CreateIndexedLValueNodeList(const ImmutableString &lValueName,
77                                             const TType &elementType,
78                                             unsigned arraySize)
79 {
80     ASSERT(elementType.isArray() == false);
81     TType *arrayType = new TType(elementType);
82     arrayType->makeArray(arraySize);
83 
84     // We're using a mock symbol table here, don't need to assign proper symbol ids to these nodes.
85     TSymbolTable symbolTable;
86     TVariable *variable =
87         new TVariable(&symbolTable, lValueName, arrayType, SymbolType::UserDefined);
88     TIntermSymbol *arraySymbol = new TIntermSymbol(variable);
89 
90     ExpectedLValues expected(arraySize);
91     for (unsigned index = 0u; index < arraySize; ++index)
92     {
93         expected[index] = new TIntermBinary(EOpIndexDirect, arraySymbol->deepCopy(),
94                                             CreateIndexNode(static_cast<int>(index)));
95     }
96     return expected;
97 }
98 
99 // VerifyOutputVariableInitializers traverses the subtree covering main and collects the lvalues in
100 // assignments for which the rvalue is an expression containing only zero constants.
101 class VerifyOutputVariableInitializers final : public TIntermTraverser
102 {
103   public:
VerifyOutputVariableInitializers(TIntermBlock * root)104     VerifyOutputVariableInitializers(TIntermBlock *root) : TIntermTraverser(true, false, false)
105     {
106         ASSERT(root != nullptr);
107 
108         // The traversal starts in the body of main because this is where the varyings and output
109         // variables are initialized.
110         sh::TIntermFunctionDefinition *main = FindMain(root);
111         ASSERT(main != nullptr);
112         main->traverse(this);
113     }
114 
visitBinary(Visit visit,TIntermBinary * node)115     bool visitBinary(Visit visit, TIntermBinary *node) override
116     {
117         if (node->getOp() == EOpAssign && IsZero(node->getRight()))
118         {
119             mCandidateLValues.push_back(node->getLeft());
120             return false;
121         }
122         return true;
123     }
124 
125     // The collected lvalues are considered valid if every expected lvalue in expectedLValues is
126     // matched by name and type with any lvalue in mCandidateLValues.
areAllExpectedLValuesFound(const ExpectedLValues & expectedLValues) const127     bool areAllExpectedLValuesFound(const ExpectedLValues &expectedLValues) const
128     {
129         for (size_t i = 0u; i < expectedLValues.size(); ++i)
130         {
131             if (!isExpectedLValueFound(expectedLValues[i]))
132             {
133                 return false;
134             }
135         }
136         return true;
137     }
138 
isExpectedLValueFound(TIntermTyped * expectedLValue) const139     bool isExpectedLValueFound(TIntermTyped *expectedLValue) const
140     {
141         bool isFound = false;
142         for (size_t j = 0; j < mCandidateLValues.size() && !isFound; ++j)
143         {
144             isFound = AreLValuesTheSame(expectedLValue, mCandidateLValues[j]);
145         }
146         return isFound;
147     }
148 
getCandidates() const149     const ExpectedLValues &getCandidates() const { return mCandidateLValues; }
150 
151   private:
152     ExpectedLValues mCandidateLValues;
153 };
154 
155 // Traverses the AST and records a pointer to a structure with a given name.
156 class FindStructByName final : public TIntermTraverser
157 {
158   public:
FindStructByName(const ImmutableString & structName)159     FindStructByName(const ImmutableString &structName)
160         : TIntermTraverser(true, false, false), mStructName(structName), mStructure(nullptr)
161     {}
162 
visitSymbol(TIntermSymbol * symbol)163     void visitSymbol(TIntermSymbol *symbol) override
164     {
165         if (isStructureFound())
166         {
167             return;
168         }
169 
170         const TStructure *structure = symbol->getType().getStruct();
171 
172         if (structure != nullptr && structure->symbolType() != SymbolType::Empty &&
173             structure->name() == mStructName)
174         {
175             mStructure = structure;
176         }
177     }
178 
isStructureFound() const179     bool isStructureFound() const { return mStructure != nullptr; }
getStructure() const180     const TStructure *getStructure() const { return mStructure; }
181 
182   private:
183     ImmutableString mStructName;
184     const TStructure *mStructure;
185 };
186 
187 }  // namespace
188 
189 class InitOutputVariablesWebGL2Test : public ShaderCompileTreeTest
190 {
191   public:
SetUp()192     void SetUp() override
193     {
194         mCompileOptions.initOutputVariables = true;
195         if (getShaderType() == GL_VERTEX_SHADER)
196         {
197             mCompileOptions.initGLPosition = true;
198         }
199         ShaderCompileTreeTest::SetUp();
200     }
201 
202   protected:
getShaderSpec() const203     ShShaderSpec getShaderSpec() const override { return SH_WEBGL2_SPEC; }
204 };
205 
206 class InitOutputVariablesWebGL2VertexShaderTest : public InitOutputVariablesWebGL2Test
207 {
208   protected:
getShaderType() const209     ::GLenum getShaderType() const override { return GL_VERTEX_SHADER; }
210 };
211 
212 class InitOutputVariablesWebGL2FragmentShaderTest : public InitOutputVariablesWebGL2Test
213 {
214   protected:
getShaderType() const215     ::GLenum getShaderType() const override { return GL_FRAGMENT_SHADER; }
initResources(ShBuiltInResources * resources)216     void initResources(ShBuiltInResources *resources) override
217     {
218         resources->EXT_draw_buffers = 1;
219         resources->MaxDrawBuffers   = 2;
220     }
221 };
222 
223 class InitOutputVariablesWebGL1FragmentShaderTest : public ShaderCompileTreeTest
224 {
225   public:
InitOutputVariablesWebGL1FragmentShaderTest()226     InitOutputVariablesWebGL1FragmentShaderTest() { mCompileOptions.initOutputVariables = true; }
227 
228   protected:
getShaderType() const229     ::GLenum getShaderType() const override { return GL_FRAGMENT_SHADER; }
getShaderSpec() const230     ShShaderSpec getShaderSpec() const override { return SH_WEBGL_SPEC; }
initResources(ShBuiltInResources * resources)231     void initResources(ShBuiltInResources *resources) override
232     {
233         resources->EXT_draw_buffers = 1;
234         resources->MaxDrawBuffers   = 2;
235     }
236 };
237 
238 class InitOutputVariablesVertexShaderClipDistanceTest : public ShaderCompileTreeTest
239 {
240   public:
InitOutputVariablesVertexShaderClipDistanceTest()241     InitOutputVariablesVertexShaderClipDistanceTest()
242     {
243         mCompileOptions.initOutputVariables = true;
244         mCompileOptions.validateAST         = true;
245     }
246 
247   protected:
getShaderType() const248     ::GLenum getShaderType() const override { return GL_VERTEX_SHADER; }
getShaderSpec() const249     ShShaderSpec getShaderSpec() const override { return SH_GLES2_SPEC; }
initResources(ShBuiltInResources * resources)250     void initResources(ShBuiltInResources *resources) override
251     {
252         resources->APPLE_clip_distance = 1;
253         resources->MaxClipDistances    = 8;
254     }
255 };
256 
257 // Test the initialization of output variables with various qualifiers in a vertex shader.
TEST_F(InitOutputVariablesWebGL2VertexShaderTest,OutputAllQualifiers)258 TEST_F(InitOutputVariablesWebGL2VertexShaderTest, OutputAllQualifiers)
259 {
260     const std::string &shaderString =
261         "#version 300 es\n"
262         "precision mediump float;\n"
263         "precision lowp int;\n"
264         "out vec4 out1;\n"
265         "flat out int out2;\n"
266         "centroid out float out3;\n"
267         "smooth out float out4;\n"
268         "void main() {\n"
269         "}\n";
270     compileAssumeSuccess(shaderString);
271     VerifyOutputVariableInitializers verifier(mASTRoot);
272 
273     ExpectedLValues expectedLValues = {
274         CreateLValueNode(ImmutableString("out1"), TType(EbtFloat, EbpMedium, EvqVertexOut, 4)),
275         CreateLValueNode(ImmutableString("out2"), TType(EbtInt, EbpLow, EvqFlatOut)),
276         CreateLValueNode(ImmutableString("out3"), TType(EbtFloat, EbpMedium, EvqCentroidOut)),
277         CreateLValueNode(ImmutableString("out4"), TType(EbtFloat, EbpMedium, EvqSmoothOut))};
278     EXPECT_TRUE(verifier.areAllExpectedLValuesFound(expectedLValues));
279 }
280 
281 // Test the initialization of an output array in a vertex shader.
TEST_F(InitOutputVariablesWebGL2VertexShaderTest,OutputArray)282 TEST_F(InitOutputVariablesWebGL2VertexShaderTest, OutputArray)
283 {
284     const std::string &shaderString =
285         "#version 300 es\n"
286         "precision mediump float;\n"
287         "out float out1[2];\n"
288         "void main() {\n"
289         "}\n";
290     compileAssumeSuccess(shaderString);
291     VerifyOutputVariableInitializers verifier(mASTRoot);
292 
293     ExpectedLValues expectedLValues = CreateIndexedLValueNodeList(
294         ImmutableString("out1"), TType(EbtFloat, EbpMedium, EvqVertexOut), 2);
295     EXPECT_TRUE(verifier.areAllExpectedLValuesFound(expectedLValues));
296 }
297 
298 // Test the initialization of a struct output variable in a vertex shader.
TEST_F(InitOutputVariablesWebGL2VertexShaderTest,OutputStruct)299 TEST_F(InitOutputVariablesWebGL2VertexShaderTest, OutputStruct)
300 {
301     const std::string &shaderString =
302         "#version 300 es\n"
303         "precision mediump float;\n"
304         "struct MyS{\n"
305         "   float a;\n"
306         "   float b;\n"
307         "};\n"
308         "out MyS out1;\n"
309         "void main() {\n"
310         "}\n";
311     compileAssumeSuccess(shaderString);
312     VerifyOutputVariableInitializers verifier(mASTRoot);
313 
314     FindStructByName findStruct(ImmutableString("MyS"));
315     mASTRoot->traverse(&findStruct);
316     ASSERT(findStruct.isStructureFound());
317 
318     TType type(findStruct.getStructure(), false);
319     type.setQualifier(EvqVertexOut);
320 
321     TIntermTyped *expectedLValue = CreateLValueNode(ImmutableString("out1"), type);
322     EXPECT_TRUE(verifier.isExpectedLValueFound(expectedLValue));
323     delete expectedLValue;
324 }
325 
326 // Test the initialization of a varying variable in an ESSL1 vertex shader.
TEST_F(InitOutputVariablesWebGL2VertexShaderTest,OutputFromESSL1Shader)327 TEST_F(InitOutputVariablesWebGL2VertexShaderTest, OutputFromESSL1Shader)
328 {
329     const std::string &shaderString =
330         "precision mediump float;\n"
331         "varying vec4 out1;\n"
332         "void main() {\n"
333         "}\n";
334     compileAssumeSuccess(shaderString);
335     VerifyOutputVariableInitializers verifier(mASTRoot);
336 
337     TIntermTyped *expectedLValue =
338         CreateLValueNode(ImmutableString("out1"), TType(EbtFloat, EbpMedium, EvqVaryingOut, 4));
339     EXPECT_TRUE(verifier.isExpectedLValueFound(expectedLValue));
340     delete expectedLValue;
341 }
342 
343 // Test the initialization of output variables in a fragment shader.
TEST_F(InitOutputVariablesWebGL2FragmentShaderTest,Output)344 TEST_F(InitOutputVariablesWebGL2FragmentShaderTest, Output)
345 {
346     const std::string &shaderString =
347         "#version 300 es\n"
348         "precision mediump float;\n"
349         "out vec4 out1;\n"
350         "void main() {\n"
351         "}\n";
352     compileAssumeSuccess(shaderString);
353     VerifyOutputVariableInitializers verifier(mASTRoot);
354 
355     TIntermTyped *expectedLValue =
356         CreateLValueNode(ImmutableString("out1"), TType(EbtFloat, EbpMedium, EvqFragmentOut, 4));
357     EXPECT_TRUE(verifier.isExpectedLValueFound(expectedLValue));
358     delete expectedLValue;
359 }
360 
361 // Test the initialization of gl_FragData in a WebGL2 ESSL1 fragment shader. Only writes to
362 // gl_FragData[0] should be found.
TEST_F(InitOutputVariablesWebGL2FragmentShaderTest,FragData)363 TEST_F(InitOutputVariablesWebGL2FragmentShaderTest, FragData)
364 {
365     const std::string &shaderString =
366         "precision mediump float;\n"
367         "void main() {\n"
368         "   gl_FragData[0] = vec4(1.);\n"
369         "}\n";
370     compileAssumeSuccess(shaderString);
371     VerifyOutputVariableInitializers verifier(mASTRoot);
372 
373     ExpectedLValues expectedLValues = CreateIndexedLValueNodeList(
374         ImmutableString("gl_FragData"), TType(EbtFloat, EbpMedium, EvqFragData, 4), 1);
375     EXPECT_TRUE(verifier.isExpectedLValueFound(expectedLValues[0]));
376     EXPECT_EQ(1u, verifier.getCandidates().size());
377 }
378 
379 // Test the initialization of gl_FragData in a WebGL1 ESSL1 fragment shader. Only writes to
380 // gl_FragData[0] should be found.
TEST_F(InitOutputVariablesWebGL1FragmentShaderTest,FragData)381 TEST_F(InitOutputVariablesWebGL1FragmentShaderTest, FragData)
382 {
383     const std::string &shaderString =
384         "precision mediump float;\n"
385         "void main() {\n"
386         "   gl_FragData[0] = vec4(1.);\n"
387         "}\n";
388     compileAssumeSuccess(shaderString);
389     VerifyOutputVariableInitializers verifier(mASTRoot);
390 
391     // In the symbol table, gl_FragData array has 2 elements. However, only the 1st one should be
392     // initialized.
393     ExpectedLValues expectedLValues = CreateIndexedLValueNodeList(
394         ImmutableString("gl_FragData"), TType(EbtFloat, EbpMedium, EvqFragData, 4), 2);
395     EXPECT_TRUE(verifier.isExpectedLValueFound(expectedLValues[0]));
396     EXPECT_EQ(1u, verifier.getCandidates().size());
397 }
398 
399 // Test the initialization of gl_FragData in a WebGL1 ESSL1 fragment shader with GL_EXT_draw_buffers
400 // enabled. All attachment slots should be initialized.
TEST_F(InitOutputVariablesWebGL1FragmentShaderTest,FragDataWithDrawBuffersExtEnabled)401 TEST_F(InitOutputVariablesWebGL1FragmentShaderTest, FragDataWithDrawBuffersExtEnabled)
402 {
403     const std::string &shaderString =
404         "#extension GL_EXT_draw_buffers : enable\n"
405         "precision mediump float;\n"
406         "void main() {\n"
407         "   gl_FragData[0] = vec4(1.);\n"
408         "}\n";
409     compileAssumeSuccess(shaderString);
410     VerifyOutputVariableInitializers verifier(mASTRoot);
411 
412     ExpectedLValues expectedLValues = CreateIndexedLValueNodeList(
413         ImmutableString("gl_FragData"), TType(EbtFloat, EbpMedium, EvqFragData, 4), 2);
414     EXPECT_TRUE(verifier.isExpectedLValueFound(expectedLValues[0]));
415     EXPECT_TRUE(verifier.isExpectedLValueFound(expectedLValues[1]));
416     EXPECT_EQ(2u, verifier.getCandidates().size());
417 }
418 
419 // Test that gl_Position is initialized once in case it is not statically used and both
420 // SH_INIT_OUTPUT_VARIABLES and SH_INIT_GL_POSITION flags are set.
TEST_F(InitOutputVariablesWebGL2VertexShaderTest,InitGLPositionWhenNotStaticallyUsed)421 TEST_F(InitOutputVariablesWebGL2VertexShaderTest, InitGLPositionWhenNotStaticallyUsed)
422 {
423     const std::string &shaderString =
424         "#version 300 es\n"
425         "precision highp float;\n"
426         "void main() {\n"
427         "}\n";
428     compileAssumeSuccess(shaderString);
429     VerifyOutputVariableInitializers verifier(mASTRoot);
430 
431     TIntermTyped *glPosition =
432         CreateLValueNode(ImmutableString("gl_Position"), TType(EbtFloat, EbpHigh, EvqPosition, 4));
433     EXPECT_TRUE(verifier.isExpectedLValueFound(glPosition));
434     EXPECT_EQ(1u, verifier.getCandidates().size());
435 }
436 
437 // Test that gl_Position is initialized once in case it is statically used and both
438 // SH_INIT_OUTPUT_VARIABLES and SH_INIT_GL_POSITION flags are set.
TEST_F(InitOutputVariablesWebGL2VertexShaderTest,InitGLPositionOnceWhenStaticallyUsed)439 TEST_F(InitOutputVariablesWebGL2VertexShaderTest, InitGLPositionOnceWhenStaticallyUsed)
440 {
441     const std::string &shaderString =
442         "#version 300 es\n"
443         "precision highp float;\n"
444         "void main() {\n"
445         "    gl_Position = vec4(1.0);\n"
446         "}\n";
447     compileAssumeSuccess(shaderString);
448     VerifyOutputVariableInitializers verifier(mASTRoot);
449 
450     TIntermTyped *glPosition =
451         CreateLValueNode(ImmutableString("gl_Position"), TType(EbtFloat, EbpHigh, EvqPosition, 4));
452     EXPECT_TRUE(verifier.isExpectedLValueFound(glPosition));
453     EXPECT_EQ(1u, verifier.getCandidates().size());
454 }
455 
456 // Mirrors ClipDistanceTest.ThreeClipDistancesRedeclared
TEST_F(InitOutputVariablesVertexShaderClipDistanceTest,RedeclareClipDistance)457 TEST_F(InitOutputVariablesVertexShaderClipDistanceTest, RedeclareClipDistance)
458 {
459     constexpr char shaderString[] = R"(
460 #extension GL_APPLE_clip_distance : require
461 
462 varying highp float gl_ClipDistance[3];
463 
464 void computeClipDistances(in vec4 position, in vec4 plane[3])
465 {
466     gl_ClipDistance[0] = dot(position, plane[0]);
467     gl_ClipDistance[1] = dot(position, plane[1]);
468     gl_ClipDistance[2] = dot(position, plane[2]);
469 }
470 
471 uniform vec4 u_plane[3];
472 
473 attribute vec2 a_position;
474 
475 void main()
476 {
477     gl_Position = vec4(a_position, 0.0, 1.0);
478 
479     computeClipDistances(gl_Position, u_plane);
480 })";
481 
482     compileAssumeSuccess(shaderString);
483 }
484 }  // namespace sh
485