1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Validates correctness of bitwise instructions.
16 
17 #include "source/opcode.h"
18 #include "source/spirv_target_env.h"
19 #include "source/val/instruction.h"
20 #include "source/val/validate.h"
21 #include "source/val/validation_state.h"
22 
23 namespace spvtools {
24 namespace val {
25 
26 // Validates when base and result need to be the same type
ValidateBaseType(ValidationState_t & _,const Instruction * inst,const uint32_t base_type)27 spv_result_t ValidateBaseType(ValidationState_t& _, const Instruction* inst,
28                               const uint32_t base_type) {
29   const spv::Op opcode = inst->opcode();
30 
31   if (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type)) {
32     return _.diag(SPV_ERROR_INVALID_DATA, inst)
33            << _.VkErrorID(4781)
34            << "Expected int scalar or vector type for Base operand: "
35            << spvOpcodeString(opcode);
36   }
37 
38   // Vulkan has a restriction to 32 bit for base
39   if (spvIsVulkanEnv(_.context()->target_env)) {
40     if (_.GetBitWidth(base_type) != 32) {
41       return _.diag(SPV_ERROR_INVALID_DATA, inst)
42              << _.VkErrorID(4781)
43              << "Expected 32-bit int type for Base operand: "
44              << spvOpcodeString(opcode);
45     }
46   }
47 
48   // OpBitCount just needs same number of components
49   if (base_type != inst->type_id() && opcode != spv::Op::OpBitCount) {
50     return _.diag(SPV_ERROR_INVALID_DATA, inst)
51            << "Expected Base Type to be equal to Result Type: "
52            << spvOpcodeString(opcode);
53   }
54 
55   return SPV_SUCCESS;
56 }
57 
58 // Validates correctness of bitwise instructions.
BitwisePass(ValidationState_t & _,const Instruction * inst)59 spv_result_t BitwisePass(ValidationState_t& _, const Instruction* inst) {
60   const spv::Op opcode = inst->opcode();
61   const uint32_t result_type = inst->type_id();
62 
63   switch (opcode) {
64     case spv::Op::OpShiftRightLogical:
65     case spv::Op::OpShiftRightArithmetic:
66     case spv::Op::OpShiftLeftLogical: {
67       if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
68         return _.diag(SPV_ERROR_INVALID_DATA, inst)
69                << "Expected int scalar or vector type as Result Type: "
70                << spvOpcodeString(opcode);
71 
72       const uint32_t result_dimension = _.GetDimension(result_type);
73       const uint32_t base_type = _.GetOperandTypeId(inst, 2);
74       const uint32_t shift_type = _.GetOperandTypeId(inst, 3);
75 
76       if (!base_type ||
77           (!_.IsIntScalarType(base_type) && !_.IsIntVectorType(base_type)))
78         return _.diag(SPV_ERROR_INVALID_DATA, inst)
79                << "Expected Base to be int scalar or vector: "
80                << spvOpcodeString(opcode);
81 
82       if (_.GetDimension(base_type) != result_dimension)
83         return _.diag(SPV_ERROR_INVALID_DATA, inst)
84                << "Expected Base to have the same dimension "
85                << "as Result Type: " << spvOpcodeString(opcode);
86 
87       if (_.GetBitWidth(base_type) != _.GetBitWidth(result_type))
88         return _.diag(SPV_ERROR_INVALID_DATA, inst)
89                << "Expected Base to have the same bit width "
90                << "as Result Type: " << spvOpcodeString(opcode);
91 
92       if (!shift_type ||
93           (!_.IsIntScalarType(shift_type) && !_.IsIntVectorType(shift_type)))
94         return _.diag(SPV_ERROR_INVALID_DATA, inst)
95                << "Expected Shift to be int scalar or vector: "
96                << spvOpcodeString(opcode);
97 
98       if (_.GetDimension(shift_type) != result_dimension)
99         return _.diag(SPV_ERROR_INVALID_DATA, inst)
100                << "Expected Shift to have the same dimension "
101                << "as Result Type: " << spvOpcodeString(opcode);
102       break;
103     }
104 
105     case spv::Op::OpBitwiseOr:
106     case spv::Op::OpBitwiseXor:
107     case spv::Op::OpBitwiseAnd:
108     case spv::Op::OpNot: {
109       if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
110         return _.diag(SPV_ERROR_INVALID_DATA, inst)
111                << "Expected int scalar or vector type as Result Type: "
112                << spvOpcodeString(opcode);
113 
114       const uint32_t result_dimension = _.GetDimension(result_type);
115       const uint32_t result_bit_width = _.GetBitWidth(result_type);
116 
117       for (size_t operand_index = 2; operand_index < inst->operands().size();
118            ++operand_index) {
119         const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
120         if (!type_id ||
121             (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id)))
122           return _.diag(SPV_ERROR_INVALID_DATA, inst)
123                  << "Expected int scalar or vector as operand: "
124                  << spvOpcodeString(opcode) << " operand index "
125                  << operand_index;
126 
127         if (_.GetDimension(type_id) != result_dimension)
128           return _.diag(SPV_ERROR_INVALID_DATA, inst)
129                  << "Expected operands to have the same dimension "
130                  << "as Result Type: " << spvOpcodeString(opcode)
131                  << " operand index " << operand_index;
132 
133         if (_.GetBitWidth(type_id) != result_bit_width)
134           return _.diag(SPV_ERROR_INVALID_DATA, inst)
135                  << "Expected operands to have the same bit width "
136                  << "as Result Type: " << spvOpcodeString(opcode)
137                  << " operand index " << operand_index;
138       }
139       break;
140     }
141 
142     case spv::Op::OpBitFieldInsert: {
143       const uint32_t base_type = _.GetOperandTypeId(inst, 2);
144       const uint32_t insert_type = _.GetOperandTypeId(inst, 3);
145       const uint32_t offset_type = _.GetOperandTypeId(inst, 4);
146       const uint32_t count_type = _.GetOperandTypeId(inst, 5);
147 
148       if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
149         return error;
150       }
151 
152       if (insert_type != result_type)
153         return _.diag(SPV_ERROR_INVALID_DATA, inst)
154                << "Expected Insert Type to be equal to Result Type: "
155                << spvOpcodeString(opcode);
156 
157       if (!offset_type || !_.IsIntScalarType(offset_type))
158         return _.diag(SPV_ERROR_INVALID_DATA, inst)
159                << "Expected Offset Type to be int scalar: "
160                << spvOpcodeString(opcode);
161 
162       if (!count_type || !_.IsIntScalarType(count_type))
163         return _.diag(SPV_ERROR_INVALID_DATA, inst)
164                << "Expected Count Type to be int scalar: "
165                << spvOpcodeString(opcode);
166       break;
167     }
168 
169     case spv::Op::OpBitFieldSExtract:
170     case spv::Op::OpBitFieldUExtract: {
171       const uint32_t base_type = _.GetOperandTypeId(inst, 2);
172       const uint32_t offset_type = _.GetOperandTypeId(inst, 3);
173       const uint32_t count_type = _.GetOperandTypeId(inst, 4);
174 
175       if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
176         return error;
177       }
178 
179       if (!offset_type || !_.IsIntScalarType(offset_type))
180         return _.diag(SPV_ERROR_INVALID_DATA, inst)
181                << "Expected Offset Type to be int scalar: "
182                << spvOpcodeString(opcode);
183 
184       if (!count_type || !_.IsIntScalarType(count_type))
185         return _.diag(SPV_ERROR_INVALID_DATA, inst)
186                << "Expected Count Type to be int scalar: "
187                << spvOpcodeString(opcode);
188       break;
189     }
190 
191     case spv::Op::OpBitReverse: {
192       const uint32_t base_type = _.GetOperandTypeId(inst, 2);
193 
194       if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
195         return error;
196       }
197 
198       break;
199     }
200 
201     case spv::Op::OpBitCount: {
202       if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type))
203         return _.diag(SPV_ERROR_INVALID_DATA, inst)
204                << "Expected int scalar or vector type as Result Type: "
205                << spvOpcodeString(opcode);
206 
207       const uint32_t base_type = _.GetOperandTypeId(inst, 2);
208 
209       if (spv_result_t error = ValidateBaseType(_, inst, base_type)) {
210         return error;
211       }
212 
213       const uint32_t base_dimension = _.GetDimension(base_type);
214       const uint32_t result_dimension = _.GetDimension(result_type);
215 
216       if (base_dimension != result_dimension)
217         return _.diag(SPV_ERROR_INVALID_DATA, inst)
218                << "Expected Base dimension to be equal to Result Type "
219                   "dimension: "
220                << spvOpcodeString(opcode);
221       break;
222     }
223 
224     default:
225       break;
226   }
227 
228   return SPV_SUCCESS;
229 }
230 
231 }  // namespace val
232 }  // namespace spvtools
233