xref: /aosp_15_r20/external/angle/src/compiler/translator/msl/SymbolEnv.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <algorithm>
8 #include <limits>
9 
10 #include "compiler/translator/ImmutableStringBuilder.h"
11 #include "compiler/translator/IntermRebuild.h"
12 #include "compiler/translator/msl/AstHelpers.h"
13 #include "compiler/translator/msl/SymbolEnv.h"
14 #include "compiler/translator/util.h"
15 
16 using namespace sh;
17 
18 ////////////////////////////////////////////////////////////////////////////////
19 
20 constexpr AddressSpace kAddressSpaces[] = {
21     AddressSpace::Constant,
22     AddressSpace::Device,
23     AddressSpace::Thread,
24 };
25 
toString(AddressSpace space)26 char const *sh::toString(AddressSpace space)
27 {
28     switch (space)
29     {
30         case AddressSpace::Constant:
31             return "constant";
32         case AddressSpace::Device:
33             return "device";
34         case AddressSpace::Thread:
35             return "thread";
36     }
37 }
38 
39 ////////////////////////////////////////////////////////////////////////////////
40 
41 using NameToStruct = std::map<Name, const TStructure *>;
42 
43 class StructFinder : TIntermRebuild
44 {
45     NameToStruct nameToStruct;
46 
StructFinder(TCompiler & compiler)47     StructFinder(TCompiler &compiler) : TIntermRebuild(compiler, true, false) {}
48 
visitDeclarationPre(TIntermDeclaration & node)49     PreResult visitDeclarationPre(TIntermDeclaration &node) override
50     {
51         Declaration decl     = ViewDeclaration(node);
52         const TVariable &var = decl.symbol.variable();
53         const TType &type    = var.getType();
54 
55         if (var.symbolType() == SymbolType::Empty && type.isStructSpecifier())
56         {
57             const TStructure *s = type.getStruct();
58             ASSERT(s);
59             const Name name(*s);
60             const TStructure *&z = nameToStruct[name];
61             ASSERT(!z);
62             z = s;
63         }
64 
65         return node;
66     }
67 
visitFunctionDefinitionPre(TIntermFunctionDefinition & node)68     PreResult visitFunctionDefinitionPre(TIntermFunctionDefinition &node) override
69     {
70         return {node, VisitBits::Neither};
71     }
72 
73   public:
FindStructs(TCompiler & compiler,TIntermBlock & root)74     static NameToStruct FindStructs(TCompiler &compiler, TIntermBlock &root)
75     {
76         StructFinder finder(compiler);
77         if (!finder.rebuildRoot(root))
78         {
79             UNREACHABLE();
80         }
81         return std::move(finder.nameToStruct);
82     }
83 };
84 
85 ////////////////////////////////////////////////////////////////////////////////
86 
TemplateArg(bool value)87 TemplateArg::TemplateArg(bool value) : mKind(Kind::Bool), mValue(value) {}
88 
TemplateArg(int value)89 TemplateArg::TemplateArg(int value) : mKind(Kind::Int), mValue(value) {}
90 
TemplateArg(unsigned value)91 TemplateArg::TemplateArg(unsigned value) : mKind(Kind::UInt), mValue(value) {}
92 
TemplateArg(const TType & value)93 TemplateArg::TemplateArg(const TType &value) : mKind(Kind::Type), mValue(value) {}
94 
operator ==(const TemplateArg & other) const95 bool TemplateArg::operator==(const TemplateArg &other) const
96 {
97     if (mKind != other.mKind)
98     {
99         return false;
100     }
101 
102     switch (mKind)
103     {
104         case Kind::Bool:
105             return mValue.b == other.mValue.b;
106         case Kind::Int:
107             return mValue.i == other.mValue.i;
108         case Kind::UInt:
109             return mValue.u == other.mValue.u;
110         case Kind::Type:
111             return *mValue.t == *other.mValue.t;
112     }
113 }
114 
operator <(const TemplateArg & other) const115 bool TemplateArg::operator<(const TemplateArg &other) const
116 {
117     if (mKind < other.mKind)
118     {
119         return true;
120     }
121 
122     if (mKind > other.mKind)
123     {
124         return false;
125     }
126 
127     switch (mKind)
128     {
129         case Kind::Bool:
130             return mValue.b < other.mValue.b;
131         case Kind::Int:
132             return mValue.i < other.mValue.i;
133         case Kind::UInt:
134             return mValue.u < other.mValue.u;
135         case Kind::Type:
136             return *mValue.t < *other.mValue.t;
137     }
138 }
139 
140 ////////////////////////////////////////////////////////////////////////////////
141 
operator ==(const TemplateName & other) const142 bool SymbolEnv::TemplateName::operator==(const TemplateName &other) const
143 {
144     return baseName == other.baseName && templateArgs == other.templateArgs;
145 }
146 
operator <(const TemplateName & other) const147 bool SymbolEnv::TemplateName::operator<(const TemplateName &other) const
148 {
149     if (baseName < other.baseName)
150     {
151         return true;
152     }
153     if (other.baseName < baseName)
154     {
155         return false;
156     }
157     return templateArgs < other.templateArgs;
158 }
159 
empty() const160 bool SymbolEnv::TemplateName::empty() const
161 {
162     return baseName.empty() && templateArgs.empty();
163 }
164 
clear()165 void SymbolEnv::TemplateName::clear()
166 {
167     baseName = Name();
168     templateArgs.clear();
169 }
170 
fullName(std::string & buffer) const171 Name SymbolEnv::TemplateName::fullName(std::string &buffer) const
172 {
173     ASSERT(buffer.empty());
174 
175     if (templateArgs.empty())
176     {
177         return baseName;
178     }
179 
180     static constexpr size_t n = std::max({
181         std::numeric_limits<unsigned>::digits10,  //
182         std::numeric_limits<int>::digits10,       //
183         5,                                        // max_length("true", "false")
184     });
185 
186     buffer.reserve(baseName.rawName().length() + (n + 2) * templateArgs.size() + 1);
187     buffer += baseName.rawName().data();
188 
189     if (!templateArgs.empty())
190     {
191         buffer += "<";
192 
193         bool first = true;
194         char argBuffer[n + 1];
195         for (const TemplateArg &arg : templateArgs)
196         {
197             if (first)
198             {
199                 first = false;
200             }
201             else
202             {
203                 buffer += ", ";
204             }
205 
206             const TemplateArg::Value value = arg.value();
207             const TemplateArg::Kind kind   = arg.kind();
208             switch (kind)
209             {
210                 case TemplateArg::Kind::Bool:
211                     if (value.b)
212                     {
213                         buffer += "true";
214                     }
215                     else
216                     {
217                         buffer += "false";
218                     }
219                     break;
220 
221                 case TemplateArg::Kind::Int:
222                     snprintf(argBuffer, sizeof(argBuffer), "%i", value.i);
223                     buffer += argBuffer;
224                     break;
225 
226                 case TemplateArg::Kind::UInt:
227                     snprintf(argBuffer, sizeof(argBuffer), "%u", value.u);
228                     buffer += argBuffer;
229                     break;
230 
231                 case TemplateArg::Kind::Type:
232                 {
233                     const TType &type = *value.t;
234                     if (const TStructure *s = type.getStruct())
235                     {
236                         buffer += s->name().data();
237                     }
238                     else if (HasScalarBasicType(type))
239                     {
240                         ASSERT(!type.isArray());  // TODO
241                         buffer += type.getBasicString();
242                         if (type.isVector())
243                         {
244                             snprintf(argBuffer, sizeof(argBuffer), "%u", type.getNominalSize());
245                             buffer += argBuffer;
246                         }
247                         else if (type.isMatrix())
248                         {
249                             snprintf(argBuffer, sizeof(argBuffer), "%u", type.getCols());
250                             buffer += argBuffer;
251                             buffer += "x";
252                             snprintf(argBuffer, sizeof(argBuffer), "%u", type.getRows());
253                             buffer += argBuffer;
254                         }
255                     }
256                 }
257                 break;
258             }
259         }
260 
261         buffer += ">";
262     }
263 
264     const ImmutableString name(buffer);
265     buffer.clear();
266 
267     return Name(name, baseName.symbolType());
268 }
269 
assign(const Name & name,size_t argCount,const TemplateArg * args)270 void SymbolEnv::TemplateName::assign(const Name &name, size_t argCount, const TemplateArg *args)
271 {
272     baseName = name;
273     templateArgs.clear();
274     for (size_t i = 0; i < argCount; ++i)
275     {
276         templateArgs.push_back(args[i]);
277     }
278 }
279 
280 ////////////////////////////////////////////////////////////////////////////////
281 
SymbolEnv(TCompiler & compiler,TIntermBlock & root)282 SymbolEnv::SymbolEnv(TCompiler &compiler, TIntermBlock &root)
283     : mSymbolTable(compiler.getSymbolTable()),
284       mNameToStruct(StructFinder::FindStructs(compiler, root))
285 {}
286 
remap(const TStructure & s) const287 const TStructure &SymbolEnv::remap(const TStructure &s) const
288 {
289     const Name name(s);
290     auto iter = mNameToStruct.find(name);
291     if (iter == mNameToStruct.end())
292     {
293         return s;
294     }
295     const TStructure &z = *iter->second;
296     return z;
297 }
298 
remap(const TStructure * s) const299 const TStructure *SymbolEnv::remap(const TStructure *s) const
300 {
301     if (s)
302     {
303         return &remap(*s);
304     }
305     return nullptr;
306 }
307 
getFunctionOverloadImpl()308 const TFunction &SymbolEnv::getFunctionOverloadImpl()
309 {
310     ASSERT(!mReusableSigBuffer.empty());
311 
312     SigToFunc &sigToFunc = mOverloads[mReusableTemplateNameBuffer];
313     TFunction *&func     = sigToFunc[mReusableSigBuffer];
314 
315     if (!func)
316     {
317         const TType &returnType = mReusableSigBuffer.back();
318         mReusableSigBuffer.pop_back();
319 
320         const Name name = mReusableTemplateNameBuffer.fullName(mReusableStringBuffer);
321 
322         func = new TFunction(&mSymbolTable, name.rawName(), name.symbolType(), &returnType, false);
323         for (const TType &paramType : mReusableSigBuffer)
324         {
325             func->addParameter(
326                 new TVariable(&mSymbolTable, kEmptyImmutableString, &paramType, SymbolType::Empty));
327         }
328     }
329 
330     mReusableSigBuffer.clear();
331     mReusableTemplateNameBuffer.clear();
332 
333     return *func;
334 }
335 
getFunctionOverload(const Name & name,const TType & returnType,size_t paramCount,const TType ** paramTypes,size_t templateArgCount,const TemplateArg * templateArgs)336 const TFunction &SymbolEnv::getFunctionOverload(const Name &name,
337                                                 const TType &returnType,
338                                                 size_t paramCount,
339                                                 const TType **paramTypes,
340                                                 size_t templateArgCount,
341                                                 const TemplateArg *templateArgs)
342 {
343     ASSERT(mReusableSigBuffer.empty());
344     ASSERT(mReusableTemplateNameBuffer.empty());
345 
346     for (size_t i = 0; i < paramCount; ++i)
347     {
348         mReusableSigBuffer.push_back(*paramTypes[i]);
349     }
350     mReusableSigBuffer.push_back(returnType);
351     mReusableTemplateNameBuffer.assign(name, templateArgCount, templateArgs);
352     return getFunctionOverloadImpl();
353 }
354 
callFunctionOverload(const Name & name,const TType & returnType,TIntermSequence & args,size_t templateArgCount,const TemplateArg * templateArgs)355 TIntermAggregate &SymbolEnv::callFunctionOverload(const Name &name,
356                                                   const TType &returnType,
357                                                   TIntermSequence &args,
358                                                   size_t templateArgCount,
359                                                   const TemplateArg *templateArgs)
360 {
361     ASSERT(mReusableSigBuffer.empty());
362     ASSERT(mReusableTemplateNameBuffer.empty());
363 
364     for (TIntermNode *arg : args)
365     {
366         TIntermTyped *targ = arg->getAsTyped();
367         ASSERT(targ);
368         mReusableSigBuffer.push_back(targ->getType());
369     }
370     mReusableSigBuffer.push_back(returnType);
371     mReusableTemplateNameBuffer.assign(name, templateArgCount, templateArgs);
372     const TFunction &func = getFunctionOverloadImpl();
373     return *TIntermAggregate::CreateRawFunctionCall(func, &args);
374 }
375 
newStructure(const Name & name,TFieldList & fields)376 const TStructure &SymbolEnv::newStructure(const Name &name, TFieldList &fields)
377 {
378     ASSERT(name.symbolType() == SymbolType::AngleInternal);
379 
380     TStructure *&s = mAngleStructs[name.rawName()];
381     ASSERT(!s);
382     s = new TStructure(&mSymbolTable, name.rawName(), &fields, name.symbolType());
383     return *s;
384 }
385 
getTextureEnv(TBasicType samplerType)386 const TStructure &SymbolEnv::getTextureEnv(TBasicType samplerType)
387 {
388     ASSERT(IsSampler(samplerType));
389     const TStructure *&env = mTextureEnvs[samplerType];
390     if (env == nullptr)
391     {
392         auto *textureType = new TType(samplerType);
393         auto *texture =
394             new TField(textureType, ImmutableString("texture"), kNoSourceLoc, SymbolType::BuiltIn);
395         markAsPointer(*texture, AddressSpace::Thread);
396 
397         auto *sampler = new TField(new TType(&getSamplerStruct(), false),
398                                    ImmutableString("sampler"), kNoSourceLoc, SymbolType::BuiltIn);
399         markAsPointer(*sampler, AddressSpace::Thread);
400 
401         std::string envName;
402         envName += "TextureEnv<";
403         envName += GetTextureTypeName(samplerType).rawName().data();
404         envName += ">";
405 
406         env = &newStructure(Name(envName, SymbolType::AngleInternal),
407                             *new TFieldList{texture, sampler});
408     }
409     return *env;
410 }
411 
getSamplerStruct()412 const TStructure &SymbolEnv::getSamplerStruct()
413 {
414     if (!mSampler)
415     {
416         mSampler = new TStructure(&mSymbolTable, ImmutableString("metal::sampler"),
417                                   new TFieldList(), SymbolType::BuiltIn);
418     }
419     return *mSampler;
420 }
421 
markSpace(VarField x,AddressSpace space,std::unordered_map<VarField,AddressSpace> & map)422 void SymbolEnv::markSpace(VarField x,
423                           AddressSpace space,
424                           std::unordered_map<VarField, AddressSpace> &map)
425 {
426     // It is in principle permissible to have references to pointers or multiple pointers, but this
427     // is not required for now and would require code changes to get right.
428     ASSERT(!isPointer(x));
429     ASSERT(!isReference(x));
430 
431     map[x] = space;
432 }
433 
removeSpace(VarField x,std::unordered_map<VarField,AddressSpace> & map)434 void SymbolEnv::removeSpace(VarField x, std::unordered_map<VarField, AddressSpace> &map)
435 {
436     // It is in principle permissible to have references to pointers or multiple pointers, but this
437     // is not required for now and would require code changes to get right.
438     map.erase(x);
439 }
440 
isSpace(VarField x,const std::unordered_map<VarField,AddressSpace> & map) const441 const AddressSpace *SymbolEnv::isSpace(VarField x,
442                                        const std::unordered_map<VarField, AddressSpace> &map) const
443 {
444     const auto iter = map.find(x);
445     if (iter == map.end())
446     {
447         return nullptr;
448     }
449     const AddressSpace space = iter->second;
450     const auto index         = static_cast<std::underlying_type_t<AddressSpace>>(space);
451     return &kAddressSpaces[index];
452 }
453 
markAsPointer(VarField x,AddressSpace space)454 void SymbolEnv::markAsPointer(VarField x, AddressSpace space)
455 {
456     return markSpace(x, space, mPointers);
457 }
458 
removePointer(VarField x)459 void SymbolEnv::removePointer(VarField x)
460 {
461     return removeSpace(x, mPointers);
462 }
463 
markAsReference(VarField x,AddressSpace space)464 void SymbolEnv::markAsReference(VarField x, AddressSpace space)
465 {
466     return markSpace(x, space, mReferences);
467 }
468 
isPointer(VarField x) const469 const AddressSpace *SymbolEnv::isPointer(VarField x) const
470 {
471     return isSpace(x, mPointers);
472 }
473 
isReference(VarField x) const474 const AddressSpace *SymbolEnv::isReference(VarField x) const
475 {
476     return isSpace(x, mReferences);
477 }
478 
markAsPacked(const TField & field)479 void SymbolEnv::markAsPacked(const TField &field)
480 {
481     mPackedFields.insert(&field);
482 }
483 
isPacked(const TField & field) const484 bool SymbolEnv::isPacked(const TField &field) const
485 {
486     return mPackedFields.find(&field) != mPackedFields.end();
487 }
488 
markAsUBO(VarField x)489 void SymbolEnv::markAsUBO(VarField x)
490 {
491     mUboFields.insert(x);
492 }
493 
isUBO(VarField x) const494 bool SymbolEnv::isUBO(VarField x) const
495 {
496     return mUboFields.find(x) != mUboFields.end();
497 }
498 
GetTextureBasicType(TBasicType basicType)499 static TBasicType GetTextureBasicType(TBasicType basicType)
500 {
501     ASSERT(IsSampler(basicType));
502 
503     switch (basicType)
504     {
505         case EbtSampler2D:
506         case EbtSampler3D:
507         case EbtSamplerCube:
508         case EbtSampler2DArray:
509         case EbtSamplerExternalOES:
510         case EbtSamplerExternal2DY2YEXT:
511         case EbtSampler2DRect:
512         case EbtSampler2DMS:
513         case EbtSampler2DMSArray:
514         case EbtSamplerVideoWEBGL:
515         case EbtSampler2DShadow:
516         case EbtSamplerCubeShadow:
517         case EbtSampler2DArrayShadow:
518         case EbtSamplerBuffer:
519         case EbtSamplerCubeArray:
520         case EbtSamplerCubeArrayShadow:
521         case EbtSampler2DRectShadow:
522             return TBasicType::EbtFloat;
523 
524         case EbtISampler2D:
525         case EbtISampler3D:
526         case EbtISamplerCube:
527         case EbtISampler2DArray:
528         case EbtISampler2DMS:
529         case EbtISampler2DMSArray:
530         case EbtISampler2DRect:
531         case EbtISamplerBuffer:
532         case EbtISamplerCubeArray:
533             return TBasicType::EbtInt;
534 
535         case EbtUSampler2D:
536         case EbtUSampler3D:
537         case EbtUSamplerCube:
538         case EbtUSampler2DArray:
539         case EbtUSampler2DMS:
540         case EbtUSampler2DMSArray:
541         case EbtUSampler2DRect:
542         case EbtUSamplerBuffer:
543         case EbtUSamplerCubeArray:
544             return TBasicType::EbtUInt;
545 
546         default:
547             UNREACHABLE();
548             return TBasicType::EbtVoid;
549     }
550 }
551 
GetTextureTypeName(TBasicType samplerType)552 Name sh::GetTextureTypeName(TBasicType samplerType)
553 {
554     ASSERT(IsSampler(samplerType));
555 
556     const TBasicType textureType = GetTextureBasicType(samplerType);
557     const char *name;
558 
559 #define HANDLE_TEXTURE_NAME(baseName)                   \
560     do                                                  \
561     {                                                   \
562         switch (textureType)                            \
563         {                                               \
564             case TBasicType::EbtFloat:                  \
565                 name = "metal::" baseName "<float>";    \
566                 break;                                  \
567             case TBasicType::EbtInt:                    \
568                 name = "metal::" baseName "<int>";      \
569                 break;                                  \
570             case TBasicType::EbtUInt:                   \
571                 name = "metal::" baseName "<uint32_t>"; \
572                 break;                                  \
573             default:                                    \
574                 UNREACHABLE();                          \
575                 name = nullptr;                         \
576                 break;                                  \
577         }                                               \
578     } while (false)
579 
580     switch (samplerType)
581     {
582         // Buffer textures
583         case EbtSamplerBuffer:
584         case EbtISamplerBuffer:
585         case EbtUSamplerBuffer:
586             HANDLE_TEXTURE_NAME("texture_buffer");
587             break;
588 
589         // 2d textures
590         case EbtSampler2D:
591         case EbtISampler2D:
592         case EbtUSampler2D:
593         case EbtSampler2DRect:
594         case EbtUSampler2DRect:
595         case EbtISampler2DRect:
596             HANDLE_TEXTURE_NAME("texture2d");
597             break;
598 
599         // 3d textures
600         case EbtSampler3D:
601         case EbtISampler3D:
602         case EbtUSampler3D:
603             HANDLE_TEXTURE_NAME("texture3d");
604             break;
605 
606         // Cube textures
607         case EbtSamplerCube:
608         case EbtISamplerCube:
609         case EbtUSamplerCube:
610             HANDLE_TEXTURE_NAME("texturecube");
611             break;
612 
613         // 2d array textures
614         case EbtSampler2DArray:
615         case EbtUSampler2DArray:
616         case EbtISampler2DArray:
617             HANDLE_TEXTURE_NAME("texture2d_array");
618             break;
619 
620         case EbtSampler2DMS:
621         case EbtISampler2DMS:
622         case EbtUSampler2DMS:
623             HANDLE_TEXTURE_NAME("texture2d_ms");
624             break;
625 
626         case EbtSampler2DMSArray:
627         case EbtISampler2DMSArray:
628         case EbtUSampler2DMSArray:
629             HANDLE_TEXTURE_NAME("texture2d_ms_array");
630             break;
631 
632         // cube array
633         case EbtSamplerCubeArray:
634         case EbtISamplerCubeArray:
635         case EbtUSamplerCubeArray:
636             HANDLE_TEXTURE_NAME("texturecube_array");
637             break;
638 
639         // Shadow
640         case EbtSampler2DRectShadow:
641         case EbtSampler2DShadow:
642             HANDLE_TEXTURE_NAME("depth2d");
643             break;
644 
645         case EbtSamplerCubeShadow:
646             HANDLE_TEXTURE_NAME("depthcube");
647             break;
648 
649         case EbtSampler2DArrayShadow:
650             HANDLE_TEXTURE_NAME("depth2d_array");
651             break;
652 
653         case EbtSamplerCubeArrayShadow:
654             HANDLE_TEXTURE_NAME("depthcube_array");
655             break;
656 
657         // Extentions
658         case EbtSamplerExternalOES:       // Only valid if OES_EGL_image_external exists:
659         case EbtSamplerExternal2DY2YEXT:  // Only valid if GL_EXT_YUV_target exists:
660         case EbtSamplerVideoWEBGL:
661             UNIMPLEMENTED();
662             HANDLE_TEXTURE_NAME("TODO");
663             break;
664 
665         default:
666             UNREACHABLE();
667             name = nullptr;
668             break;
669     }
670 
671 #undef HANDLE_TEXTURE_NAME
672 
673     return Name(name, SymbolType::BuiltIn);
674 }
675