1 //
2 // Copyright 2023 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5
6 #include "compiler/translator/tree_ops/PreTransformTextureCubeGradDerivatives.h"
7
8 #include "compiler/translator/StaticType.h"
9 #include "compiler/translator/SymbolTable.h"
10 #include "compiler/translator/tree_util/FindFunction.h"
11 #include "compiler/translator/tree_util/IntermNode_util.h"
12 #include "compiler/translator/tree_util/IntermTraverse.h"
13
14 namespace sh
15 {
16
17 namespace
18 {
19
20 constexpr ImmutableString kFunctionAGX("ANGLE_textureGradAGX");
21
22 const TType *kBoolType = StaticType::GetTemporary<EbtBool, EbpUndefined>();
23 const TType *kVec3Type = StaticType::GetTemporary<EbtFloat, EbpMedium, 3>();
24 const TType *kVec4Type = StaticType::GetTemporary<EbtFloat, EbpMedium, 4>();
25 const TType *kVec3InType = StaticType::GetQualified<EbtFloat, EbpMedium, EvqParamIn, 3>();
26 const TType *kVec4InType = StaticType::GetQualified<EbtFloat, EbpMedium, EvqParamIn, 4>();
27
28 class PreTransformTextureCubeGradTraverser : public TIntermTraverser
29 {
30 public:
PreTransformTextureCubeGradTraverser(TSymbolTable * symbolTable,int shaderVersion)31 PreTransformTextureCubeGradTraverser(TSymbolTable *symbolTable, int shaderVersion)
32 : TIntermTraverser(true, false, false, symbolTable), mShaderVersion(shaderVersion)
33 {}
34
getSwizzledVariable(const TVariable * source,const TVariable * xMajor,const TVariable * yMajor,TIntermBlock * body)35 const TVariable *getSwizzledVariable(const TVariable *source,
36 const TVariable *xMajor,
37 const TVariable *yMajor,
38 TIntermBlock *body)
39 {
40 TIntermSwizzle *sYZX = new TIntermSwizzle(new TIntermSymbol(source), {1, 2, 0});
41 TIntermSwizzle *sXZY = new TIntermSwizzle(new TIntermSymbol(source), {0, 2, 1});
42 TIntermSwizzle *sXYZ = new TIntermSwizzle(new TIntermSymbol(source), {0, 1, 2});
43 TIntermTernary *secondRule = new TIntermTernary(new TIntermSymbol(yMajor), sXZY, sXYZ);
44 const TVariable *var = CreateTempVariable(mSymbolTable, kVec3Type);
45 body->appendStatement(CreateTempInitDeclarationNode(
46 var, new TIntermTernary(new TIntermSymbol(xMajor), sYZX, secondRule)));
47 return var;
48 }
49
getReplacementFunction(const TType & textureType,const TType & returnType)50 const TFunction *getReplacementFunction(const TType &textureType, const TType &returnType)
51 {
52 const TBasicType samplerType = textureType.getBasicType();
53 ASSERT(IsSamplerCube(samplerType));
54 if (mReplacementFunctions[samplerType] != nullptr)
55 {
56 return mReplacementFunctions[samplerType]->getFunction();
57 }
58
59 // Sampler
60 TType *texType = new TType(textureType);
61 texType->setQualifier(EvqParamIn);
62 const TVariable *texture =
63 new TVariable(mSymbolTable, kEmptyImmutableString, texType, SymbolType::AngleInternal);
64
65 // Direction vector
66 const TType *directionType =
67 samplerType == EbtSamplerCubeShadow ? kVec4InType : kVec3InType;
68 const TVariable *direction = new TVariable(mSymbolTable, kEmptyImmutableString,
69 directionType, SymbolType::AngleInternal);
70
71 // Derivatives
72 const TVariable *dPdx = new TVariable(mSymbolTable, kEmptyImmutableString, kVec3InType,
73 SymbolType::AngleInternal);
74 const TVariable *dPdy = new TVariable(mSymbolTable, kEmptyImmutableString, kVec3InType,
75 SymbolType::AngleInternal);
76
77 TFunction *function =
78 new TFunction(mSymbolTable, kFunctionAGX, SymbolType::AngleInternal, &returnType, true);
79 function->addParameter(texture);
80 function->addParameter(direction);
81 function->addParameter(dPdx);
82 function->addParameter(dPdy);
83
84 TIntermBlock *body = new TIntermBlock;
85
86 // Select major axis. Apple GPUs have the following rules:
87 // * X wins over Y and Z
88 // * Y wins over Z
89
90 // vec3 absDirection = abs(direction.xyz);
91 const TVariable *absDirection = CreateTempVariable(mSymbolTable, kVec3Type);
92 body->appendStatement(CreateTempInitDeclarationNode(
93 absDirection, CreateBuiltInFunctionCallNode(
94 "abs", {new TIntermSwizzle(new TIntermSymbol(direction), {0, 1, 2})},
95 *mSymbolTable, mShaderVersion)));
96
97 TIntermSwizzle *absDirectionX = new TIntermSwizzle(new TIntermSymbol(absDirection), {0});
98 TIntermSwizzle *absDirectionY = new TIntermSwizzle(new TIntermSymbol(absDirection), {1});
99 TIntermSwizzle *absDirectionZ = new TIntermSwizzle(new TIntermSymbol(absDirection), {2});
100
101 // bool xMajor = absDirection.x >= max(absDirection.y, absDirection.z);
102 const TVariable *xMajor = CreateTempVariable(mSymbolTable, kBoolType);
103 body->appendStatement(CreateTempInitDeclarationNode(
104 xMajor,
105 new TIntermBinary(EOpGreaterThanEqual, absDirectionX,
106 CreateBuiltInFunctionCallNode("max", {absDirectionY, absDirectionZ},
107 *mSymbolTable, mShaderVersion))));
108
109 // bool yMajor = absDirection.y >= absDirection.z;
110 const TVariable *yMajor = CreateTempVariable(mSymbolTable, kBoolType);
111 body->appendStatement(CreateTempInitDeclarationNode(
112 yMajor, new TIntermBinary(EOpGreaterThanEqual, absDirectionY->deepCopy(),
113 absDirectionZ->deepCopy())));
114
115 // Prepare input vectors
116
117 // vec3 faceDirection = xMajor ? direction.yzx : (yMajor ? direction.xzy : direction.xyz);
118 const TVariable *faceDirection = getSwizzledVariable(direction, xMajor, yMajor, body);
119
120 // vec3 dQdx = xMajor ? dPdx.yzx : (yMajor ? dPdx.xzy : dPdx);
121 const TVariable *dQdx = getSwizzledVariable(dPdx, xMajor, yMajor, body);
122
123 // vec3 dQdy = xMajor ? dPdy.yzx : (yMajor ? dPdy.xzy : dPdy);
124 const TVariable *dQdy = getSwizzledVariable(dPdy, xMajor, yMajor, body);
125
126 // Transform all derivatives; Q = faceDirection
127 // vec4 d = vec4(dQdx.xy, dQdy.xy) - (Q.xy / Q.z).xyxy * vec4(dQdx.zz, dQdy.zz);
128 TIntermAggregate *packXY = TIntermAggregate::CreateConstructor(
129 *kVec4Type, {new TIntermSwizzle(new TIntermSymbol(dQdx), {0, 1}),
130 new TIntermSwizzle(new TIntermSymbol(dQdy), {0, 1})});
131 TIntermAggregate *packZZ = TIntermAggregate::CreateConstructor(
132 *kVec4Type, {new TIntermSwizzle(new TIntermSymbol(dQdx), {2, 2}),
133 new TIntermSwizzle(new TIntermSymbol(dQdy), {2, 2})});
134 TIntermSwizzle *division = new TIntermSwizzle(
135 new TIntermBinary(EOpDiv, new TIntermSwizzle(new TIntermSymbol(faceDirection), {0, 1}),
136 new TIntermSwizzle(new TIntermSymbol(faceDirection), {2})),
137 {0, 1, 0, 1});
138 const TVariable *d = CreateTempVariable(mSymbolTable, kVec4Type);
139 body->appendStatement(CreateTempInitDeclarationNode(
140 d, new TIntermBinary(EOpSub, packXY, new TIntermBinary(EOpMul, division, packZZ))));
141
142 // Final swizzle to put the transformed values into target components
143 // X major: X and Z; Y major: X and Y; Z major: Y and Z
144 TIntermTernary *transformedX = new TIntermTernary(
145 new TIntermSymbol(xMajor), new TIntermSwizzle(new TIntermSymbol(d), {0, 0, 1}),
146 new TIntermSwizzle(new TIntermSymbol(d), {0, 1, 0}));
147 TIntermTernary *transformedY = new TIntermTernary(
148 new TIntermSymbol(xMajor), new TIntermSwizzle(new TIntermSymbol(d), {2, 2, 3}),
149 new TIntermSwizzle(new TIntermSymbol(d), {2, 3, 2}));
150
151 TIntermTyped *nativeCall = CreateBuiltInFunctionCallNode(
152 mShaderVersion == 100 ? "textureCubeGradEXT" : "textureGrad",
153 {new TIntermSymbol(texture), new TIntermSymbol(direction), transformedX, transformedY},
154 *mSymbolTable, mShaderVersion);
155 body->appendStatement(new TIntermBranch(EOpReturn, nativeCall));
156
157 mReplacementFunctions[samplerType] =
158 new TIntermFunctionDefinition(new TIntermFunctionPrototype(function), body);
159 mNewFunctionType = samplerType;
160 return function;
161 }
162
visitFunctionDefinition(Visit visit,TIntermFunctionDefinition * node)163 bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override
164 {
165 // Do not traverse the wrapper function
166 return node->getFunction()->name() != kFunctionAGX;
167 }
168
visitAggregate(Visit visit,TIntermAggregate * node)169 bool visitAggregate(Visit visit, TIntermAggregate *node) override
170 {
171 if (mFound)
172 {
173 return false;
174 }
175
176 switch (node->getOp())
177 {
178 case EOpTextureCubeGradEXT:
179 case EOpTextureGrad:
180 break;
181 default:
182 return true;
183 }
184
185 TIntermSequence *parameters = node->getSequence();
186 TIntermTyped *tex = parameters->at(0)->getAsTyped();
187 if (!IsSamplerCube(tex->getBasicType()))
188 {
189 return true;
190 }
191
192 queueReplacement(TIntermAggregate::CreateFunctionCall(
193 *getReplacementFunction(tex->getType(), node->getType()), parameters),
194 OriginalNode::IS_DROPPED);
195 mFound = true;
196 return false;
197 }
198
nextIteration()199 void nextIteration()
200 {
201 mNewFunctionType = EbtVoid;
202 mFound = false;
203 }
204
getNewReplacementFunction()205 TIntermFunctionDefinition *getNewReplacementFunction()
206 {
207 return mNewFunctionType != EbtVoid ? mReplacementFunctions[mNewFunctionType] : nullptr;
208 }
209
found() const210 bool found() const { return mFound; }
211
212 private:
213 const int mShaderVersion;
214 std::map<TBasicType, TIntermFunctionDefinition *> mReplacementFunctions;
215 TBasicType mNewFunctionType = EbtVoid;
216 bool mFound = false;
217 };
218
219 } // anonymous namespace
220
PreTransformTextureCubeGradDerivatives(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,int shaderVersion)221 bool PreTransformTextureCubeGradDerivatives(TCompiler *compiler,
222 TIntermBlock *root,
223 TSymbolTable *symbolTable,
224 int shaderVersion)
225 {
226 PreTransformTextureCubeGradTraverser traverser(symbolTable, shaderVersion);
227 do
228 {
229 traverser.nextIteration();
230 root->traverse(&traverser);
231 if (traverser.found())
232 {
233 TIntermFunctionDefinition *newFunction = traverser.getNewReplacementFunction();
234 if (newFunction != nullptr)
235 {
236 root->insertStatement(FindFirstFunctionDefinitionIndex(root), newFunction);
237 }
238
239 if (!traverser.updateTree(compiler, root))
240 {
241 return false;
242 }
243 }
244 } while (traverser.found());
245
246 return true;
247 }
248
249 } // namespace sh
250