1 //
2 // Copyright 2023 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 
6 #include "compiler/translator/tree_ops/PreTransformTextureCubeGradDerivatives.h"
7 
8 #include "compiler/translator/StaticType.h"
9 #include "compiler/translator/SymbolTable.h"
10 #include "compiler/translator/tree_util/FindFunction.h"
11 #include "compiler/translator/tree_util/IntermNode_util.h"
12 #include "compiler/translator/tree_util/IntermTraverse.h"
13 
14 namespace sh
15 {
16 
17 namespace
18 {
19 
20 constexpr ImmutableString kFunctionAGX("ANGLE_textureGradAGX");
21 
22 const TType *kBoolType   = StaticType::GetTemporary<EbtBool, EbpUndefined>();
23 const TType *kVec3Type   = StaticType::GetTemporary<EbtFloat, EbpMedium, 3>();
24 const TType *kVec4Type   = StaticType::GetTemporary<EbtFloat, EbpMedium, 4>();
25 const TType *kVec3InType = StaticType::GetQualified<EbtFloat, EbpMedium, EvqParamIn, 3>();
26 const TType *kVec4InType = StaticType::GetQualified<EbtFloat, EbpMedium, EvqParamIn, 4>();
27 
28 class PreTransformTextureCubeGradTraverser : public TIntermTraverser
29 {
30   public:
PreTransformTextureCubeGradTraverser(TSymbolTable * symbolTable,int shaderVersion)31     PreTransformTextureCubeGradTraverser(TSymbolTable *symbolTable, int shaderVersion)
32         : TIntermTraverser(true, false, false, symbolTable), mShaderVersion(shaderVersion)
33     {}
34 
getSwizzledVariable(const TVariable * source,const TVariable * xMajor,const TVariable * yMajor,TIntermBlock * body)35     const TVariable *getSwizzledVariable(const TVariable *source,
36                                          const TVariable *xMajor,
37                                          const TVariable *yMajor,
38                                          TIntermBlock *body)
39     {
40         TIntermSwizzle *sYZX       = new TIntermSwizzle(new TIntermSymbol(source), {1, 2, 0});
41         TIntermSwizzle *sXZY       = new TIntermSwizzle(new TIntermSymbol(source), {0, 2, 1});
42         TIntermSwizzle *sXYZ       = new TIntermSwizzle(new TIntermSymbol(source), {0, 1, 2});
43         TIntermTernary *secondRule = new TIntermTernary(new TIntermSymbol(yMajor), sXZY, sXYZ);
44         const TVariable *var       = CreateTempVariable(mSymbolTable, kVec3Type);
45         body->appendStatement(CreateTempInitDeclarationNode(
46             var, new TIntermTernary(new TIntermSymbol(xMajor), sYZX, secondRule)));
47         return var;
48     }
49 
getReplacementFunction(const TType & textureType,const TType & returnType)50     const TFunction *getReplacementFunction(const TType &textureType, const TType &returnType)
51     {
52         const TBasicType samplerType = textureType.getBasicType();
53         ASSERT(IsSamplerCube(samplerType));
54         if (mReplacementFunctions[samplerType] != nullptr)
55         {
56             return mReplacementFunctions[samplerType]->getFunction();
57         }
58 
59         // Sampler
60         TType *texType = new TType(textureType);
61         texType->setQualifier(EvqParamIn);
62         const TVariable *texture =
63             new TVariable(mSymbolTable, kEmptyImmutableString, texType, SymbolType::AngleInternal);
64 
65         // Direction vector
66         const TType *directionType =
67             samplerType == EbtSamplerCubeShadow ? kVec4InType : kVec3InType;
68         const TVariable *direction = new TVariable(mSymbolTable, kEmptyImmutableString,
69                                                    directionType, SymbolType::AngleInternal);
70 
71         // Derivatives
72         const TVariable *dPdx = new TVariable(mSymbolTable, kEmptyImmutableString, kVec3InType,
73                                               SymbolType::AngleInternal);
74         const TVariable *dPdy = new TVariable(mSymbolTable, kEmptyImmutableString, kVec3InType,
75                                               SymbolType::AngleInternal);
76 
77         TFunction *function =
78             new TFunction(mSymbolTable, kFunctionAGX, SymbolType::AngleInternal, &returnType, true);
79         function->addParameter(texture);
80         function->addParameter(direction);
81         function->addParameter(dPdx);
82         function->addParameter(dPdy);
83 
84         TIntermBlock *body = new TIntermBlock;
85 
86         // Select major axis. Apple GPUs have the following rules:
87         // * X wins over Y and Z
88         // * Y wins over Z
89 
90         // vec3 absDirection = abs(direction.xyz);
91         const TVariable *absDirection = CreateTempVariable(mSymbolTable, kVec3Type);
92         body->appendStatement(CreateTempInitDeclarationNode(
93             absDirection, CreateBuiltInFunctionCallNode(
94                               "abs", {new TIntermSwizzle(new TIntermSymbol(direction), {0, 1, 2})},
95                               *mSymbolTable, mShaderVersion)));
96 
97         TIntermSwizzle *absDirectionX = new TIntermSwizzle(new TIntermSymbol(absDirection), {0});
98         TIntermSwizzle *absDirectionY = new TIntermSwizzle(new TIntermSymbol(absDirection), {1});
99         TIntermSwizzle *absDirectionZ = new TIntermSwizzle(new TIntermSymbol(absDirection), {2});
100 
101         // bool xMajor = absDirection.x >= max(absDirection.y, absDirection.z);
102         const TVariable *xMajor = CreateTempVariable(mSymbolTable, kBoolType);
103         body->appendStatement(CreateTempInitDeclarationNode(
104             xMajor,
105             new TIntermBinary(EOpGreaterThanEqual, absDirectionX,
106                               CreateBuiltInFunctionCallNode("max", {absDirectionY, absDirectionZ},
107                                                             *mSymbolTable, mShaderVersion))));
108 
109         // bool yMajor = absDirection.y >= absDirection.z;
110         const TVariable *yMajor = CreateTempVariable(mSymbolTable, kBoolType);
111         body->appendStatement(CreateTempInitDeclarationNode(
112             yMajor, new TIntermBinary(EOpGreaterThanEqual, absDirectionY->deepCopy(),
113                                       absDirectionZ->deepCopy())));
114 
115         // Prepare input vectors
116 
117         // vec3 faceDirection = xMajor ? direction.yzx : (yMajor ? direction.xzy : direction.xyz);
118         const TVariable *faceDirection = getSwizzledVariable(direction, xMajor, yMajor, body);
119 
120         // vec3 dQdx = xMajor ? dPdx.yzx : (yMajor ? dPdx.xzy : dPdx);
121         const TVariable *dQdx = getSwizzledVariable(dPdx, xMajor, yMajor, body);
122 
123         // vec3 dQdy = xMajor ? dPdy.yzx : (yMajor ? dPdy.xzy : dPdy);
124         const TVariable *dQdy = getSwizzledVariable(dPdy, xMajor, yMajor, body);
125 
126         // Transform all derivatives; Q = faceDirection
127         // vec4 d = vec4(dQdx.xy, dQdy.xy) - (Q.xy / Q.z).xyxy * vec4(dQdx.zz, dQdy.zz);
128         TIntermAggregate *packXY = TIntermAggregate::CreateConstructor(
129             *kVec4Type, {new TIntermSwizzle(new TIntermSymbol(dQdx), {0, 1}),
130                          new TIntermSwizzle(new TIntermSymbol(dQdy), {0, 1})});
131         TIntermAggregate *packZZ = TIntermAggregate::CreateConstructor(
132             *kVec4Type, {new TIntermSwizzle(new TIntermSymbol(dQdx), {2, 2}),
133                          new TIntermSwizzle(new TIntermSymbol(dQdy), {2, 2})});
134         TIntermSwizzle *division = new TIntermSwizzle(
135             new TIntermBinary(EOpDiv, new TIntermSwizzle(new TIntermSymbol(faceDirection), {0, 1}),
136                               new TIntermSwizzle(new TIntermSymbol(faceDirection), {2})),
137             {0, 1, 0, 1});
138         const TVariable *d = CreateTempVariable(mSymbolTable, kVec4Type);
139         body->appendStatement(CreateTempInitDeclarationNode(
140             d, new TIntermBinary(EOpSub, packXY, new TIntermBinary(EOpMul, division, packZZ))));
141 
142         // Final swizzle to put the transformed values into target components
143         // X major: X and Z; Y major: X and Y; Z major: Y and Z
144         TIntermTernary *transformedX = new TIntermTernary(
145             new TIntermSymbol(xMajor), new TIntermSwizzle(new TIntermSymbol(d), {0, 0, 1}),
146             new TIntermSwizzle(new TIntermSymbol(d), {0, 1, 0}));
147         TIntermTernary *transformedY = new TIntermTernary(
148             new TIntermSymbol(xMajor), new TIntermSwizzle(new TIntermSymbol(d), {2, 2, 3}),
149             new TIntermSwizzle(new TIntermSymbol(d), {2, 3, 2}));
150 
151         TIntermTyped *nativeCall = CreateBuiltInFunctionCallNode(
152             mShaderVersion == 100 ? "textureCubeGradEXT" : "textureGrad",
153             {new TIntermSymbol(texture), new TIntermSymbol(direction), transformedX, transformedY},
154             *mSymbolTable, mShaderVersion);
155         body->appendStatement(new TIntermBranch(EOpReturn, nativeCall));
156 
157         mReplacementFunctions[samplerType] =
158             new TIntermFunctionDefinition(new TIntermFunctionPrototype(function), body);
159         mNewFunctionType = samplerType;
160         return function;
161     }
162 
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)163     bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
164     {
165         // Do not traverse the wrapper function
166         return node->getFunction()->name() != kFunctionAGX;
167     }
168 
visitAggregate(Visit visit,TIntermAggregate * node)169     bool visitAggregate(Visit visit, TIntermAggregate *node) override
170     {
171         if (mFound)
172         {
173             return false;
174         }
175 
176         switch (node->getOp())
177         {
178             case EOpTextureCubeGradEXT:
179             case EOpTextureGrad:
180                 break;
181             default:
182                 return true;
183         }
184 
185         TIntermSequence *parameters = node->getSequence();
186         TIntermTyped *tex           = parameters->at(0)->getAsTyped();
187         if (!IsSamplerCube(tex->getBasicType()))
188         {
189             return true;
190         }
191 
192         queueReplacement(TIntermAggregate::CreateFunctionCall(
193                              *getReplacementFunction(tex->getType(), node->getType()), parameters),
194                          OriginalNode::IS_DROPPED);
195         mFound = true;
196         return false;
197     }
198 
nextIteration()199     void nextIteration()
200     {
201         mNewFunctionType = EbtVoid;
202         mFound           = false;
203     }
204 
getNewReplacementFunction()205     TIntermFunctionDefinition *getNewReplacementFunction()
206     {
207         return mNewFunctionType != EbtVoid ? mReplacementFunctions[mNewFunctionType] : nullptr;
208     }
209 
found() const210     bool found() const { return mFound; }
211 
212   private:
213     const int mShaderVersion;
214     std::map<TBasicType, TIntermFunctionDefinition *> mReplacementFunctions;
215     TBasicType mNewFunctionType = EbtVoid;
216     bool mFound                 = false;
217 };
218 
219 }  // anonymous namespace
220 
PreTransformTextureCubeGradDerivatives(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,int shaderVersion)221 bool PreTransformTextureCubeGradDerivatives(TCompiler *compiler,
222                                             TIntermBlock *root,
223                                             TSymbolTable *symbolTable,
224                                             int shaderVersion)
225 {
226     PreTransformTextureCubeGradTraverser traverser(symbolTable, shaderVersion);
227     do
228     {
229         traverser.nextIteration();
230         root->traverse(&traverser);
231         if (traverser.found())
232         {
233             TIntermFunctionDefinition *newFunction = traverser.getNewReplacementFunction();
234             if (newFunction != nullptr)
235             {
236                 root->insertStatement(FindFirstFunctionDefinitionIndex(root), newFunction);
237             }
238 
239             if (!traverser.updateTree(compiler, root))
240             {
241                 return false;
242             }
243         }
244     } while (traverser.found());
245 
246     return true;
247 }
248 
249 }  // namespace sh
250