xref: /aosp_15_r20/external/angle/src/compiler/translator/msl/TranslatorMSL.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/msl/TranslatorMSL.h"
8 
9 #include "angle_gl.h"
10 #include "common/utilities.h"
11 #include "compiler/translator/ImmutableStringBuilder.h"
12 #include "compiler/translator/Name.h"
13 #include "compiler/translator/StaticType.h"
14 #include "compiler/translator/msl/AstHelpers.h"
15 #include "compiler/translator/msl/DriverUniformMetal.h"
16 #include "compiler/translator/msl/EmitMetal.h"
17 #include "compiler/translator/msl/RewritePipelines.h"
18 #include "compiler/translator/msl/SymbolEnv.h"
19 #include "compiler/translator/msl/ToposortStructs.h"
20 #include "compiler/translator/msl/UtilsMSL.h"
21 #include "compiler/translator/tree_ops/InitializeVariables.h"
22 #include "compiler/translator/tree_ops/MonomorphizeUnsupportedFunctions.h"
23 #include "compiler/translator/tree_ops/PreTransformTextureCubeGradDerivatives.h"
24 #include "compiler/translator/tree_ops/RemoveAtomicCounterBuiltins.h"
25 #include "compiler/translator/tree_ops/RemoveInactiveInterfaceVariables.h"
26 #include "compiler/translator/tree_ops/RewriteArrayOfArrayOfOpaqueUniforms.h"
27 #include "compiler/translator/tree_ops/RewriteAtomicCounters.h"
28 #include "compiler/translator/tree_ops/RewriteDfdy.h"
29 #include "compiler/translator/tree_ops/RewriteStructSamplers.h"
30 #include "compiler/translator/tree_ops/SeparateStructFromUniformDeclarations.h"
31 #include "compiler/translator/tree_ops/msl/AddExplicitTypeCasts.h"
32 #include "compiler/translator/tree_ops/msl/ConvertUnsupportedConstructorsToFunctionCalls.h"
33 #include "compiler/translator/tree_ops/msl/FixTypeConstructors.h"
34 #include "compiler/translator/tree_ops/msl/HoistConstants.h"
35 #include "compiler/translator/tree_ops/msl/IntroduceVertexIndexID.h"
36 #include "compiler/translator/tree_ops/msl/ReduceInterfaceBlocks.h"
37 #include "compiler/translator/tree_ops/msl/RewriteCaseDeclarations.h"
38 #include "compiler/translator/tree_ops/msl/RewriteInterpolants.h"
39 #include "compiler/translator/tree_ops/msl/RewriteOutArgs.h"
40 #include "compiler/translator/tree_ops/msl/RewriteUnaddressableReferences.h"
41 #include "compiler/translator/tree_ops/msl/SeparateCompoundExpressions.h"
42 #include "compiler/translator/tree_ops/msl/WrapMain.h"
43 #include "compiler/translator/tree_util/BuiltIn.h"
44 #include "compiler/translator/tree_util/DriverUniform.h"
45 #include "compiler/translator/tree_util/FindFunction.h"
46 #include "compiler/translator/tree_util/FindMain.h"
47 #include "compiler/translator/tree_util/FindSymbolNode.h"
48 #include "compiler/translator/tree_util/IntermNode_util.h"
49 #include "compiler/translator/tree_util/ReplaceClipCullDistanceVariable.h"
50 #include "compiler/translator/tree_util/ReplaceVariable.h"
51 #include "compiler/translator/tree_util/RunAtTheBeginningOfShader.h"
52 #include "compiler/translator/tree_util/RunAtTheEndOfShader.h"
53 #include "compiler/translator/tree_util/SpecializationConstant.h"
54 #include "compiler/translator/util.h"
55 
56 namespace sh
57 {
58 
59 namespace
60 {
61 
62 constexpr Name kFlippedPointCoordName("flippedPointCoord", SymbolType::AngleInternal);
63 constexpr Name kFlippedFragCoordName("flippedFragCoord", SymbolType::AngleInternal);
64 
65 class DeclareStructTypesTraverser : public TIntermTraverser
66 {
67   public:
DeclareStructTypesTraverser(TOutputMSL * outputMSL)68     explicit DeclareStructTypesTraverser(TOutputMSL *outputMSL)
69         : TIntermTraverser(true, false, false), mOutputMSL(outputMSL)
70     {}
71 
visitDeclaration(Visit visit,TIntermDeclaration * node)72     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
73     {
74         ASSERT(visit == PreVisit);
75         if (!mInGlobalScope)
76         {
77             return false;
78         }
79 
80         const TIntermSequence &sequence = *(node->getSequence());
81         TIntermTyped *declarator        = sequence.front()->getAsTyped();
82         const TType &type               = declarator->getType();
83 
84         if (type.isStructSpecifier())
85         {
86             const TStructure *structure = type.getStruct();
87 
88             // Embedded structs should be parsed away by now.
89             ASSERT(structure->symbolType() != SymbolType::Empty);
90             // outputMSL->writeStructType(structure);
91 
92             TIntermSymbol *symbolNode = declarator->getAsSymbolNode();
93             if (symbolNode && symbolNode->variable().symbolType() == SymbolType::Empty)
94             {
95                 // Remove the struct specifier declaration from the tree so it isn't parsed again.
96                 TIntermSequence emptyReplacement;
97                 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
98                                                 std::move(emptyReplacement));
99             }
100         }
101         // TODO: REMOVE, used to remove 'unsued' warning
102         mOutputMSL = nullptr;
103 
104         return false;
105     }
106 
107   private:
108     TOutputMSL *mOutputMSL;
109 };
110 
111 class DeclareDefaultUniformsTraverser : public TIntermTraverser
112 {
113   public:
DeclareDefaultUniformsTraverser(TInfoSinkBase * sink,ShHashFunction64 hashFunction,NameMap * nameMap)114     DeclareDefaultUniformsTraverser(TInfoSinkBase *sink,
115                                     ShHashFunction64 hashFunction,
116                                     NameMap *nameMap)
117         : TIntermTraverser(true, true, true),
118           mSink(sink),
119           mHashFunction(hashFunction),
120           mNameMap(nameMap),
121           mInDefaultUniform(false)
122     {}
123 
visitDeclaration(Visit visit,TIntermDeclaration * node)124     bool visitDeclaration(Visit visit, TIntermDeclaration *node) override
125     {
126         const TIntermSequence &sequence = *(node->getSequence());
127 
128         // TODO(jmadill): Compound declarations.
129         ASSERT(sequence.size() == 1);
130 
131         TIntermTyped *variable = sequence.front()->getAsTyped();
132         const TType &type      = variable->getType();
133         bool isUniform         = type.getQualifier() == EvqUniform && !type.isInterfaceBlock() &&
134                          !IsOpaqueType(type.getBasicType());
135 
136         if (visit == PreVisit)
137         {
138             if (isUniform)
139             {
140                 (*mSink) << "    " << GetTypeName(type, mHashFunction, mNameMap) << " ";
141                 mInDefaultUniform = true;
142             }
143         }
144         else if (visit == InVisit)
145         {
146             mInDefaultUniform = isUniform;
147         }
148         else if (visit == PostVisit)
149         {
150             if (isUniform)
151             {
152                 (*mSink) << ";\n";
153 
154                 // Remove the uniform declaration from the tree so it isn't parsed again.
155                 TIntermSequence emptyReplacement;
156                 mMultiReplacements.emplace_back(getParentNode()->getAsBlock(), node,
157                                                 std::move(emptyReplacement));
158             }
159 
160             mInDefaultUniform = false;
161         }
162         return true;
163     }
164 
visitSymbol(TIntermSymbol * symbol)165     void visitSymbol(TIntermSymbol *symbol) override
166     {
167         if (mInDefaultUniform)
168         {
169             const ImmutableString &name = symbol->variable().name();
170             ASSERT(!gl::IsBuiltInName(name.data()));
171             (*mSink) << HashName(&symbol->variable(), mHashFunction, mNameMap)
172                      << ArrayString(symbol->getType());
173         }
174     }
175 
176   private:
177     TInfoSinkBase *mSink;
178     ShHashFunction64 mHashFunction;
179     NameMap *mNameMap;
180     bool mInDefaultUniform;
181 };
182 
183 // Declares a new variable to replace gl_DepthRange, its values are fed from a driver uniform.
ReplaceGLDepthRangeWithDriverUniform(TCompiler * compiler,TIntermBlock * root,const DriverUniformMetal * driverUniforms,TSymbolTable * symbolTable)184 [[nodiscard]] bool ReplaceGLDepthRangeWithDriverUniform(TCompiler *compiler,
185                                                         TIntermBlock *root,
186                                                         const DriverUniformMetal *driverUniforms,
187                                                         TSymbolTable *symbolTable)
188 {
189     // Create a symbol reference to "gl_DepthRange"
190     const TVariable *depthRangeVar = static_cast<const TVariable *>(
191         symbolTable->findBuiltIn(ImmutableString("gl_DepthRange"), 0));
192 
193     // ANGLEUniforms.depthRange
194     TIntermTyped *angleEmulatedDepthRangeRef = driverUniforms->getDepthRange();
195 
196     // Use this variable instead of gl_DepthRange everywhere.
197     return ReplaceVariableWithTyped(compiler, root, depthRangeVar, angleEmulatedDepthRangeRef);
198 }
199 
GetMainSequence(TIntermBlock * root)200 TIntermSequence *GetMainSequence(TIntermBlock *root)
201 {
202     TIntermFunctionDefinition *main = FindMain(root);
203     return main->getBody()->getSequence();
204 }
205 
206 // Replaces a builtin variable with a version that is rotated and corrects the X and Y coordinates.
FlipBuiltinVariable(TCompiler * compiler,TIntermBlock * root,TIntermSequence * insertSequence,TIntermTyped * flipXY,TSymbolTable * symbolTable,const TVariable * builtin,const Name & flippedVariableName,TIntermTyped * pivot)207 [[nodiscard]] bool FlipBuiltinVariable(TCompiler *compiler,
208                                        TIntermBlock *root,
209                                        TIntermSequence *insertSequence,
210                                        TIntermTyped *flipXY,
211                                        TSymbolTable *symbolTable,
212                                        const TVariable *builtin,
213                                        const Name &flippedVariableName,
214                                        TIntermTyped *pivot)
215 {
216     // Create a symbol reference to 'builtin'.
217     TIntermSymbol *builtinRef = new TIntermSymbol(builtin);
218 
219     // Create a swizzle to "builtin.xy"
220     TVector<int> swizzleOffsetXY = {0, 1};
221     TIntermSwizzle *builtinXY    = new TIntermSwizzle(builtinRef, swizzleOffsetXY);
222 
223     // Create a symbol reference to our new variable that will hold the modified builtin.
224     const TType *type =
225         StaticType::GetForVec<EbtFloat, EbpHigh>(EvqGlobal, builtin->getType().getNominalSize());
226     TVariable *replacementVar =
227         new TVariable(symbolTable, flippedVariableName.rawName(), type, SymbolType::AngleInternal);
228     DeclareGlobalVariable(root, replacementVar);
229     TIntermSymbol *flippedBuiltinRef = new TIntermSymbol(replacementVar);
230 
231     // Use this new variable instead of 'builtin' everywhere.
232     if (!ReplaceVariable(compiler, root, builtin, replacementVar))
233     {
234         return false;
235     }
236 
237     // Create the expression "(builtin.xy - pivot) * flipXY + pivot
238     TIntermBinary *removePivot = new TIntermBinary(EOpSub, builtinXY, pivot);
239     TIntermBinary *inverseXY   = new TIntermBinary(EOpMul, removePivot, flipXY);
240     TIntermBinary *plusPivot   = new TIntermBinary(EOpAdd, inverseXY, pivot->deepCopy());
241 
242     // Create the corrected variable and copy the value of the original builtin.
243     TIntermSequence sequence;
244     sequence.push_back(builtinRef->deepCopy());
245     TIntermAggregate *aggregate =
246         TIntermAggregate::CreateConstructor(builtin->getType(), &sequence);
247     TIntermBinary *assignment = new TIntermBinary(EOpAssign, flippedBuiltinRef, aggregate);
248 
249     // Create an assignment to the replaced variable's .xy.
250     TIntermSwizzle *correctedXY =
251         new TIntermSwizzle(flippedBuiltinRef->deepCopy(), swizzleOffsetXY);
252     TIntermBinary *assignToY = new TIntermBinary(EOpAssign, correctedXY, plusPivot);
253 
254     // Add this assigment at the beginning of the main function
255     insertSequence->insert(insertSequence->begin(), assignToY);
256     insertSequence->insert(insertSequence->begin(), assignment);
257 
258     return compiler->validateAST(root);
259 }
260 
InsertFragCoordCorrection(TCompiler * compiler,const ShCompileOptions & compileOptions,TIntermBlock * root,TIntermSequence * insertSequence,TSymbolTable * symbolTable,const DriverUniformMetal * driverUniforms)261 [[nodiscard]] bool InsertFragCoordCorrection(TCompiler *compiler,
262                                              const ShCompileOptions &compileOptions,
263                                              TIntermBlock *root,
264                                              TIntermSequence *insertSequence,
265                                              TSymbolTable *symbolTable,
266                                              const DriverUniformMetal *driverUniforms)
267 {
268     TIntermTyped *flipXY = driverUniforms->getFlipXY(symbolTable, DriverUniformFlip::Fragment);
269     TIntermTyped *pivot  = driverUniforms->getHalfRenderArea();
270 
271     const TVariable *fragCoord = static_cast<const TVariable *>(
272         symbolTable->findBuiltIn(ImmutableString("gl_FragCoord"), compiler->getShaderVersion()));
273     return FlipBuiltinVariable(compiler, root, insertSequence, flipXY, symbolTable, fragCoord,
274                                kFlippedFragCoordName, pivot);
275 }
276 
DeclareRightBeforeMain(TIntermBlock & root,const TVariable & var)277 void DeclareRightBeforeMain(TIntermBlock &root, const TVariable &var)
278 {
279     root.insertChildNodes(FindMainIndex(&root), {new TIntermDeclaration{&var}});
280 }
281 
AddFragColorDeclaration(TIntermBlock & root,TSymbolTable & symbolTable,const TVariable & var)282 void AddFragColorDeclaration(TIntermBlock &root, TSymbolTable &symbolTable, const TVariable &var)
283 {
284     root.insertChildNodes(FindMainIndex(&root), TIntermSequence{new TIntermDeclaration{&var}});
285 }
286 
AddBuiltInDeclaration(TIntermBlock & root,TSymbolTable & symbolTable,const TVariable & builtIn)287 void AddBuiltInDeclaration(TIntermBlock &root, TSymbolTable &symbolTable, const TVariable &builtIn)
288 {
289     // Check if the variable has been already declared.
290     const TIntermSymbol *builtInSymbol = new TIntermSymbol(&builtIn);
291     const TIntermSymbol *foundSymbol   = FindSymbolNode(&root, builtIn.name());
292     if (foundSymbol && foundSymbol->uniqueId() != builtInSymbol->uniqueId())
293     {
294         return;
295     }
296     root.insertChildNodes(FindMainIndex(&root), TIntermSequence{new TIntermDeclaration{&builtIn}});
297 }
298 
AddFragDepthEXTDeclaration(TCompiler & compiler,TIntermBlock & root,TSymbolTable & symbolTable)299 void AddFragDepthEXTDeclaration(TCompiler &compiler, TIntermBlock &root, TSymbolTable &symbolTable)
300 {
301     const TIntermSymbol *glFragDepthExt = FindSymbolNode(&root, ImmutableString("gl_FragDepthEXT"));
302     ASSERT(glFragDepthExt);
303     // Replace gl_FragData with our globally defined fragdata.
304     if (!ReplaceVariable(&compiler, &root, &(glFragDepthExt->variable()),
305                          BuiltInVariable::gl_FragDepth()))
306     {
307         return;
308     }
309     AddBuiltInDeclaration(root, symbolTable, *BuiltInVariable::gl_FragDepth());
310 }
311 
AddNumSamplesDeclaration(TCompiler & compiler,TIntermBlock & root,TSymbolTable & symbolTable)312 [[nodiscard]] bool AddNumSamplesDeclaration(TCompiler &compiler,
313                                             TIntermBlock &root,
314                                             TSymbolTable &symbolTable)
315 {
316     const TVariable *glNumSamples = BuiltInVariable::gl_NumSamples();
317     DeclareRightBeforeMain(root, *glNumSamples);
318 
319     // gl_NumSamples = metal::get_num_samples();
320     TIntermBinary *assignment = new TIntermBinary(
321         TOperator::EOpAssign, new TIntermSymbol(glNumSamples),
322         CreateBuiltInFunctionCallNode("numSamples", {}, symbolTable, kESSLInternalBackendBuiltIns));
323     return RunAtTheBeginningOfShader(&compiler, &root, assignment);
324 }
325 
AddSamplePositionDeclaration(TCompiler & compiler,TIntermBlock & root,TSymbolTable & symbolTable,const DriverUniformMetal * driverUniforms)326 [[nodiscard]] bool AddSamplePositionDeclaration(TCompiler &compiler,
327                                                 TIntermBlock &root,
328                                                 TSymbolTable &symbolTable,
329                                                 const DriverUniformMetal *driverUniforms)
330 {
331     const TVariable *glSamplePosition = BuiltInVariable::gl_SamplePosition();
332     DeclareRightBeforeMain(root, *glSamplePosition);
333 
334     // When rendering to a default FBO, gl_SamplePosition should
335     // be Y-flipped to match the actual sample location
336     // gl_SamplePosition = metal::get_sample_position(uint(gl_SampleID));
337     // gl_SamplePosition -= 0.5;
338     // gl_SamplePosition *= flipXY;
339     // gl_SamplePosition += 0.5;
340     TIntermBlock *block = new TIntermBlock;
341     block->appendStatement(new TIntermBinary(
342         TOperator::EOpAssign, new TIntermSymbol(glSamplePosition),
343         CreateBuiltInFunctionCallNode("samplePosition",
344                                       {TIntermAggregate::CreateConstructor(
345                                           *StaticType::GetBasic<EbtUInt, EbpHigh>(),
346                                           {new TIntermSymbol(BuiltInVariable::gl_SampleID())})},
347                                       symbolTable, kESSLInternalBackendBuiltIns)));
348     block->appendStatement(new TIntermBinary(TOperator::EOpSubAssign,
349                                              new TIntermSymbol(glSamplePosition),
350                                              CreateFloatNode(0.5f, EbpHigh)));
351     block->appendStatement(
352         new TIntermBinary(EOpMulAssign, new TIntermSymbol(glSamplePosition),
353                           driverUniforms->getFlipXY(&symbolTable, DriverUniformFlip::Fragment)));
354     block->appendStatement(new TIntermBinary(TOperator::EOpAddAssign,
355                                              new TIntermSymbol(glSamplePosition),
356                                              CreateFloatNode(0.5f, EbpHigh)));
357     return RunAtTheBeginningOfShader(&compiler, &root, block);
358 }
359 
AddSampleMaskInDeclaration(TCompiler & compiler,TIntermBlock & root,TSymbolTable & symbolTable,const DriverUniformMetal * driverUniforms,bool perSampleShading)360 [[nodiscard]] bool AddSampleMaskInDeclaration(TCompiler &compiler,
361                                               TIntermBlock &root,
362                                               TSymbolTable &symbolTable,
363                                               const DriverUniformMetal *driverUniforms,
364                                               bool perSampleShading)
365 {
366     // in highp int gl_SampleMaskIn[1]
367     const TVariable *glSampleMaskIn = static_cast<const TVariable *>(
368         symbolTable.findBuiltIn(ImmutableString("gl_SampleMaskIn"), compiler.getShaderVersion()));
369     DeclareRightBeforeMain(root, *glSampleMaskIn);
370 
371     // Reference to gl_SampleMaskIn[0]
372     TIntermBinary *glSampleMaskIn0 =
373         new TIntermBinary(EOpIndexDirect, new TIntermSymbol(glSampleMaskIn), CreateIndexNode(0));
374 
375     // When per-sample shading is active due to the use of a fragment input qualified
376     // by sample or due to the use of the gl_SampleID or gl_SamplePosition variables,
377     // only the bit for the current sample is set in gl_SampleMaskIn.
378     TIntermBlock *block = new TIntermBlock;
379     if (perSampleShading)
380     {
381         // gl_SampleMaskIn[0] = 1 << gl_SampleID;
382         block->appendStatement(new TIntermBinary(
383             EOpAssign, glSampleMaskIn0,
384             new TIntermBinary(EOpBitShiftLeft, CreateUIntNode(1),
385                               new TIntermSymbol(BuiltInVariable::gl_SampleID()))));
386     }
387     else
388     {
389         // uint32_t ANGLE_metal_SampleMaskIn [[sample_mask]]
390         TVariable *angleSampleMaskIn = new TVariable(
391             &symbolTable, ImmutableString("metal_SampleMaskIn"),
392             new TType(EbtUInt, EbpHigh, EvqSampleMaskIn, 1), SymbolType::AngleInternal);
393         DeclareRightBeforeMain(root, *angleSampleMaskIn);
394 
395         // gl_SampleMaskIn[0] = ANGLE_metal_SampleMaskIn;
396         block->appendStatement(
397             new TIntermBinary(EOpAssign, glSampleMaskIn0, new TIntermSymbol(angleSampleMaskIn)));
398     }
399 
400     // Bits in the sample mask corresponding to covered samples
401     // that will be unset due to SAMPLE_COVERAGE or SAMPLE_MASK
402     // will not be set (section 4.1.3).
403     // if (ANGLEMultisampledRendering)
404     // {
405     //      gl_SampleMaskIn[0] &= ANGLE_angleUniforms.coverageMask;
406     // }
407     TIntermBlock *coverageBlock = new TIntermBlock;
408     coverageBlock->appendStatement(new TIntermBinary(
409         EOpBitwiseAndAssign, glSampleMaskIn0->deepCopy(), driverUniforms->getCoverageMaskField()));
410 
411     TVariable *sampleMaskEnabledVar = new TVariable(
412         &symbolTable, sh::ImmutableString(mtl::kMultisampledRenderingConstName),
413         StaticType::Get<EbtBool, EbpUndefined, EvqSpecConst, 1, 1>(), SymbolType::AngleInternal);
414     block->appendStatement(
415         new TIntermIfElse(new TIntermSymbol(sampleMaskEnabledVar), coverageBlock, nullptr));
416 
417     return RunAtTheBeginningOfShader(&compiler, &root, block);
418 }
419 
AddSampleMaskDeclaration(TCompiler & compiler,TIntermBlock & root,TSymbolTable & symbolTable,const DriverUniformMetal * driverUniforms,bool includeEmulateAlphaToCoverage,bool usesSampleMask)420 [[nodiscard]] bool AddSampleMaskDeclaration(TCompiler &compiler,
421                                             TIntermBlock &root,
422                                             TSymbolTable &symbolTable,
423                                             const DriverUniformMetal *driverUniforms,
424                                             bool includeEmulateAlphaToCoverage,
425                                             bool usesSampleMask)
426 {
427     // uint32_t ANGLE_metal_SampleMask [[sample_mask]]
428     TVariable *angleSampleMask =
429         new TVariable(&symbolTable, ImmutableString("metal_SampleMask"),
430                       new TType(EbtUInt, EbpHigh, EvqSampleMask, 1), SymbolType::AngleInternal);
431     DeclareRightBeforeMain(root, *angleSampleMask);
432 
433     // Write all-enabled sample mask even for single-sampled rendering
434     // when the shader uses derivatives to workaround a driver bug.
435     if (compiler.usesDerivatives())
436     {
437         TIntermBlock *helperAssignBlock = new TIntermBlock;
438         helperAssignBlock->appendStatement(new TIntermBinary(
439             EOpAssign, new TIntermSymbol(angleSampleMask), CreateUIntNode(0xFFFFFFFFu)));
440 
441         TVariable *writeHelperSampleMaskVar =
442             new TVariable(&symbolTable, sh::ImmutableString(mtl::kWriteHelperSampleMaskConstName),
443                           StaticType::Get<EbtBool, EbpUndefined, EvqSpecConst, 1, 1>(),
444                           SymbolType::AngleInternal);
445 
446         if (!RunAtTheBeginningOfShader(
447                 &compiler, &root,
448                 new TIntermIfElse(new TIntermSymbol(writeHelperSampleMaskVar), helperAssignBlock,
449                                   nullptr)))
450         {
451             return false;
452         }
453     }
454 
455     // ANGLE_metal_SampleMask = ANGLE_angleUniforms.coverageMask;
456     TIntermBlock *block = new TIntermBlock;
457     block->appendStatement(new TIntermBinary(EOpAssign, new TIntermSymbol(angleSampleMask),
458                                              driverUniforms->getCoverageMaskField()));
459     if (usesSampleMask)
460     {
461         // out highp int gl_SampleMask[1];
462         const TVariable *glSampleMask = static_cast<const TVariable *>(
463             symbolTable.findBuiltIn(ImmutableString("gl_SampleMask"), compiler.getShaderVersion()));
464         DeclareRightBeforeMain(root, *glSampleMask);
465 
466         // ANGLE_metal_SampleMask &= gl_SampleMask[0];
467         TIntermBinary *glSampleMask0 =
468             new TIntermBinary(EOpIndexDirect, new TIntermSymbol(glSampleMask), CreateIndexNode(0));
469         block->appendStatement(new TIntermBinary(
470             EOpBitwiseAndAssign, new TIntermSymbol(angleSampleMask), glSampleMask0));
471     }
472 
473     if (includeEmulateAlphaToCoverage)
474     {
475         // Some Metal drivers ignore alpha-to-coverage state when a fragment
476         // shader writes to [[sample_mask]]. Moreover, Metal pipeline state
477         // does not support setting a global coverage mask, which would be used
478         // for emulating GL_SAMPLE_COVERAGE, so [[sample_mask]] is used instead.
479         // To support alpha-to-coverage regardless of the [[sample_mask]] usage,
480         // the former is always emulated on such drivers.
481         TIntermBlock *alphaBlock = new TIntermBlock;
482 
483         // To reduce image artifacts due to regular coverage sample locations,
484         // alpha value thresholds that toggle individual samples are slightly
485         // different within 2x2 pixel blocks. Consider MSAAx4, for example.
486         // Instead of always enabling samples on evenly distributed alpha
487         // values like {51, 102, 153, 204} these thresholds may vary as follows
488         //
489         //    Sample 0       Sample 1       Sample 2       Sample 3
490         //   ----- -----    ----- -----    ----- -----    ----- -----
491         //  |  7.5| 39.5|  | 71.5|103.5|  |135.5|167.5|  |199.5|231.5|
492         //  |----- -----|  |----- -----|  |----- -----|  |----- -----|
493         //  | 55.5| 23.5|  |119.5| 87.5|  |183.5|151.5|  |247.5|215.5|
494         //   ----- -----    ----- -----    ----- -----    ----- -----
495         // These threshold values may be expressed as
496         //    7.5 + P * 16 + 64 * sampleID
497         // where P is
498         //    ((x << 1) - (y & 1)) & 3
499         // and constant values depend on the number of samples used.
500         TVariable *p = CreateTempVariable(&symbolTable, StaticType::GetBasic<EbtInt, EbpHigh>());
501         TVariable *y = CreateTempVariable(&symbolTable, StaticType::GetBasic<EbtInt, EbpHigh>());
502         alphaBlock->appendStatement(CreateTempInitDeclarationNode(
503             p, new TIntermSwizzle(new TIntermSymbol(BuiltInVariable::gl_FragCoord()), {0})));
504         alphaBlock->appendStatement(CreateTempInitDeclarationNode(
505             y, new TIntermSwizzle(new TIntermSymbol(BuiltInVariable::gl_FragCoord()), {1})));
506         alphaBlock->appendStatement(
507             new TIntermBinary(EOpBitShiftLeftAssign, new TIntermSymbol(p), CreateIndexNode(1)));
508         alphaBlock->appendStatement(
509             new TIntermBinary(EOpBitwiseAndAssign, new TIntermSymbol(y), CreateIndexNode(1)));
510         alphaBlock->appendStatement(
511             new TIntermBinary(EOpSubAssign, new TIntermSymbol(p), new TIntermSymbol(y)));
512         alphaBlock->appendStatement(
513             new TIntermBinary(EOpBitwiseAndAssign, new TIntermSymbol(p), CreateIndexNode(3)));
514 
515         // This internal variable, defined in-text in the function constants section,
516         // will point to the alpha channel of the color zero output. Due to potential
517         // EXT_blend_func_extended usage, the exact variable may be unknown until the
518         // program is linked.
519         TVariable *alpha0 =
520             new TVariable(&symbolTable, sh::ImmutableString("ALPHA0"),
521                           StaticType::Get<EbtFloat, EbpUndefined, EvqSpecConst, 1, 1>(),
522                           SymbolType::AngleInternal);
523 
524         // Use metal::saturate to clamp the alpha value to [0.0, 1.0] and scale it
525         // to [0.0, 510.0] since further operations expect an integer alpha value.
526         TVariable *alphaScaled =
527             CreateTempVariable(&symbolTable, StaticType::GetBasic<EbtFloat, EbpHigh>());
528         alphaBlock->appendStatement(CreateTempInitDeclarationNode(
529             alphaScaled, CreateBuiltInFunctionCallNode("saturate", {new TIntermSymbol(alpha0)},
530                                                        symbolTable, kESSLInternalBackendBuiltIns)));
531         alphaBlock->appendStatement(new TIntermBinary(EOpMulAssign, new TIntermSymbol(alphaScaled),
532                                                       CreateFloatNode(510.0, EbpUndefined)));
533         // int alphaMask = int(alphaScaled);
534         TVariable *alphaMask =
535             CreateTempVariable(&symbolTable, StaticType::GetBasic<EbtInt, EbpHigh>());
536         alphaBlock->appendStatement(CreateTempInitDeclarationNode(
537             alphaMask, TIntermAggregate::CreateConstructor(*StaticType::GetBasic<EbtInt, EbpHigh>(),
538                                                            {new TIntermSymbol(alphaScaled)})));
539 
540         // Next operations depend on the number of samples in the curent render target.
541         TIntermBlock *switchBlock = new TIntermBlock();
542 
543         auto computeNumberOfSamples = [&](int step, int bias, int scale) {
544             switchBlock->appendStatement(new TIntermBinary(
545                 EOpBitShiftLeftAssign, new TIntermSymbol(p), CreateIndexNode(step)));
546             switchBlock->appendStatement(new TIntermBinary(
547                 EOpAddAssign, new TIntermSymbol(alphaMask), CreateIndexNode(bias)));
548             switchBlock->appendStatement(new TIntermBinary(
549                 EOpSubAssign, new TIntermSymbol(alphaMask), new TIntermSymbol(p)));
550             switchBlock->appendStatement(new TIntermBinary(
551                 EOpBitShiftRightAssign, new TIntermSymbol(alphaMask), CreateIndexNode(scale)));
552         };
553 
554         // MSAAx2
555         switchBlock->appendStatement(new TIntermCase(CreateIndexNode(2)));
556 
557         // Canonical threshold values are
558         //     15.5 + P * 32 + 128 * sampleID
559         // With alpha values scaled to [0, 510], the number of covered samples is
560         //     (alphaScaled + 256 - (31 + P * 64)) / 256
561         // which could be simplified to
562         //     (alphaScaled + 225 - (P << 6)) >> 8
563         computeNumberOfSamples(6, 225, 8);
564 
565         // In a case of only two samples, the coverage mask is
566         //     mask = (num_covered_samples * 3) >> 1
567         switchBlock->appendStatement(
568             new TIntermBinary(EOpMulAssign, new TIntermSymbol(alphaMask), CreateIndexNode(3)));
569         switchBlock->appendStatement(new TIntermBinary(
570             EOpBitShiftRightAssign, new TIntermSymbol(alphaMask), CreateIndexNode(1)));
571 
572         switchBlock->appendStatement(new TIntermBranch(EOpBreak, nullptr));
573 
574         // MSAAx4
575         switchBlock->appendStatement(new TIntermCase(CreateIndexNode(4)));
576 
577         // Canonical threshold values are
578         //     7.5 + P * 16 + 64 * sampleID
579         // With alpha values scaled to [0, 510], the number of covered samples is
580         //     (alphaScaled + 128 - (15 + P * 32)) / 128
581         // which could be simplified to
582         //     (alphaScaled + 113 - (P << 5)) >> 7
583         computeNumberOfSamples(5, 113, 7);
584 
585         // When two out of four samples should be covered, prioritize
586         // those that are located in the opposite corners of a pixel.
587         // 0: 0000, 1: 0001, 2: 1001, 3: 1011, 4: 1111
588         //     mask = (0xFB910 >> (num_covered_samples * 4)) & 0xF
589         // The final AND may be omitted because the rasterizer output
590         // is limited to four samples.
591         switchBlock->appendStatement(new TIntermBinary(
592             EOpBitShiftLeftAssign, new TIntermSymbol(alphaMask), CreateIndexNode(2)));
593         switchBlock->appendStatement(
594             new TIntermBinary(EOpAssign, new TIntermSymbol(alphaMask),
595                               new TIntermBinary(EOpBitShiftRight, CreateIndexNode(0xFB910),
596                                                 new TIntermSymbol(alphaMask))));
597 
598         switchBlock->appendStatement(new TIntermBranch(EOpBreak, nullptr));
599 
600         // MSAAx8
601         switchBlock->appendStatement(new TIntermCase(CreateIndexNode(8)));
602 
603         // Canonical threshold values are
604         //     3.5 + P * 8 + 32 * sampleID
605         // With alpha values scaled to [0, 510], the number of covered samples is
606         //     (alphaScaled + 64 - (7 + P * 16)) / 64
607         // which could be simplified to
608         //     (alphaScaled + 57 - (P << 4)) >> 6
609         computeNumberOfSamples(4, 57, 6);
610 
611         // When eight samples are used, they could be enabled one by one
612         //     mask = ~(0xFFFFFFFF << num_covered_samples)
613         switchBlock->appendStatement(
614             new TIntermBinary(EOpAssign, new TIntermSymbol(alphaMask),
615                               new TIntermBinary(EOpBitShiftLeft, CreateUIntNode(0xFFFFFFFFu),
616                                                 new TIntermSymbol(alphaMask))));
617         switchBlock->appendStatement(new TIntermBinary(
618             EOpAssign, new TIntermSymbol(alphaMask),
619             new TIntermUnary(EOpBitwiseNot, new TIntermSymbol(alphaMask), nullptr)));
620 
621         switchBlock->appendStatement(new TIntermBranch(EOpBreak, nullptr));
622 
623         alphaBlock->getSequence()->push_back(
624             new TIntermSwitch(CreateBuiltInFunctionCallNode("numSamples", {}, symbolTable,
625                                                             kESSLInternalBackendBuiltIns),
626                               switchBlock));
627 
628         alphaBlock->appendStatement(new TIntermBinary(
629             EOpBitwiseAndAssign, new TIntermSymbol(angleSampleMask), new TIntermSymbol(alphaMask)));
630 
631         TIntermBlock *emulateAlphaToCoverageEnabledBlock = new TIntermBlock;
632         emulateAlphaToCoverageEnabledBlock->appendStatement(
633             new TIntermIfElse(driverUniforms->getAlphaToCoverage(), alphaBlock, nullptr));
634 
635         TVariable *emulateAlphaToCoverageVar =
636             new TVariable(&symbolTable, sh::ImmutableString(mtl::kEmulateAlphaToCoverageConstName),
637                           StaticType::Get<EbtBool, EbpUndefined, EvqSpecConst, 1, 1>(),
638                           SymbolType::AngleInternal);
639         TIntermIfElse *useAlphaToCoverage =
640             new TIntermIfElse(new TIntermSymbol(emulateAlphaToCoverageVar),
641                               emulateAlphaToCoverageEnabledBlock, nullptr);
642 
643         block->appendStatement(useAlphaToCoverage);
644     }
645 
646     // Sample mask assignment is guarded by ANGLEMultisampledRendering specialization constant
647     TVariable *multisampledRenderingVar = new TVariable(
648         &symbolTable, sh::ImmutableString(mtl::kMultisampledRenderingConstName),
649         StaticType::Get<EbtBool, EbpUndefined, EvqSpecConst, 1, 1>(), SymbolType::AngleInternal);
650     return RunAtTheEndOfShader(
651         &compiler, &root,
652         new TIntermIfElse(new TIntermSymbol(multisampledRenderingVar), block, nullptr),
653         &symbolTable);
654 }
655 
AddFragDataDeclaration(TCompiler & compiler,TIntermBlock & root,bool usesSecondary,bool secondary)656 [[nodiscard]] bool AddFragDataDeclaration(TCompiler &compiler,
657                                           TIntermBlock &root,
658                                           bool usesSecondary,
659                                           bool secondary)
660 {
661     TSymbolTable &symbolTable = compiler.getSymbolTable();
662     const int maxDrawBuffers  = usesSecondary ? compiler.getResources().MaxDualSourceDrawBuffers
663                                               : compiler.getResources().MaxDrawBuffers;
664     TType *gl_FragDataType =
665         new TType(EbtFloat, EbpMedium, secondary ? EvqSecondaryFragDataEXT : EvqFragData, 4, 1);
666     std::vector<const TVariable *> glFragDataSlots;
667     TIntermSequence declareGLFragdataSequence;
668 
669     // Create gl_FragData_i or gl_SecondaryFragDataEXT_i
670     const char *fragData             = "gl_FragData";
671     const char *secondaryFragDataEXT = "gl_SecondaryFragDataEXT";
672     const char *name                 = secondary ? secondaryFragDataEXT : fragData;
673     for (int i = 0; i < maxDrawBuffers; i++)
674     {
675         ImmutableString varName = BuildConcatenatedImmutableString(name, '_', i);
676         const TVariable *glFragData =
677             new TVariable(&symbolTable, varName, gl_FragDataType, SymbolType::AngleInternal,
678                           TExtension::UNDEFINED);
679         glFragDataSlots.push_back(glFragData);
680         declareGLFragdataSequence.push_back(new TIntermDeclaration{glFragData});
681     }
682     root.insertChildNodes(FindMainIndex(&root), declareGLFragdataSequence);
683 
684     // Create an internal gl_FragData array type, compatible with indexing syntax.
685     TType *gl_FragDataTypeArray = new TType(EbtFloat, EbpMedium, EvqGlobal, 4, 1);
686     gl_FragDataTypeArray->makeArray(maxDrawBuffers);
687     const TVariable *glFragDataGlobal = new TVariable(&symbolTable, ImmutableString(name),
688                                                       gl_FragDataTypeArray, SymbolType::BuiltIn);
689 
690     DeclareGlobalVariable(&root, glFragDataGlobal);
691     const TIntermSymbol *originalGLFragData = FindSymbolNode(&root, ImmutableString(name));
692     ASSERT(originalGLFragData);
693 
694     // Replace gl_FragData[] or gl_SecondaryFragDataEXT[] with our globally defined variable
695     if (!ReplaceVariable(&compiler, &root, &(originalGLFragData->variable()), glFragDataGlobal))
696     {
697         return false;
698     }
699 
700     // Assign each array attribute to an output
701     TIntermBlock *insertSequence = new TIntermBlock();
702     for (int i = 0; i < maxDrawBuffers; i++)
703     {
704         TIntermTyped *glFragDataSlot         = new TIntermSymbol(glFragDataSlots[i]);
705         TIntermTyped *glFragDataGlobalSymbol = new TIntermSymbol(glFragDataGlobal);
706         auto &access                         = AccessIndex(*glFragDataGlobalSymbol, i);
707         TIntermBinary *assignment =
708             new TIntermBinary(TOperator::EOpAssign, glFragDataSlot, &access);
709         insertSequence->appendStatement(assignment);
710     }
711     return RunAtTheEndOfShader(&compiler, &root, insertSequence, &symbolTable);
712 }
713 
AppendVertexShaderTransformFeedbackOutputToMain(TCompiler & compiler,SymbolEnv & mSymbolEnv,TIntermBlock & root)714 [[nodiscard]] bool AppendVertexShaderTransformFeedbackOutputToMain(TCompiler &compiler,
715                                                                    SymbolEnv &mSymbolEnv,
716                                                                    TIntermBlock &root)
717 {
718     TSymbolTable &symbolTable = compiler.getSymbolTable();
719 
720     // Append the assignment as a statement at the end of the shader.
721     return RunAtTheEndOfShader(&compiler, &root,
722                                &(mSymbolEnv.callFunctionOverload(Name("@@XFB-OUT@@"), *new TType(),
723                                                                  *new TIntermSequence())),
724                                &symbolTable);
725 }
726 
727 // Unlike Vulkan having auto viewport flipping extension, in Metal we have to flip gl_Position.y
728 // manually.
729 // This operation performs flipping the gl_Position.y using this expression:
730 // gl_Position.y = gl_Position.y * negViewportScaleY
AppendVertexShaderPositionYCorrectionToMain(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,TIntermTyped * negFlipY)731 [[nodiscard]] bool AppendVertexShaderPositionYCorrectionToMain(TCompiler *compiler,
732                                                                TIntermBlock *root,
733                                                                TSymbolTable *symbolTable,
734                                                                TIntermTyped *negFlipY)
735 {
736     // Create a symbol reference to "gl_Position"
737     const TVariable *position  = BuiltInVariable::gl_Position();
738     TIntermSymbol *positionRef = new TIntermSymbol(position);
739 
740     // Create a swizzle to "gl_Position.y"
741     TVector<int> swizzleOffsetY;
742     swizzleOffsetY.push_back(1);
743     TIntermSwizzle *positionY = new TIntermSwizzle(positionRef, swizzleOffsetY);
744 
745     // Create the expression "gl_Position.y * negFlipY"
746     TIntermBinary *inverseY = new TIntermBinary(EOpMul, positionY->deepCopy(), negFlipY);
747 
748     // Create the assignment "gl_Position.y = gl_Position.y * negViewportScaleY
749     TIntermTyped *positionYLHS = positionY->deepCopy();
750     TIntermBinary *assignment  = new TIntermBinary(TOperator::EOpAssign, positionYLHS, inverseY);
751 
752     // Append the assignment as a statement at the end of the shader.
753     return RunAtTheEndOfShader(compiler, root, assignment, symbolTable);
754 }
755 
EmulateClipDistanceVaryings(TCompiler * compiler,TIntermBlock * root,TSymbolTable * symbolTable,const GLenum shaderType)756 [[nodiscard]] bool EmulateClipDistanceVaryings(TCompiler *compiler,
757                                                TIntermBlock *root,
758                                                TSymbolTable *symbolTable,
759                                                const GLenum shaderType)
760 {
761     ASSERT(shaderType == GL_VERTEX_SHADER || shaderType == GL_FRAGMENT_SHADER);
762 
763     const TIntermSymbol *symbolNode = FindSymbolNode(root, ImmutableString("gl_ClipDistance"));
764     ASSERT(symbolNode != nullptr);
765     const TVariable *clipDistanceVar = &symbolNode->variable();
766 
767     const bool fragment = shaderType == GL_FRAGMENT_SHADER;
768     if (fragment)
769     {
770         TType *globalType = new TType(EbtFloat, EbpHigh, EvqGlobal, 1, 1);
771         globalType->toArrayBaseType();
772         globalType->makeArray(compiler->getClipDistanceArraySize());
773 
774         const TVariable *globalVar = new TVariable(symbolTable, ImmutableString("ClipDistance"),
775                                                    globalType, SymbolType::AngleInternal);
776 
777         if (!ReplaceVariable(compiler, root, clipDistanceVar, globalVar))
778         {
779             return false;
780         }
781         clipDistanceVar = globalVar;
782     }
783 
784     TIntermBlock *assignBlock = new TIntermBlock();
785     size_t index              = FindMainIndex(root);
786     TType *type = new TType(EbtFloat, EbpHigh, fragment ? EvqFragmentIn : EvqVertexOut, 1, 1);
787     for (int i = 0; i < compiler->getClipDistanceArraySize(); i++)
788     {
789         TVariable *varyingVar =
790             new TVariable(symbolTable, BuildConcatenatedImmutableString("ClipDistance_", i), type,
791                           SymbolType::AngleInternal);
792         TIntermDeclaration *varyingDecl = new TIntermDeclaration();
793         varyingDecl->appendDeclarator(new TIntermSymbol(varyingVar));
794         root->insertStatement(index++, varyingDecl);
795         TIntermSymbol *varyingSym = new TIntermSymbol(varyingVar);
796         TIntermTyped *arrayAccess = new TIntermBinary(
797             EOpIndexDirect, new TIntermSymbol(clipDistanceVar), CreateIndexNode(i));
798         assignBlock->appendStatement(new TIntermBinary(
799             EOpAssign, fragment ? arrayAccess : varyingSym, fragment ? varyingSym : arrayAccess));
800     }
801 
802     return fragment ? RunAtTheBeginningOfShader(compiler, root, assignBlock)
803                     : RunAtTheEndOfShader(compiler, root, assignBlock, symbolTable);
804 }
805 }  // namespace
806 
807 namespace mtl
808 {
getTranslatorMetalReflection(const TCompiler * compiler)809 TranslatorMetalReflection *getTranslatorMetalReflection(const TCompiler *compiler)
810 {
811     return ((TranslatorMSL *)compiler)->getTranslatorMetalReflection();
812 }
813 }  // namespace mtl
TranslatorMSL(sh::GLenum type,ShShaderSpec spec,ShShaderOutput output)814 TranslatorMSL::TranslatorMSL(sh::GLenum type, ShShaderSpec spec, ShShaderOutput output)
815     : TCompiler(type, spec, output)
816 {}
817 
insertRasterizationDiscardLogic(TIntermBlock & root)818 [[nodiscard]] bool TranslatorMSL::insertRasterizationDiscardLogic(TIntermBlock &root)
819 {
820     // This transformation leaves the tree in an inconsistent state by using a variable that's
821     // defined in text, outside of the knowledge of the AST.
822     mValidateASTOptions.validateVariableReferences = false;
823 
824     TSymbolTable *symbolTable = &getSymbolTable();
825 
826     TType *boolType = new TType(EbtBool);
827     boolType->setQualifier(EvqConst);
828     TVariable *discardEnabledVar =
829         new TVariable(symbolTable, sh::ImmutableString(sh::mtl::kRasterizerDiscardEnabledConstName),
830                       boolType, SymbolType::AngleInternal);
831 
832     const TVariable *position  = BuiltInVariable::gl_Position();
833     TIntermSymbol *positionRef = new TIntermSymbol(position);
834 
835     // Create vec4(-3, -3, -3, 1):
836     auto vec4Type            = new TType(EbtFloat, 4);
837     TIntermSequence vec4Args = {
838         CreateFloatNode(-3.0f, EbpMedium),
839         CreateFloatNode(-3.0f, EbpMedium),
840         CreateFloatNode(-3.0f, EbpMedium),
841         CreateFloatNode(1.0f, EbpMedium),
842     };
843     TIntermAggregate *constVarConstructor =
844         TIntermAggregate::CreateConstructor(*vec4Type, &vec4Args);
845 
846     // Create the assignment "gl_Position = vec4(-3, -3, -3, 1)"
847     TIntermBinary *assignment =
848         new TIntermBinary(TOperator::EOpAssign, positionRef->deepCopy(), constVarConstructor);
849 
850     TIntermBlock *discardBlock = new TIntermBlock;
851     discardBlock->appendStatement(assignment);
852 
853     TIntermSymbol *discardEnabled = new TIntermSymbol(discardEnabledVar);
854     TIntermIfElse *ifCall         = new TIntermIfElse(discardEnabled, discardBlock, nullptr);
855 
856     return RunAtTheEndOfShader(this, &root, ifCall, symbolTable);
857 }
858 
859 // Metal needs to inverse the depth if depthRange is is reverse order, i.e. depth near > depth far
860 // This is achieved by multiply the depth value with scale value stored in
861 // driver uniform's depthRange.reserved
transformDepthBeforeCorrection(TIntermBlock * root,const DriverUniformMetal * driverUniforms)862 bool TranslatorMSL::transformDepthBeforeCorrection(TIntermBlock *root,
863                                                    const DriverUniformMetal *driverUniforms)
864 {
865     // Create a symbol reference to "gl_Position"
866     const TVariable *position  = BuiltInVariable::gl_Position();
867     TIntermSymbol *positionRef = new TIntermSymbol(position);
868 
869     // Create a swizzle to "gl_Position.z"
870     TVector<int> swizzleOffsetZ = {2};
871     TIntermSwizzle *positionZ   = new TIntermSwizzle(positionRef, swizzleOffsetZ);
872 
873     // Create a ref to "zscale"
874     TIntermTyped *viewportZScale = driverUniforms->getViewportZScale();
875 
876     // Create the expression "gl_Position.z * zscale".
877     TIntermBinary *zScale = new TIntermBinary(EOpMul, positionZ->deepCopy(), viewportZScale);
878 
879     // Create the assignment "gl_Position.z = gl_Position.z * zscale"
880     TIntermTyped *positionZLHS = positionZ->deepCopy();
881     TIntermBinary *assignment  = new TIntermBinary(TOperator::EOpAssign, positionZLHS, zScale);
882 
883     // Append the assignment as a statement at the end of the shader.
884     return RunAtTheEndOfShader(this, root, assignment, &getSymbolTable());
885 }
886 
887 // This operation performs the viewport depth translation needed by Metal. GL uses a
888 // clip space z range of -1 to +1 where as Metal uses 0 to 1. The translation becomes
889 // this expression
890 //
891 //     z_metal = 0.5 * (w_gl + z_gl)
892 //
893 // where z_metal is the depth output of a Metal vertex shader and z_gl is the same for GL.
894 // This operation is skipped when GL_CLIP_DEPTH_MODE_EXT is set to GL_ZERO_TO_ONE_EXT.
appendVertexShaderDepthCorrectionToMain(TIntermBlock * root,const DriverUniformMetal * driverUniforms)895 bool TranslatorMSL::appendVertexShaderDepthCorrectionToMain(
896     TIntermBlock *root,
897     const DriverUniformMetal *driverUniforms)
898 {
899     const TVariable *position  = BuiltInVariable::gl_Position();
900     TIntermSymbol *positionRef = new TIntermSymbol(position);
901 
902     TVector<int> swizzleOffsetZ = {2};
903     TIntermSwizzle *positionZ   = new TIntermSwizzle(positionRef, swizzleOffsetZ);
904 
905     TIntermConstantUnion *oneHalf = CreateFloatNode(0.5f, EbpMedium);
906 
907     TVector<int> swizzleOffsetW = {3};
908     TIntermSwizzle *positionW   = new TIntermSwizzle(positionRef->deepCopy(), swizzleOffsetW);
909 
910     // Create the expression "(gl_Position.z + gl_Position.w) * 0.5".
911     TIntermBinary *zPlusW = new TIntermBinary(EOpAdd, positionZ->deepCopy(), positionW->deepCopy());
912     TIntermBinary *halfZPlusW = new TIntermBinary(EOpMul, zPlusW, oneHalf->deepCopy());
913 
914     // Create the assignment "gl_Position.z = (gl_Position.z + gl_Position.w) * 0.5"
915     TIntermTyped *positionZLHS = positionZ->deepCopy();
916     TIntermBinary *assignment  = new TIntermBinary(TOperator::EOpAssign, positionZLHS, halfZPlusW);
917 
918     // Apply depth correction if needed
919     TIntermBlock *block = new TIntermBlock;
920     block->appendStatement(assignment);
921     TIntermIfElse *ifCall = new TIntermIfElse(driverUniforms->getTransformDepth(), block, nullptr);
922 
923     // Append the assignment as a statement at the end of the shader.
924     return RunAtTheEndOfShader(this, root, ifCall, &getSymbolTable());
925 }
926 
metalShaderTypeFromGLSL(sh::GLenum shaderType)927 static inline MetalShaderType metalShaderTypeFromGLSL(sh::GLenum shaderType)
928 {
929     switch (shaderType)
930     {
931         case GL_VERTEX_SHADER:
932             return MetalShaderType::Vertex;
933         case GL_FRAGMENT_SHADER:
934             return MetalShaderType::Fragment;
935         case GL_COMPUTE_SHADER:
936             ASSERT(0 && "compute shaders not currently supported");
937             return MetalShaderType::Compute;
938         default:
939             ASSERT(0 && "Invalid shader type.");
940             return MetalShaderType::None;
941     }
942 }
943 
translateImpl(TInfoSinkBase & sink,TIntermBlock * root,const ShCompileOptions & compileOptions,PerformanceDiagnostics *,SpecConst * specConst,DriverUniformMetal * driverUniforms)944 bool TranslatorMSL::translateImpl(TInfoSinkBase &sink,
945                                   TIntermBlock *root,
946                                   const ShCompileOptions &compileOptions,
947                                   PerformanceDiagnostics * /*perfDiagnostics*/,
948                                   SpecConst *specConst,
949                                   DriverUniformMetal *driverUniforms)
950 {
951     TSymbolTable &symbolTable = getSymbolTable();
952     IdGen idGen;
953     ProgramPreludeConfig ppc(metalShaderTypeFromGLSL(getShaderType()));
954     ppc.usesDerivatives = usesDerivatives();
955 
956     if (!WrapMain(*this, idGen, *root))
957     {
958         return false;
959     }
960 
961     // Remove declarations of inactive shader interface variables so glslang wrapper doesn't need to
962     // replace them.  Note: this is done before extracting samplers from structs, as removing such
963     // inactive samplers is not yet supported.  Note also that currently, CollectVariables marks
964     // every field of an active uniform that's of struct type as active, i.e. no extracted sampler
965     // is inactive.
966     if (!RemoveInactiveInterfaceVariables(this, root, &getSymbolTable(), getAttributes(),
967                                           getInputVaryings(), getOutputVariables(), getUniforms(),
968                                           getInterfaceBlocks(), false))
969     {
970         return false;
971     }
972 
973     // Write out default uniforms into a uniform block assigned to a specific set/binding.
974     int aggregateTypesUsedForUniforms = 0;
975     int atomicCounterCount            = 0;
976     for (const auto &uniform : getUniforms())
977     {
978         if (uniform.isStruct() || uniform.isArrayOfArrays())
979         {
980             ++aggregateTypesUsedForUniforms;
981         }
982 
983         if (uniform.active && gl::IsAtomicCounterType(uniform.type))
984         {
985             ++atomicCounterCount;
986         }
987     }
988 
989     // If there are any function calls that take array-of-array of opaque uniform parameters, or
990     // other opaque uniforms that need special handling in Vulkan, such as atomic counters,
991     // monomorphize the functions by removing said parameters and replacing them in the function
992     // body with the call arguments.
993     //
994     // This has a few benefits:
995     //
996     // - It dramatically simplifies future transformations w.r.t to samplers in structs, array of
997     //   arrays of opaque types, atomic counters etc.
998     // - Avoids the need for shader*ArrayDynamicIndexing Vulkan features.
999     UnsupportedFunctionArgsBitSet args{UnsupportedFunctionArgs::StructContainingSamplers,
1000                                        UnsupportedFunctionArgs::ArrayOfArrayOfSamplerOrImage,
1001                                        UnsupportedFunctionArgs::AtomicCounter,
1002                                        UnsupportedFunctionArgs::Image};
1003     if (!MonomorphizeUnsupportedFunctions(this, root, &getSymbolTable(), args))
1004     {
1005         return false;
1006     }
1007 
1008     if (aggregateTypesUsedForUniforms > 0)
1009     {
1010         int removedUniformsCount;
1011         if (!RewriteStructSamplers(this, root, &getSymbolTable(), &removedUniformsCount))
1012         {
1013             return false;
1014         }
1015     }
1016 
1017     // Replace array of array of opaque uniforms with a flattened array.  This is run after
1018     // MonomorphizeUnsupportedFunctions and RewriteStructSamplers so that it's not possible for an
1019     // array of array of opaque type to be partially subscripted and passed to a function.
1020     if (!RewriteArrayOfArrayOfOpaqueUniforms(this, root, &getSymbolTable()))
1021     {
1022         return false;
1023     }
1024 
1025     if (getShaderVersion() >= 300 ||
1026         IsExtensionEnabled(getExtensionBehavior(), TExtension::EXT_shader_texture_lod))
1027     {
1028         if (compileOptions.preTransformTextureCubeGradDerivatives)
1029         {
1030             if (!PreTransformTextureCubeGradDerivatives(this, root, &symbolTable,
1031                                                         getShaderVersion()))
1032             {
1033                 return false;
1034             }
1035         }
1036     }
1037 
1038     if (getShaderType() == GL_COMPUTE_SHADER)
1039     {
1040         driverUniforms->addComputeDriverUniformsToShader(root, &getSymbolTable());
1041     }
1042     else
1043     {
1044         driverUniforms->addGraphicsDriverUniformsToShader(root, &getSymbolTable());
1045     }
1046 
1047     if (atomicCounterCount > 0)
1048     {
1049         const TIntermTyped *acbBufferOffsets = driverUniforms->getAcbBufferOffsets();
1050         if (!RewriteAtomicCounters(this, root, &symbolTable, acbBufferOffsets, nullptr))
1051         {
1052             return false;
1053         }
1054     }
1055     else if (getShaderVersion() >= 310)
1056     {
1057         // Vulkan doesn't support Atomic Storage as a Storage Class, but we've seen
1058         // cases where builtins are using it even with no active atomic counters.
1059         // This pass simply removes those builtins in that scenario.
1060         if (!RemoveAtomicCounterBuiltins(this, root))
1061         {
1062             return false;
1063         }
1064     }
1065 
1066     if (getShaderType() != GL_COMPUTE_SHADER)
1067     {
1068         if (!ReplaceGLDepthRangeWithDriverUniform(this, root, driverUniforms, &getSymbolTable()))
1069         {
1070             return false;
1071         }
1072     }
1073 
1074     {
1075         bool usesInstanceId = false;
1076         bool usesVertexId   = false;
1077         for (const ShaderVariable &var : mAttributes)
1078         {
1079             if (var.isBuiltIn())
1080             {
1081                 if (var.name == "gl_InstanceID")
1082                 {
1083                     usesInstanceId = true;
1084                 }
1085                 if (var.name == "gl_VertexID")
1086                 {
1087                     usesVertexId = true;
1088                 }
1089             }
1090         }
1091 
1092         if (usesInstanceId)
1093         {
1094             root->insertChildNodes(
1095                 FindMainIndex(root),
1096                 TIntermSequence{new TIntermDeclaration{BuiltInVariable::gl_InstanceID()}});
1097         }
1098         if (usesVertexId)
1099         {
1100             AddBuiltInDeclaration(*root, symbolTable, *BuiltInVariable::gl_VertexID());
1101         }
1102     }
1103     SymbolEnv symbolEnv(*this, *root);
1104 
1105     bool usesSampleMask = false;
1106     if (getShaderType() == GL_FRAGMENT_SHADER)
1107     {
1108         bool usesPointCoord     = false;
1109         bool usesFragCoord      = false;
1110         bool usesFrontFacing    = false;
1111         bool usesSampleID       = false;
1112         bool usesSamplePosition = false;
1113         bool usesSampleMaskIn   = false;
1114         for (const ShaderVariable &inputVarying : mInputVaryings)
1115         {
1116             if (inputVarying.isBuiltIn())
1117             {
1118                 if (inputVarying.name == "gl_PointCoord")
1119                 {
1120                     usesPointCoord = true;
1121                 }
1122                 else if (inputVarying.name == "gl_FragCoord")
1123                 {
1124                     usesFragCoord = true;
1125                 }
1126                 else if (inputVarying.name == "gl_FrontFacing")
1127                 {
1128                     usesFrontFacing = true;
1129                 }
1130                 else if (inputVarying.name == "gl_SampleID")
1131                 {
1132                     usesSampleID = true;
1133                 }
1134                 else if (inputVarying.name == "gl_SamplePosition")
1135                 {
1136                     usesSampleID       = true;
1137                     usesSamplePosition = true;
1138                 }
1139                 else if (inputVarying.name == "gl_SampleMaskIn")
1140                 {
1141                     usesSampleMaskIn = true;
1142                 }
1143             }
1144         }
1145 
1146         bool usesFragColor             = false;
1147         bool usesFragData              = false;
1148         bool usesFragDepth             = false;
1149         bool usesFragDepthEXT          = false;
1150         bool usesSecondaryFragColorEXT = false;
1151         bool usesSecondaryFragDataEXT  = false;
1152         for (const ShaderVariable &outputVarying : mOutputVariables)
1153         {
1154             if (outputVarying.isBuiltIn())
1155             {
1156                 if (outputVarying.name == "gl_FragColor")
1157                 {
1158                     usesFragColor = true;
1159                 }
1160                 else if (outputVarying.name == "gl_FragData")
1161                 {
1162                     usesFragData = true;
1163                 }
1164                 else if (outputVarying.name == "gl_FragDepth")
1165                 {
1166                     usesFragDepth = true;
1167                 }
1168                 else if (outputVarying.name == "gl_FragDepthEXT")
1169                 {
1170                     usesFragDepthEXT = true;
1171                 }
1172                 else if (outputVarying.name == "gl_SecondaryFragColorEXT")
1173                 {
1174                     usesSecondaryFragColorEXT = true;
1175                 }
1176                 else if (outputVarying.name == "gl_SecondaryFragDataEXT")
1177                 {
1178                     usesSecondaryFragDataEXT = true;
1179                 }
1180                 else if (outputVarying.name == "gl_SampleMask")
1181                 {
1182                     usesSampleMask = true;
1183                 }
1184             }
1185         }
1186 
1187         // A shader may assign values to either the set of gl_FragColor and gl_SecondaryFragColorEXT
1188         // or the set of gl_FragData and gl_SecondaryFragDataEXT, but not both.
1189         ASSERT((!usesFragColor && !usesSecondaryFragColorEXT) ||
1190                (!usesFragData && !usesSecondaryFragDataEXT));
1191 
1192         if (usesFragColor)
1193         {
1194             AddFragColorDeclaration(*root, symbolTable, *BuiltInVariable::gl_FragColor());
1195         }
1196         else if (usesFragData)
1197         {
1198             if (!AddFragDataDeclaration(*this, *root, usesSecondaryFragDataEXT, false))
1199             {
1200                 return false;
1201             }
1202         }
1203 
1204         if (usesFragDepth)
1205         {
1206             AddBuiltInDeclaration(*root, symbolTable, *BuiltInVariable::gl_FragDepth());
1207         }
1208         else if (usesFragDepthEXT)
1209         {
1210             AddFragDepthEXTDeclaration(*this, *root, symbolTable);
1211         }
1212 
1213         if (usesSecondaryFragColorEXT)
1214         {
1215             AddFragColorDeclaration(*root, symbolTable,
1216                                     *BuiltInVariable::gl_SecondaryFragColorEXT());
1217         }
1218         else if (usesSecondaryFragDataEXT)
1219         {
1220             if (!AddFragDataDeclaration(*this, *root, usesSecondaryFragDataEXT, true))
1221             {
1222                 return false;
1223             }
1224         }
1225 
1226         bool usesSampleInterpolation = false;
1227         bool usesSampleInterpolant   = false;
1228         if ((getShaderVersion() >= 320 ||
1229              IsExtensionEnabled(getExtensionBehavior(),
1230                                 TExtension::OES_shader_multisample_interpolation)) &&
1231             !RewriteInterpolants(*this, *root, symbolTable, driverUniforms,
1232                                  &usesSampleInterpolation, &usesSampleInterpolant))
1233         {
1234             return false;
1235         }
1236 
1237         if (usesSampleID || (usesSampleMaskIn && usesSampleInterpolation) || usesSampleInterpolant)
1238         {
1239             DeclareRightBeforeMain(*root, *BuiltInVariable::gl_SampleID());
1240         }
1241 
1242         if (usesSamplePosition)
1243         {
1244             if (!AddSamplePositionDeclaration(*this, *root, symbolTable, driverUniforms))
1245             {
1246                 return false;
1247             }
1248         }
1249 
1250         if (usesSampleMaskIn)
1251         {
1252             if (!AddSampleMaskInDeclaration(*this, *root, symbolTable, driverUniforms,
1253                                             usesSampleID || usesSampleInterpolation))
1254             {
1255                 return false;
1256             }
1257         }
1258 
1259         if (usesPointCoord)
1260         {
1261             TIntermTyped *flipNegXY =
1262                 driverUniforms->getNegFlipXY(&getSymbolTable(), DriverUniformFlip::Fragment);
1263             TIntermConstantUnion *pivot = CreateFloatNode(0.5f, EbpMedium);
1264             if (!FlipBuiltinVariable(this, root, GetMainSequence(root), flipNegXY,
1265                                      &getSymbolTable(), BuiltInVariable::gl_PointCoord(),
1266                                      kFlippedPointCoordName, pivot))
1267             {
1268                 return false;
1269             }
1270             DeclareRightBeforeMain(*root, *BuiltInVariable::gl_PointCoord());
1271         }
1272 
1273         if (usesFragCoord || compileOptions.emulateAlphaToCoverage ||
1274             compileOptions.metal.generateShareableShaders)
1275         {
1276             if (!InsertFragCoordCorrection(this, compileOptions, root, GetMainSequence(root),
1277                                            &getSymbolTable(), driverUniforms))
1278             {
1279                 return false;
1280             }
1281             const TVariable *fragCoord = static_cast<const TVariable *>(
1282                 getSymbolTable().findBuiltIn(ImmutableString("gl_FragCoord"), getShaderVersion()));
1283             DeclareRightBeforeMain(*root, *fragCoord);
1284         }
1285 
1286         if (!RewriteDfdy(this, root, &getSymbolTable(), getShaderVersion(), specConst,
1287                          driverUniforms))
1288         {
1289             return false;
1290         }
1291 
1292         if (getClipDistanceArraySize())
1293         {
1294             if (!EmulateClipDistanceVaryings(this, root, &getSymbolTable(), getShaderType()))
1295             {
1296                 return false;
1297             }
1298         }
1299 
1300         if (usesFrontFacing)
1301         {
1302             DeclareRightBeforeMain(*root, *BuiltInVariable::gl_FrontFacing());
1303         }
1304 
1305         bool usesNumSamples = false;
1306         for (const ShaderVariable &uniform : mUniforms)
1307         {
1308             if (uniform.name == "gl_NumSamples")
1309             {
1310                 usesNumSamples = true;
1311                 break;
1312             }
1313         }
1314 
1315         if (usesNumSamples)
1316         {
1317             if (!AddNumSamplesDeclaration(*this, *root, symbolTable))
1318             {
1319                 return false;
1320             }
1321         }
1322     }
1323     else if (getShaderType() == GL_VERTEX_SHADER)
1324     {
1325         DeclareRightBeforeMain(*root, *BuiltInVariable::gl_Position());
1326 
1327         if (FindSymbolNode(root, BuiltInVariable::gl_PointSize()->name()))
1328         {
1329             const TVariable *pointSize = static_cast<const TVariable *>(
1330                 getSymbolTable().findBuiltIn(ImmutableString("gl_PointSize"), getShaderVersion()));
1331             DeclareRightBeforeMain(*root, *pointSize);
1332         }
1333 
1334         // Append a macro for transform feedback substitution prior to modifying depth.
1335         if (!AppendVertexShaderTransformFeedbackOutputToMain(*this, symbolEnv, *root))
1336         {
1337             return false;
1338         }
1339 
1340         if (getClipDistanceArraySize())
1341         {
1342             if (!ZeroDisabledClipDistanceAssignments(this, root, &getSymbolTable(), getShaderType(),
1343                                                      driverUniforms->getClipDistancesEnabled()))
1344             {
1345                 return false;
1346             }
1347 
1348             if (IsExtensionEnabled(getExtensionBehavior(), TExtension::ANGLE_clip_cull_distance) &&
1349                 !EmulateClipDistanceVaryings(this, root, &getSymbolTable(), getShaderType()))
1350             {
1351                 return false;
1352             }
1353         }
1354 
1355         if (!transformDepthBeforeCorrection(root, driverUniforms))
1356         {
1357             return false;
1358         }
1359 
1360         if (!appendVertexShaderDepthCorrectionToMain(root, driverUniforms))
1361         {
1362             return false;
1363         }
1364     }
1365 
1366     if (getShaderType() == GL_VERTEX_SHADER)
1367     {
1368         TIntermTyped *flipNegY =
1369             driverUniforms->getFlipXY(&getSymbolTable(), DriverUniformFlip::PreFragment);
1370         flipNegY = (new TIntermSwizzle(flipNegY, {1}))->fold(nullptr);
1371 
1372         if (!AppendVertexShaderPositionYCorrectionToMain(this, root, &getSymbolTable(), flipNegY))
1373         {
1374             return false;
1375         }
1376         if (!insertRasterizationDiscardLogic(*root))
1377         {
1378             return false;
1379         }
1380     }
1381     else if (getShaderType() == GL_FRAGMENT_SHADER)
1382     {
1383         mValidateASTOptions.validateVariableReferences = false;
1384         if (!AddSampleMaskDeclaration(*this, *root, symbolTable, driverUniforms,
1385                                       compileOptions.emulateAlphaToCoverage ||
1386                                           compileOptions.metal.generateShareableShaders,
1387                                       usesSampleMask))
1388         {
1389             return false;
1390         }
1391     }
1392 
1393     if (!validateAST(root))
1394     {
1395         return false;
1396     }
1397 
1398     // This is the largest size required to pass all the tests in
1399     // (dEQP-GLES3.functional.shaders.large_constant_arrays)
1400     // This value could in principle be smaller.
1401     const size_t hoistThresholdSize = 256;
1402     if (!HoistConstants(*this, *root, idGen, hoistThresholdSize))
1403     {
1404         return false;
1405     }
1406 
1407     if (!ConvertUnsupportedConstructorsToFunctionCalls(*this, *root))
1408     {
1409         return false;
1410     }
1411 
1412     const bool needsExplicitBoolCasts = compileOptions.addExplicitBoolCasts;
1413     if (!AddExplicitTypeCasts(*this, *root, symbolEnv, needsExplicitBoolCasts))
1414     {
1415         return false;
1416     }
1417 
1418     if (!SeparateCompoundExpressions(*this, symbolEnv, idGen, *root))
1419     {
1420         return false;
1421     }
1422 
1423     if (!ReduceInterfaceBlocks(*this, *root, idGen))
1424     {
1425         return false;
1426     }
1427 
1428     // The RewritePipelines phase leaves the tree in an inconsistent state by inserting
1429     // references to structures like "ANGLE_TextureEnv<metal::texture2d<float>>" which are
1430     // defined in text (in ProgramPrelude), outside of the knowledge of the AST.
1431     mValidateASTOptions.validateStructUsage = false;
1432     // The RewritePipelines phase also generates incoming arguments to synthesized
1433     // functions that use are missing qualifiers - for example, angleUniforms isn't marked
1434     // as an incoming argument.
1435     mValidateASTOptions.validateQualifiers = false;
1436 
1437     PipelineStructs pipelineStructs;
1438     if (!RewritePipelines(*this, *root, getInputVaryings(), getOutputVaryings(), idGen,
1439                           *driverUniforms, symbolEnv, pipelineStructs))
1440     {
1441         return false;
1442     }
1443     if (getShaderType() == GL_VERTEX_SHADER)
1444     {
1445         // This has to happen after RewritePipelines.
1446         if (!IntroduceVertexAndInstanceIndex(*this, *root))
1447         {
1448             return false;
1449         }
1450     }
1451 
1452     if (!RewriteCaseDeclarations(*this, *root))
1453     {
1454         return false;
1455     }
1456 
1457     if (!RewriteUnaddressableReferences(*this, *root, symbolEnv))
1458     {
1459         return false;
1460     }
1461 
1462     if (!RewriteOutArgs(*this, *root, symbolEnv))
1463     {
1464         return false;
1465     }
1466     if (!FixTypeConstructors(*this, symbolEnv, *root))
1467     {
1468         return false;
1469     }
1470     if (!ToposortStructs(*this, symbolEnv, *root, ppc))
1471     {
1472         return false;
1473     }
1474     if (!EmitMetal(*this, *root, idGen, pipelineStructs, symbolEnv, ppc, compileOptions))
1475     {
1476         return false;
1477     }
1478 
1479     ASSERT(validateAST(root));
1480 
1481     return true;
1482 }
1483 
translate(TIntermBlock * root,const ShCompileOptions & compileOptions,PerformanceDiagnostics * perfDiagnostics)1484 bool TranslatorMSL::translate(TIntermBlock *root,
1485                               const ShCompileOptions &compileOptions,
1486                               PerformanceDiagnostics *perfDiagnostics)
1487 {
1488     if (!root)
1489     {
1490         return false;
1491     }
1492 
1493     // TODO: refactor the code in TranslatorMSL to not issue raw function calls.
1494     // http://anglebug.com/42264589#comment3
1495     mValidateASTOptions.validateNoRawFunctionCalls = false;
1496     // A validation error is generated in this backend due to bool uniforms.
1497     mValidateASTOptions.validatePrecision = false;
1498 
1499     TInfoSinkBase &sink = getInfoSink().obj;
1500     SpecConst specConst(&getSymbolTable(), compileOptions, getShaderType());
1501     DriverUniformMetal driverUniforms(DriverUniformMode::Structure);
1502     if (!translateImpl(sink, root, compileOptions, perfDiagnostics, &specConst, &driverUniforms))
1503     {
1504         return false;
1505     }
1506 
1507     return true;
1508 }
shouldFlattenPragmaStdglInvariantAll()1509 bool TranslatorMSL::shouldFlattenPragmaStdglInvariantAll()
1510 {
1511     // Not neccesary for MSL transformation.
1512     return false;
1513 }
1514 
1515 }  // namespace sh
1516