xref: /aosp_15_r20/external/angle/src/tests/test_utils/ConstantFoldingTest.h (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1*8975f5c5SAndroid Build Coastguard Worker //
2*8975f5c5SAndroid Build Coastguard Worker // Copyright 2016 The ANGLE Project Authors. All rights reserved.
3*8975f5c5SAndroid Build Coastguard Worker // Use of this source code is governed by a BSD-style license that can be
4*8975f5c5SAndroid Build Coastguard Worker // found in the LICENSE file.
5*8975f5c5SAndroid Build Coastguard Worker //
6*8975f5c5SAndroid Build Coastguard Worker // ConstantFoldingTest.h:
7*8975f5c5SAndroid Build Coastguard Worker //   Utilities for constant folding tests.
8*8975f5c5SAndroid Build Coastguard Worker //
9*8975f5c5SAndroid Build Coastguard Worker 
10*8975f5c5SAndroid Build Coastguard Worker #ifndef TESTS_TEST_UTILS_CONSTANTFOLDINGTEST_H_
11*8975f5c5SAndroid Build Coastguard Worker #define TESTS_TEST_UTILS_CONSTANTFOLDINGTEST_H_
12*8975f5c5SAndroid Build Coastguard Worker 
13*8975f5c5SAndroid Build Coastguard Worker #include <vector>
14*8975f5c5SAndroid Build Coastguard Worker 
15*8975f5c5SAndroid Build Coastguard Worker #include "common/mathutil.h"
16*8975f5c5SAndroid Build Coastguard Worker #include "compiler/translator/tree_util/FindMain.h"
17*8975f5c5SAndroid Build Coastguard Worker #include "compiler/translator/tree_util/FindSymbolNode.h"
18*8975f5c5SAndroid Build Coastguard Worker #include "compiler/translator/tree_util/IntermTraverse.h"
19*8975f5c5SAndroid Build Coastguard Worker #include "tests/test_utils/ShaderCompileTreeTest.h"
20*8975f5c5SAndroid Build Coastguard Worker 
21*8975f5c5SAndroid Build Coastguard Worker namespace sh
22*8975f5c5SAndroid Build Coastguard Worker {
23*8975f5c5SAndroid Build Coastguard Worker 
24*8975f5c5SAndroid Build Coastguard Worker class TranslatorESSL;
25*8975f5c5SAndroid Build Coastguard Worker 
26*8975f5c5SAndroid Build Coastguard Worker template <typename T>
27*8975f5c5SAndroid Build Coastguard Worker class ConstantFinder : public TIntermTraverser
28*8975f5c5SAndroid Build Coastguard Worker {
29*8975f5c5SAndroid Build Coastguard Worker   public:
ConstantFinder(const std::vector<T> & constantVector)30*8975f5c5SAndroid Build Coastguard Worker     ConstantFinder(const std::vector<T> &constantVector)
31*8975f5c5SAndroid Build Coastguard Worker         : TIntermTraverser(true, false, false),
32*8975f5c5SAndroid Build Coastguard Worker           mConstantVector(constantVector),
33*8975f5c5SAndroid Build Coastguard Worker           mFaultTolerance(T()),
34*8975f5c5SAndroid Build Coastguard Worker           mFound(false)
35*8975f5c5SAndroid Build Coastguard Worker     {}
36*8975f5c5SAndroid Build Coastguard Worker 
ConstantFinder(const std::vector<T> & constantVector,const T & faultTolerance)37*8975f5c5SAndroid Build Coastguard Worker     ConstantFinder(const std::vector<T> &constantVector, const T &faultTolerance)
38*8975f5c5SAndroid Build Coastguard Worker         : TIntermTraverser(true, false, false),
39*8975f5c5SAndroid Build Coastguard Worker           mConstantVector(constantVector),
40*8975f5c5SAndroid Build Coastguard Worker           mFaultTolerance(faultTolerance),
41*8975f5c5SAndroid Build Coastguard Worker           mFound(false)
42*8975f5c5SAndroid Build Coastguard Worker     {}
43*8975f5c5SAndroid Build Coastguard Worker 
ConstantFinder(const T & value)44*8975f5c5SAndroid Build Coastguard Worker     ConstantFinder(const T &value)
45*8975f5c5SAndroid Build Coastguard Worker         : TIntermTraverser(true, false, false), mFaultTolerance(T()), mFound(false)
46*8975f5c5SAndroid Build Coastguard Worker     {
47*8975f5c5SAndroid Build Coastguard Worker         mConstantVector.push_back(value);
48*8975f5c5SAndroid Build Coastguard Worker     }
49*8975f5c5SAndroid Build Coastguard Worker 
visitConstantUnion(TIntermConstantUnion * node)50*8975f5c5SAndroid Build Coastguard Worker     void visitConstantUnion(TIntermConstantUnion *node)
51*8975f5c5SAndroid Build Coastguard Worker     {
52*8975f5c5SAndroid Build Coastguard Worker         if (node->getType().getObjectSize() == mConstantVector.size())
53*8975f5c5SAndroid Build Coastguard Worker         {
54*8975f5c5SAndroid Build Coastguard Worker             bool found = true;
55*8975f5c5SAndroid Build Coastguard Worker             for (size_t i = 0; i < mConstantVector.size(); i++)
56*8975f5c5SAndroid Build Coastguard Worker             {
57*8975f5c5SAndroid Build Coastguard Worker                 if (!isEqual(node->getConstantValue()[i], mConstantVector[i]))
58*8975f5c5SAndroid Build Coastguard Worker                 {
59*8975f5c5SAndroid Build Coastguard Worker                     found = false;
60*8975f5c5SAndroid Build Coastguard Worker                     break;
61*8975f5c5SAndroid Build Coastguard Worker                 }
62*8975f5c5SAndroid Build Coastguard Worker             }
63*8975f5c5SAndroid Build Coastguard Worker             if (found)
64*8975f5c5SAndroid Build Coastguard Worker             {
65*8975f5c5SAndroid Build Coastguard Worker                 mFound = found;
66*8975f5c5SAndroid Build Coastguard Worker             }
67*8975f5c5SAndroid Build Coastguard Worker         }
68*8975f5c5SAndroid Build Coastguard Worker     }
69*8975f5c5SAndroid Build Coastguard Worker 
found()70*8975f5c5SAndroid Build Coastguard Worker     bool found() const { return mFound; }
71*8975f5c5SAndroid Build Coastguard Worker 
72*8975f5c5SAndroid Build Coastguard Worker   private:
isEqual(const TConstantUnion & node,const float & value)73*8975f5c5SAndroid Build Coastguard Worker     bool isEqual(const TConstantUnion &node, const float &value) const
74*8975f5c5SAndroid Build Coastguard Worker     {
75*8975f5c5SAndroid Build Coastguard Worker         if (node.getType() != EbtFloat)
76*8975f5c5SAndroid Build Coastguard Worker         {
77*8975f5c5SAndroid Build Coastguard Worker             return false;
78*8975f5c5SAndroid Build Coastguard Worker         }
79*8975f5c5SAndroid Build Coastguard Worker         if (value == std::numeric_limits<float>::infinity())
80*8975f5c5SAndroid Build Coastguard Worker         {
81*8975f5c5SAndroid Build Coastguard Worker             return gl::isInf(node.getFConst()) && node.getFConst() > 0;
82*8975f5c5SAndroid Build Coastguard Worker         }
83*8975f5c5SAndroid Build Coastguard Worker         else if (value == -std::numeric_limits<float>::infinity())
84*8975f5c5SAndroid Build Coastguard Worker         {
85*8975f5c5SAndroid Build Coastguard Worker             return gl::isInf(node.getFConst()) && node.getFConst() < 0;
86*8975f5c5SAndroid Build Coastguard Worker         }
87*8975f5c5SAndroid Build Coastguard Worker         else if (gl::isNaN(value))
88*8975f5c5SAndroid Build Coastguard Worker         {
89*8975f5c5SAndroid Build Coastguard Worker             // All NaNs are treated as equal.
90*8975f5c5SAndroid Build Coastguard Worker             return gl::isNaN(node.getFConst());
91*8975f5c5SAndroid Build Coastguard Worker         }
92*8975f5c5SAndroid Build Coastguard Worker         return mFaultTolerance >= fabsf(node.getFConst() - value);
93*8975f5c5SAndroid Build Coastguard Worker     }
94*8975f5c5SAndroid Build Coastguard Worker 
isEqual(const TConstantUnion & node,const int & value)95*8975f5c5SAndroid Build Coastguard Worker     bool isEqual(const TConstantUnion &node, const int &value) const
96*8975f5c5SAndroid Build Coastguard Worker     {
97*8975f5c5SAndroid Build Coastguard Worker         if (node.getType() != EbtInt)
98*8975f5c5SAndroid Build Coastguard Worker         {
99*8975f5c5SAndroid Build Coastguard Worker             return false;
100*8975f5c5SAndroid Build Coastguard Worker         }
101*8975f5c5SAndroid Build Coastguard Worker         ASSERT(mFaultTolerance < std::numeric_limits<int>::max());
102*8975f5c5SAndroid Build Coastguard Worker         // abs() returns 0 at least on some platforms when the minimum int value is passed in (it
103*8975f5c5SAndroid Build Coastguard Worker         // doesn't have a positive counterpart).
104*8975f5c5SAndroid Build Coastguard Worker         return mFaultTolerance >= abs(node.getIConst() - value) &&
105*8975f5c5SAndroid Build Coastguard Worker                (node.getIConst() - value) != std::numeric_limits<int>::min();
106*8975f5c5SAndroid Build Coastguard Worker     }
107*8975f5c5SAndroid Build Coastguard Worker 
isEqual(const TConstantUnion & node,const unsigned int & value)108*8975f5c5SAndroid Build Coastguard Worker     bool isEqual(const TConstantUnion &node, const unsigned int &value) const
109*8975f5c5SAndroid Build Coastguard Worker     {
110*8975f5c5SAndroid Build Coastguard Worker         if (node.getType() != EbtUInt)
111*8975f5c5SAndroid Build Coastguard Worker         {
112*8975f5c5SAndroid Build Coastguard Worker             return false;
113*8975f5c5SAndroid Build Coastguard Worker         }
114*8975f5c5SAndroid Build Coastguard Worker         ASSERT(mFaultTolerance < static_cast<unsigned int>(std::numeric_limits<int>::max()));
115*8975f5c5SAndroid Build Coastguard Worker         return static_cast<int>(mFaultTolerance) >=
116*8975f5c5SAndroid Build Coastguard Worker                    abs(static_cast<int>(node.getUConst() - value)) &&
117*8975f5c5SAndroid Build Coastguard Worker                static_cast<int>(node.getUConst() - value) != std::numeric_limits<int>::min();
118*8975f5c5SAndroid Build Coastguard Worker     }
119*8975f5c5SAndroid Build Coastguard Worker 
isEqual(const TConstantUnion & node,const bool & value)120*8975f5c5SAndroid Build Coastguard Worker     bool isEqual(const TConstantUnion &node, const bool &value) const
121*8975f5c5SAndroid Build Coastguard Worker     {
122*8975f5c5SAndroid Build Coastguard Worker         if (node.getType() != EbtBool)
123*8975f5c5SAndroid Build Coastguard Worker         {
124*8975f5c5SAndroid Build Coastguard Worker             return false;
125*8975f5c5SAndroid Build Coastguard Worker         }
126*8975f5c5SAndroid Build Coastguard Worker         return node.getBConst() == value;
127*8975f5c5SAndroid Build Coastguard Worker     }
128*8975f5c5SAndroid Build Coastguard Worker 
129*8975f5c5SAndroid Build Coastguard Worker     std::vector<T> mConstantVector;
130*8975f5c5SAndroid Build Coastguard Worker     T mFaultTolerance;
131*8975f5c5SAndroid Build Coastguard Worker     bool mFound;
132*8975f5c5SAndroid Build Coastguard Worker };
133*8975f5c5SAndroid Build Coastguard Worker 
134*8975f5c5SAndroid Build Coastguard Worker class ConstantFoldingTest : public ShaderCompileTreeTest
135*8975f5c5SAndroid Build Coastguard Worker {
136*8975f5c5SAndroid Build Coastguard Worker   public:
ConstantFoldingTest()137*8975f5c5SAndroid Build Coastguard Worker     ConstantFoldingTest() {}
138*8975f5c5SAndroid Build Coastguard Worker 
139*8975f5c5SAndroid Build Coastguard Worker   protected:
getShaderType()140*8975f5c5SAndroid Build Coastguard Worker     ::GLenum getShaderType() const override { return GL_FRAGMENT_SHADER; }
getShaderSpec()141*8975f5c5SAndroid Build Coastguard Worker     ShShaderSpec getShaderSpec() const override { return SH_GLES3_1_SPEC; }
142*8975f5c5SAndroid Build Coastguard Worker 
143*8975f5c5SAndroid Build Coastguard Worker     template <typename T>
constantFoundInAST(T constant)144*8975f5c5SAndroid Build Coastguard Worker     bool constantFoundInAST(T constant)
145*8975f5c5SAndroid Build Coastguard Worker     {
146*8975f5c5SAndroid Build Coastguard Worker         ConstantFinder<T> finder(constant);
147*8975f5c5SAndroid Build Coastguard Worker         mASTRoot->traverse(&finder);
148*8975f5c5SAndroid Build Coastguard Worker         return finder.found();
149*8975f5c5SAndroid Build Coastguard Worker     }
150*8975f5c5SAndroid Build Coastguard Worker 
151*8975f5c5SAndroid Build Coastguard Worker     template <typename T>
constantVectorFoundInAST(const std::vector<T> & constantVector)152*8975f5c5SAndroid Build Coastguard Worker     bool constantVectorFoundInAST(const std::vector<T> &constantVector)
153*8975f5c5SAndroid Build Coastguard Worker     {
154*8975f5c5SAndroid Build Coastguard Worker         ConstantFinder<T> finder(constantVector);
155*8975f5c5SAndroid Build Coastguard Worker         mASTRoot->traverse(&finder);
156*8975f5c5SAndroid Build Coastguard Worker         return finder.found();
157*8975f5c5SAndroid Build Coastguard Worker     }
158*8975f5c5SAndroid Build Coastguard Worker 
159*8975f5c5SAndroid Build Coastguard Worker     template <typename T>
constantColumnMajorMatrixFoundInAST(const std::vector<T> & constantMatrix)160*8975f5c5SAndroid Build Coastguard Worker     bool constantColumnMajorMatrixFoundInAST(const std::vector<T> &constantMatrix)
161*8975f5c5SAndroid Build Coastguard Worker     {
162*8975f5c5SAndroid Build Coastguard Worker         return constantVectorFoundInAST(constantMatrix);
163*8975f5c5SAndroid Build Coastguard Worker     }
164*8975f5c5SAndroid Build Coastguard Worker 
165*8975f5c5SAndroid Build Coastguard Worker     template <typename T>
constantVectorNearFoundInAST(const std::vector<T> & constantVector,const T & faultTolerance)166*8975f5c5SAndroid Build Coastguard Worker     bool constantVectorNearFoundInAST(const std::vector<T> &constantVector, const T &faultTolerance)
167*8975f5c5SAndroid Build Coastguard Worker     {
168*8975f5c5SAndroid Build Coastguard Worker         ConstantFinder<T> finder(constantVector, faultTolerance);
169*8975f5c5SAndroid Build Coastguard Worker         mASTRoot->traverse(&finder);
170*8975f5c5SAndroid Build Coastguard Worker         return finder.found();
171*8975f5c5SAndroid Build Coastguard Worker     }
172*8975f5c5SAndroid Build Coastguard Worker 
symbolFoundInAST(const char * symbolName)173*8975f5c5SAndroid Build Coastguard Worker     bool symbolFoundInAST(const char *symbolName)
174*8975f5c5SAndroid Build Coastguard Worker     {
175*8975f5c5SAndroid Build Coastguard Worker         return FindSymbolNode(mASTRoot, ImmutableString(symbolName)) != nullptr;
176*8975f5c5SAndroid Build Coastguard Worker     }
177*8975f5c5SAndroid Build Coastguard Worker 
symbolFoundInMain(const char * symbolName)178*8975f5c5SAndroid Build Coastguard Worker     bool symbolFoundInMain(const char *symbolName)
179*8975f5c5SAndroid Build Coastguard Worker     {
180*8975f5c5SAndroid Build Coastguard Worker         return FindSymbolNode(FindMain(mASTRoot), ImmutableString(symbolName)) != nullptr;
181*8975f5c5SAndroid Build Coastguard Worker     }
182*8975f5c5SAndroid Build Coastguard Worker };
183*8975f5c5SAndroid Build Coastguard Worker 
184*8975f5c5SAndroid Build Coastguard Worker class ConstantFoldingExpressionTest : public ConstantFoldingTest
185*8975f5c5SAndroid Build Coastguard Worker {
186*8975f5c5SAndroid Build Coastguard Worker   public:
ConstantFoldingExpressionTest()187*8975f5c5SAndroid Build Coastguard Worker     ConstantFoldingExpressionTest() {}
188*8975f5c5SAndroid Build Coastguard Worker 
189*8975f5c5SAndroid Build Coastguard Worker     void evaluateIvec4(const std::string &ivec4Expression);
190*8975f5c5SAndroid Build Coastguard Worker     void evaluateVec4(const std::string &vec4Expression);
191*8975f5c5SAndroid Build Coastguard Worker     void evaluateFloat(const std::string &floatExpression);
192*8975f5c5SAndroid Build Coastguard Worker     void evaluateInt(const std::string &intExpression);
193*8975f5c5SAndroid Build Coastguard Worker     void evaluateUint(const std::string &uintExpression);
194*8975f5c5SAndroid Build Coastguard Worker 
195*8975f5c5SAndroid Build Coastguard Worker   private:
196*8975f5c5SAndroid Build Coastguard Worker     void evaluate(const std::string &type, const std::string &expression);
197*8975f5c5SAndroid Build Coastguard Worker };
198*8975f5c5SAndroid Build Coastguard Worker 
199*8975f5c5SAndroid Build Coastguard Worker }  // namespace sh
200*8975f5c5SAndroid Build Coastguard Worker 
201*8975f5c5SAndroid Build Coastguard Worker #endif  // TESTS_TEST_UTILS_CONSTANTFOLDINGTEST_H_
202