1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6
7 #include "compiler/translator/tree_ops/msl/AddExplicitTypeCasts.h"
8 #include "compiler/translator/IntermRebuild.h"
9 #include "compiler/translator/msl/AstHelpers.h"
10
11 using namespace sh;
12
13 namespace
14 {
15
16 class Rewriter : public TIntermRebuild
17 {
18 SymbolEnv &mSymbolEnv;
19 bool mNeedsExplicitBoolCasts = false;
20
21 public:
Rewriter(TCompiler & compiler,SymbolEnv & symbolEnv,bool needsExplicitBoolCasts)22 Rewriter(TCompiler &compiler, SymbolEnv &symbolEnv, bool needsExplicitBoolCasts)
23 : TIntermRebuild(compiler, false, true),
24 mSymbolEnv(symbolEnv),
25 mNeedsExplicitBoolCasts(needsExplicitBoolCasts)
26 {}
27
visitAggregatePost(TIntermAggregate & callNode)28 PostResult visitAggregatePost(TIntermAggregate &callNode) override
29 {
30 const size_t argCount = callNode.getChildCount();
31 const TType &retType = callNode.getType();
32
33 if (callNode.isConstructor())
34 {
35 if (IsScalarBasicType(retType))
36 {
37 if (argCount == 1)
38 {
39 TIntermTyped &arg = GetArg(callNode, 0);
40 const TType argType = arg.getType();
41 if (argType.isVector())
42 {
43 return CoerceSimple(retType, SubVector(arg, 0, 1), mNeedsExplicitBoolCasts);
44 }
45 }
46 }
47 else if (retType.isVector())
48 {
49 // 1 element arrays need to be accounted for.
50 if (argCount == 1 && !retType.isArray())
51 {
52 TIntermTyped &arg = GetArg(callNode, 0);
53 const TType argType = arg.getType();
54 if (argType.isVector())
55 {
56 return CoerceSimple(retType, SubVector(arg, 0, retType.getNominalSize()),
57 mNeedsExplicitBoolCasts);
58 }
59 }
60 for (size_t i = 0; i < argCount; ++i)
61 {
62 TIntermTyped &arg = GetArg(callNode, i);
63 SetArg(callNode, i,
64 CoerceSimple(retType.getBasicType(), arg, mNeedsExplicitBoolCasts));
65 }
66 }
67 else if (retType.isMatrix())
68 {
69 if (argCount == 1)
70 {
71 TIntermTyped &arg = GetArg(callNode, 0);
72 const TType argType = arg.getType();
73 if (argType.isMatrix())
74 {
75 if (retType.getCols() != argType.getCols() ||
76 retType.getRows() != argType.getRows())
77 {
78 TemplateArg templateArgs[] = {retType.getCols(), retType.getRows()};
79 return mSymbolEnv.callFunctionOverload(
80 Name("cast"), retType, *new TIntermSequence{&arg}, 2, templateArgs);
81 }
82 }
83 }
84 }
85 }
86
87 return callNode;
88 }
89 };
90
91 } // anonymous namespace
92
AddExplicitTypeCasts(TCompiler & compiler,TIntermBlock & root,SymbolEnv & symbolEnv,bool needsExplicitBoolCasts)93 bool sh::AddExplicitTypeCasts(TCompiler &compiler,
94 TIntermBlock &root,
95 SymbolEnv &symbolEnv,
96 bool needsExplicitBoolCasts)
97 {
98 Rewriter rewriter(compiler, symbolEnv, needsExplicitBoolCasts);
99 if (!rewriter.rebuildRoot(root))
100 {
101 return false;
102 }
103 return true;
104 }
105