1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Validates correctness of bitwise instructions.
16
17 #include "source/opcode.h"
18 #include "source/spirv_target_env.h"
19 #include "source/val/instruction.h"
20 #include "source/val/validate.h"
21 #include "source/val/validation_state.h"
22
23 namespace spvtools {
24 namespace val {
25
26 // Validates when base and result need to be the same type
ValidateBaseType(ValidationState_t & _,const Instruction * inst,const uint32_t base_type)27 spv_result_t ValidateBaseType(ValidationState_t& _, const Instruction* inst,
28 const uint32_t base_type) {
29 const spv::Op opcode = inst->opcode();
30
31 if (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type)) {
32 return _.diag(SPV_ERROR_INVALID_DATA, inst)
33 << _.VkErrorID(4781)
34 << "Expected int scalar or vector type for Base operand: "
35 << spvOpcodeString(opcode);
36 }
37
38 // Vulkan has a restriction to 32 bit for base
39 if (spvIsVulkanEnv(_.context()->target_env)) {
40 if (_.GetBitWidth(base_type) != 32) {
41 return _.diag(SPV_ERROR_INVALID_DATA, inst)
42 << _.VkErrorID(4781)
43 << "Expected 32-bit int type for Base operand: "
44 << spvOpcodeString(opcode);
45 }
46 }
47
48 // OpBitCount just needs same number of components
49 if (base_type != inst->type_id() && opcode != spv::Op::OpBitCount) {
50 return _.diag(SPV_ERROR_INVALID_DATA, inst)
51 << "Expected Base Type to be equal to Result Type: "
52 << spvOpcodeString(opcode);
53 }
54
55 return SPV_SUCCESS;
56 }
57
58 // Validates correctness of bitwise instructions.
BitwisePass(ValidationState_t & _,const Instruction * inst)59 spv_result_t BitwisePass(ValidationState_t& _, const Instruction* inst) {
60 const spv::Op opcode = inst->opcode();
61 const uint32_t result_type = inst->type_id();
62
63 switch (opcode) {
64 case spv::Op::OpShiftRightLogical:
65 case spv::Op::OpShiftRightArithmetic:
66 case spv::Op::OpShiftLeftLogical: {
67 if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
68 return _.diag(SPV_ERROR_INVALID_DATA, inst)
69 << "Expected int scalar or vector type as Result Type: "
70 << spvOpcodeString(opcode);
71
72 const uint32_t result_dimension = _.GetDimension(result_type);
73 const uint32_t base_type = _.GetOperandTypeId(inst, 2);
74 const uint32_t shift_type = _.GetOperandTypeId(inst, 3);
75
76 if (!base_type ||
77 (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type)))
78 return _.diag(SPV_ERROR_INVALID_DATA, inst)
79 << "Expected Base to be int scalar or vector: "
80 << spvOpcodeString(opcode);
81
82 if (_.GetDimension(base_type) != result_dimension)
83 return _.diag(SPV_ERROR_INVALID_DATA, inst)
84 << "Expected Base to have the same dimension "
85 << "as Result Type: " << spvOpcodeString(opcode);
86
87 if (_.GetBitWidth(base_type) != _.GetBitWidth(result_type))
88 return _.diag(SPV_ERROR_INVALID_DATA, inst)
89 << "Expected Base to have the same bit width "
90 << "as Result Type: " << spvOpcodeString(opcode);
91
92 if (!shift_type ||
93 (!_.IsIntScalarType(shift_type) && !_.IsIntVectorType(shift_type)))
94 return _.diag(SPV_ERROR_INVALID_DATA, inst)
95 << "Expected Shift to be int scalar or vector: "
96 << spvOpcodeString(opcode);
97
98 if (_.GetDimension(shift_type) != result_dimension)
99 return _.diag(SPV_ERROR_INVALID_DATA, inst)
100 << "Expected Shift to have the same dimension "
101 << "as Result Type: " << spvOpcodeString(opcode);
102 break;
103 }
104
105 case spv::Op::OpBitwiseOr:
106 case spv::Op::OpBitwiseXor:
107 case spv::Op::OpBitwiseAnd:
108 case spv::Op::OpNot: {
109 if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
110 return _.diag(SPV_ERROR_INVALID_DATA, inst)
111 << "Expected int scalar or vector type as Result Type: "
112 << spvOpcodeString(opcode);
113
114 const uint32_t result_dimension = _.GetDimension(result_type);
115 const uint32_t result_bit_width = _.GetBitWidth(result_type);
116
117 for (size_t operand_index = 2; operand_index < inst->operands().size();
118 ++operand_index) {
119 const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
120 if (!type_id ||
121 (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id)))
122 return _.diag(SPV_ERROR_INVALID_DATA, inst)
123 << "Expected int scalar or vector as operand: "
124 << spvOpcodeString(opcode) << " operand index "
125 << operand_index;
126
127 if (_.GetDimension(type_id) != result_dimension)
128 return _.diag(SPV_ERROR_INVALID_DATA, inst)
129 << "Expected operands to have the same dimension "
130 << "as Result Type: " << spvOpcodeString(opcode)
131 << " operand index " << operand_index;
132
133 if (_.GetBitWidth(type_id) != result_bit_width)
134 return _.diag(SPV_ERROR_INVALID_DATA, inst)
135 << "Expected operands to have the same bit width "
136 << "as Result Type: " << spvOpcodeString(opcode)
137 << " operand index " << operand_index;
138 }
139 break;
140 }
141
142 case spv::Op::OpBitFieldInsert: {
143 const uint32_t base_type = _.GetOperandTypeId(inst, 2);
144 const uint32_t insert_type = _.GetOperandTypeId(inst, 3);
145 const uint32_t offset_type = _.GetOperandTypeId(inst, 4);
146 const uint32_t count_type = _.GetOperandTypeId(inst, 5);
147
148 if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
149 return error;
150 }
151
152 if (insert_type != result_type)
153 return _.diag(SPV_ERROR_INVALID_DATA, inst)
154 << "Expected Insert Type to be equal to Result Type: "
155 << spvOpcodeString(opcode);
156
157 if (!offset_type || !_.IsIntScalarType(offset_type))
158 return _.diag(SPV_ERROR_INVALID_DATA, inst)
159 << "Expected Offset Type to be int scalar: "
160 << spvOpcodeString(opcode);
161
162 if (!count_type || !_.IsIntScalarType(count_type))
163 return _.diag(SPV_ERROR_INVALID_DATA, inst)
164 << "Expected Count Type to be int scalar: "
165 << spvOpcodeString(opcode);
166 break;
167 }
168
169 case spv::Op::OpBitFieldSExtract:
170 case spv::Op::OpBitFieldUExtract: {
171 const uint32_t base_type = _.GetOperandTypeId(inst, 2);
172 const uint32_t offset_type = _.GetOperandTypeId(inst, 3);
173 const uint32_t count_type = _.GetOperandTypeId(inst, 4);
174
175 if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
176 return error;
177 }
178
179 if (!offset_type || !_.IsIntScalarType(offset_type))
180 return _.diag(SPV_ERROR_INVALID_DATA, inst)
181 << "Expected Offset Type to be int scalar: "
182 << spvOpcodeString(opcode);
183
184 if (!count_type || !_.IsIntScalarType(count_type))
185 return _.diag(SPV_ERROR_INVALID_DATA, inst)
186 << "Expected Count Type to be int scalar: "
187 << spvOpcodeString(opcode);
188 break;
189 }
190
191 case spv::Op::OpBitReverse: {
192 const uint32_t base_type = _.GetOperandTypeId(inst, 2);
193
194 if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
195 return error;
196 }
197
198 break;
199 }
200
201 case spv::Op::OpBitCount: {
202 if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
203 return _.diag(SPV_ERROR_INVALID_DATA, inst)
204 << "Expected int scalar or vector type as Result Type: "
205 << spvOpcodeString(opcode);
206
207 const uint32_t base_type = _.GetOperandTypeId(inst, 2);
208
209 if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
210 return error;
211 }
212
213 const uint32_t base_dimension = _.GetDimension(base_type);
214 const uint32_t result_dimension = _.GetDimension(result_type);
215
216 if (base_dimension != result_dimension)
217 return _.diag(SPV_ERROR_INVALID_DATA, inst)
218 << "Expected Base dimension to be equal to Result Type "
219 "dimension: "
220 << spvOpcodeString(opcode);
221 break;
222 }
223
224 default:
225 break;
226 }
227
228 return SPV_SUCCESS;
229 }
230
231 } // namespace val
232 } // namespace spvtools
233