xref: /aosp_15_r20/external/angle/src/compiler/translator/tree_ops/msl/AddExplicitTypeCasts.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include "compiler/translator/tree_ops/msl/AddExplicitTypeCasts.h"
8 #include "compiler/translator/IntermRebuild.h"
9 #include "compiler/translator/msl/AstHelpers.h"
10 
11 using namespace sh;
12 
13 namespace
14 {
15 
16 class Rewriter : public TIntermRebuild
17 {
18     SymbolEnv &mSymbolEnv;
19     bool mNeedsExplicitBoolCasts = false;
20 
21   public:
Rewriter(TCompiler & compiler,SymbolEnv & symbolEnv,bool needsExplicitBoolCasts)22     Rewriter(TCompiler &compiler, SymbolEnv &symbolEnv, bool needsExplicitBoolCasts)
23         : TIntermRebuild(compiler, false, true),
24           mSymbolEnv(symbolEnv),
25           mNeedsExplicitBoolCasts(needsExplicitBoolCasts)
26     {}
27 
visitAggregatePost(TIntermAggregate & callNode)28     PostResult visitAggregatePost(TIntermAggregate &callNode) override
29     {
30         const size_t argCount = callNode.getChildCount();
31         const TType &retType  = callNode.getType();
32 
33         if (callNode.isConstructor())
34         {
35             if (IsScalarBasicType(retType))
36             {
37                 if (argCount == 1)
38                 {
39                     TIntermTyped &arg   = GetArg(callNode, 0);
40                     const TType argType = arg.getType();
41                     if (argType.isVector())
42                     {
43                         return CoerceSimple(retType, SubVector(arg, 0, 1), mNeedsExplicitBoolCasts);
44                     }
45                 }
46             }
47             else if (retType.isVector())
48             {
49                 // 1 element arrays need to be accounted for.
50                 if (argCount == 1 && !retType.isArray())
51                 {
52                     TIntermTyped &arg   = GetArg(callNode, 0);
53                     const TType argType = arg.getType();
54                     if (argType.isVector())
55                     {
56                         return CoerceSimple(retType, SubVector(arg, 0, retType.getNominalSize()),
57                                             mNeedsExplicitBoolCasts);
58                     }
59                 }
60                 for (size_t i = 0; i < argCount; ++i)
61                 {
62                     TIntermTyped &arg = GetArg(callNode, i);
63                     SetArg(callNode, i,
64                            CoerceSimple(retType.getBasicType(), arg, mNeedsExplicitBoolCasts));
65                 }
66             }
67             else if (retType.isMatrix())
68             {
69                 if (argCount == 1)
70                 {
71                     TIntermTyped &arg   = GetArg(callNode, 0);
72                     const TType argType = arg.getType();
73                     if (argType.isMatrix())
74                     {
75                         if (retType.getCols() != argType.getCols() ||
76                             retType.getRows() != argType.getRows())
77                         {
78                             TemplateArg templateArgs[] = {retType.getCols(), retType.getRows()};
79                             return mSymbolEnv.callFunctionOverload(
80                                 Name("cast"), retType, *new TIntermSequence{&arg}, 2, templateArgs);
81                         }
82                     }
83                 }
84             }
85         }
86 
87         return callNode;
88     }
89 };
90 
91 }  // anonymous namespace
92 
AddExplicitTypeCasts(TCompiler & compiler,TIntermBlock & root,SymbolEnv & symbolEnv,bool needsExplicitBoolCasts)93 bool sh::AddExplicitTypeCasts(TCompiler &compiler,
94                               TIntermBlock &root,
95                               SymbolEnv &symbolEnv,
96                               bool needsExplicitBoolCasts)
97 {
98     Rewriter rewriter(compiler, symbolEnv, needsExplicitBoolCasts);
99     if (!rewriter.rebuildRoot(root))
100     {
101         return false;
102     }
103     return true;
104 }
105