xref: /aosp_15_r20/external/angle/third_party/spirv-tools/src/source/val/validate_non_uniform.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Validates correctness of barrier SPIR-V instructions.
16 
17 #include "source/opcode.h"
18 #include "source/spirv_constant.h"
19 #include "source/spirv_target_env.h"
20 #include "source/val/instruction.h"
21 #include "source/val/validate.h"
22 #include "source/val/validate_scopes.h"
23 #include "source/val/validation_state.h"
24 
25 namespace spvtools {
26 namespace val {
27 namespace {
28 
ValidateGroupNonUniformElect(ValidationState_t & _,const Instruction * inst)29 spv_result_t ValidateGroupNonUniformElect(ValidationState_t& _,
30                                           const Instruction* inst) {
31   if (!_.IsBoolScalarType(inst->type_id())) {
32     return _.diag(SPV_ERROR_INVALID_DATA, inst)
33            << "Result must be a boolean scalar type";
34   }
35 
36   return SPV_SUCCESS;
37 }
38 
ValidateGroupNonUniformAnyAll(ValidationState_t & _,const Instruction * inst)39 spv_result_t ValidateGroupNonUniformAnyAll(ValidationState_t& _,
40                                            const Instruction* inst) {
41   if (!_.IsBoolScalarType(inst->type_id())) {
42     return _.diag(SPV_ERROR_INVALID_DATA, inst)
43            << "Result must be a boolean scalar type";
44   }
45 
46   if (!_.IsBoolScalarType(_.GetOperandTypeId(inst, 3))) {
47     return _.diag(SPV_ERROR_INVALID_DATA, inst)
48            << "Predicate must be a boolean scalar type";
49   }
50 
51   return SPV_SUCCESS;
52 }
53 
ValidateGroupNonUniformAllEqual(ValidationState_t & _,const Instruction * inst)54 spv_result_t ValidateGroupNonUniformAllEqual(ValidationState_t& _,
55                                              const Instruction* inst) {
56   if (!_.IsBoolScalarType(inst->type_id())) {
57     return _.diag(SPV_ERROR_INVALID_DATA, inst)
58            << "Result must be a boolean scalar type";
59   }
60 
61   const auto value_type = _.GetOperandTypeId(inst, 3);
62   if (!_.IsFloatScalarOrVectorType(value_type) &&
63       !_.IsIntScalarOrVectorType(value_type) &&
64       !_.IsBoolScalarOrVectorType(value_type)) {
65     return _.diag(SPV_ERROR_INVALID_DATA, inst)
66            << "Value must be a scalar or vector of integer, floating-point, or "
67               "boolean type";
68   }
69 
70   return SPV_SUCCESS;
71 }
72 
ValidateGroupNonUniformBroadcastShuffle(ValidationState_t & _,const Instruction * inst)73 spv_result_t ValidateGroupNonUniformBroadcastShuffle(ValidationState_t& _,
74                                                      const Instruction* inst) {
75   const auto type_id = inst->type_id();
76   if (!_.IsFloatScalarOrVectorType(type_id) &&
77       !_.IsIntScalarOrVectorType(type_id) &&
78       !_.IsBoolScalarOrVectorType(type_id)) {
79     return _.diag(SPV_ERROR_INVALID_DATA, inst)
80            << "Result must be a scalar or vector of integer, floating-point, "
81               "or boolean type";
82   }
83 
84   const auto value_type_id = _.GetOperandTypeId(inst, 3);
85   if (value_type_id != type_id) {
86     return _.diag(SPV_ERROR_INVALID_DATA, inst)
87            << "The type of Value must match the Result type";
88   }
89 
90   const auto GetOperandName = [](const spv::Op opcode) {
91     std::string operand;
92     switch (opcode) {
93       case spv::Op::OpGroupNonUniformBroadcast:
94       case spv::Op::OpGroupNonUniformShuffle:
95         operand = "Id";
96         break;
97       case spv::Op::OpGroupNonUniformShuffleXor:
98         operand = "Mask";
99         break;
100       case spv::Op::OpGroupNonUniformQuadBroadcast:
101         operand = "Index";
102         break;
103       case spv::Op::OpGroupNonUniformQuadSwap:
104         operand = "Direction";
105         break;
106       case spv::Op::OpGroupNonUniformShuffleUp:
107       case spv::Op::OpGroupNonUniformShuffleDown:
108       default:
109         operand = "Delta";
110         break;
111     }
112     return operand;
113   };
114 
115   const auto id_type_id = _.GetOperandTypeId(inst, 4);
116   if (!_.IsUnsignedIntScalarType(id_type_id)) {
117     std::string operand = GetOperandName(inst->opcode());
118     return _.diag(SPV_ERROR_INVALID_DATA, inst)
119            << operand << " must be an unsigned integer scalar";
120   }
121 
122   const bool should_be_constant =
123       inst->opcode() == spv::Op::OpGroupNonUniformQuadSwap ||
124       ((inst->opcode() == spv::Op::OpGroupNonUniformBroadcast ||
125         inst->opcode() == spv::Op::OpGroupNonUniformQuadBroadcast) &&
126        _.version() < SPV_SPIRV_VERSION_WORD(1, 5));
127   if (should_be_constant) {
128     const auto id_id = inst->GetOperandAs<uint32_t>(4);
129     const auto id_op = _.GetIdOpcode(id_id);
130     if (!spvOpcodeIsConstant(id_op)) {
131       std::string operand = GetOperandName(inst->opcode());
132       return _.diag(SPV_ERROR_INVALID_DATA, inst)
133              << "Before SPIR-V 1.5, " << operand
134              << " must be a constant instruction";
135     }
136   }
137 
138   return SPV_SUCCESS;
139 }
140 
ValidateGroupNonUniformBroadcastFirst(ValidationState_t & _,const Instruction * inst)141 spv_result_t ValidateGroupNonUniformBroadcastFirst(ValidationState_t& _,
142                                                    const Instruction* inst) {
143   const auto type_id = inst->type_id();
144   if (!_.IsFloatScalarOrVectorType(type_id) &&
145       !_.IsIntScalarOrVectorType(type_id) &&
146       !_.IsBoolScalarOrVectorType(type_id)) {
147     return _.diag(SPV_ERROR_INVALID_DATA, inst)
148            << "Result must be a scalar or vector of integer, floating-point, "
149               "or boolean type";
150   }
151 
152   const auto value_type_id = _.GetOperandTypeId(inst, 3);
153   if (value_type_id != type_id) {
154     return _.diag(SPV_ERROR_INVALID_DATA, inst)
155            << "The type of Value must match the Result type";
156   }
157 
158   return SPV_SUCCESS;
159 }
160 
ValidateGroupNonUniformBallot(ValidationState_t & _,const Instruction * inst)161 spv_result_t ValidateGroupNonUniformBallot(ValidationState_t& _,
162                                            const Instruction* inst) {
163   if (!_.IsUnsignedIntVectorType(inst->type_id())) {
164     return _.diag(SPV_ERROR_INVALID_DATA, inst)
165            << "Result must be a 4-component unsigned integer vector";
166   }
167 
168   if (_.GetDimension(inst->type_id()) != 4) {
169     return _.diag(SPV_ERROR_INVALID_DATA, inst)
170            << "Result must be a 4-component unsigned integer vector";
171   }
172 
173   const auto pred_type_id = _.GetOperandTypeId(inst, 3);
174   if (!_.IsBoolScalarType(pred_type_id)) {
175     return _.diag(SPV_ERROR_INVALID_DATA, inst)
176            << "Predicate must be a boolean scalar";
177   }
178 
179   return SPV_SUCCESS;
180 }
181 
ValidateGroupNonUniformInverseBallot(ValidationState_t & _,const Instruction * inst)182 spv_result_t ValidateGroupNonUniformInverseBallot(ValidationState_t& _,
183                                                   const Instruction* inst) {
184   if (!_.IsBoolScalarType(inst->type_id())) {
185     return _.diag(SPV_ERROR_INVALID_DATA, inst)
186            << "Result must be a boolean scalar";
187   }
188 
189   const auto value_type_id = _.GetOperandTypeId(inst, 3);
190   if (!_.IsUnsignedIntVectorType(value_type_id)) {
191     return _.diag(SPV_ERROR_INVALID_DATA, inst)
192            << "Value must be a 4-component unsigned integer vector";
193   }
194 
195   if (_.GetDimension(value_type_id) != 4) {
196     return _.diag(SPV_ERROR_INVALID_DATA, inst)
197            << "Value must be a 4-component unsigned integer vector";
198   }
199 
200   return SPV_SUCCESS;
201 }
202 
ValidateGroupNonUniformBallotBitExtract(ValidationState_t & _,const Instruction * inst)203 spv_result_t ValidateGroupNonUniformBallotBitExtract(ValidationState_t& _,
204                                                      const Instruction* inst) {
205   if (!_.IsBoolScalarType(inst->type_id())) {
206     return _.diag(SPV_ERROR_INVALID_DATA, inst)
207            << "Result must be a boolean scalar";
208   }
209 
210   const auto value_type_id = _.GetOperandTypeId(inst, 3);
211   if (!_.IsUnsignedIntVectorType(value_type_id)) {
212     return _.diag(SPV_ERROR_INVALID_DATA, inst)
213            << "Value must be a 4-component unsigned integer vector";
214   }
215 
216   if (_.GetDimension(value_type_id) != 4) {
217     return _.diag(SPV_ERROR_INVALID_DATA, inst)
218            << "Value must be a 4-component unsigned integer vector";
219   }
220 
221   const auto id_type_id = _.GetOperandTypeId(inst, 4);
222   if (!_.IsUnsignedIntScalarType(id_type_id)) {
223     return _.diag(SPV_ERROR_INVALID_DATA, inst)
224            << "Id must be an unsigned integer scalar";
225   }
226 
227   return SPV_SUCCESS;
228 }
229 
ValidateGroupNonUniformBallotBitCount(ValidationState_t & _,const Instruction * inst)230 spv_result_t ValidateGroupNonUniformBallotBitCount(ValidationState_t& _,
231                                                    const Instruction* inst) {
232   // Scope is already checked by ValidateExecutionScope() above.
233 
234   const uint32_t result_type = inst->type_id();
235   if (!_.IsUnsignedIntScalarType(result_type)) {
236     return _.diag(SPV_ERROR_INVALID_DATA, inst)
237            << "Expected Result Type to be an unsigned integer type scalar.";
238   }
239 
240   const auto value = inst->GetOperandAs<uint32_t>(4);
241   const auto value_type = _.FindDef(value)->type_id();
242   if (!_.IsUnsignedIntVectorType(value_type) ||
243       _.GetDimension(value_type) != 4) {
244     return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Value to be a "
245                                                    "vector of four components "
246                                                    "of integer type scalar";
247   }
248 
249   const auto group = inst->GetOperandAs<spv::GroupOperation>(3);
250   if (spvIsVulkanEnv(_.context()->target_env)) {
251     if ((group != spv::GroupOperation::Reduce) &&
252         (group != spv::GroupOperation::InclusiveScan) &&
253         (group != spv::GroupOperation::ExclusiveScan)) {
254       return _.diag(SPV_ERROR_INVALID_DATA, inst)
255              << _.VkErrorID(4685)
256              << "In Vulkan: The OpGroupNonUniformBallotBitCount group "
257                 "operation must be only: Reduce, InclusiveScan, or "
258                 "ExclusiveScan.";
259     }
260   }
261   return SPV_SUCCESS;
262 }
263 
ValidateGroupNonUniformBallotFind(ValidationState_t & _,const Instruction * inst)264 spv_result_t ValidateGroupNonUniformBallotFind(ValidationState_t& _,
265                                                const Instruction* inst) {
266   if (!_.IsUnsignedIntScalarType(inst->type_id())) {
267     return _.diag(SPV_ERROR_INVALID_DATA, inst)
268            << "Result must be an unsigned integer scalar";
269   }
270 
271   const auto value_type_id = _.GetOperandTypeId(inst, 3);
272   if (!_.IsUnsignedIntVectorType(value_type_id)) {
273     return _.diag(SPV_ERROR_INVALID_DATA, inst)
274            << "Value must be a 4-component unsigned integer vector";
275   }
276 
277   if (_.GetDimension(value_type_id) != 4) {
278     return _.diag(SPV_ERROR_INVALID_DATA, inst)
279            << "Value must be a 4-component unsigned integer vector";
280   }
281 
282   return SPV_SUCCESS;
283 }
284 
ValidateGroupNonUniformArithmetic(ValidationState_t & _,const Instruction * inst)285 spv_result_t ValidateGroupNonUniformArithmetic(ValidationState_t& _,
286                                                const Instruction* inst) {
287   const bool is_unsigned = inst->opcode() == spv::Op::OpGroupNonUniformUMin ||
288                            inst->opcode() == spv::Op::OpGroupNonUniformUMax;
289   const bool is_float = inst->opcode() == spv::Op::OpGroupNonUniformFAdd ||
290                         inst->opcode() == spv::Op::OpGroupNonUniformFMul ||
291                         inst->opcode() == spv::Op::OpGroupNonUniformFMin ||
292                         inst->opcode() == spv::Op::OpGroupNonUniformFMax;
293   const bool is_bool = inst->opcode() == spv::Op::OpGroupNonUniformLogicalAnd ||
294                        inst->opcode() == spv::Op::OpGroupNonUniformLogicalOr ||
295                        inst->opcode() == spv::Op::OpGroupNonUniformLogicalXor;
296   if (is_float) {
297     if (!_.IsFloatScalarOrVectorType(inst->type_id())) {
298       return _.diag(SPV_ERROR_INVALID_DATA, inst)
299              << "Result must be a floating-point scalar or vector";
300     }
301   } else if (is_bool) {
302     if (!_.IsBoolScalarOrVectorType(inst->type_id())) {
303       return _.diag(SPV_ERROR_INVALID_DATA, inst)
304              << "Result must be a boolean scalar or vector";
305     }
306   } else if (is_unsigned) {
307     if (!_.IsUnsignedIntScalarOrVectorType(inst->type_id())) {
308       return _.diag(SPV_ERROR_INVALID_DATA, inst)
309              << "Result must be an unsigned integer scalar or vector";
310     }
311   } else if (!_.IsIntScalarOrVectorType(inst->type_id())) {
312     return _.diag(SPV_ERROR_INVALID_DATA, inst)
313            << "Result must be an integer scalar or vector";
314   }
315 
316   const auto value_type_id = _.GetOperandTypeId(inst, 4);
317   if (value_type_id != inst->type_id()) {
318     return _.diag(SPV_ERROR_INVALID_DATA, inst)
319            << "The type of Value must match the Result type";
320   }
321 
322   const auto group_op = inst->GetOperandAs<spv::GroupOperation>(3);
323   bool is_clustered_reduce = group_op == spv::GroupOperation::ClusteredReduce;
324   bool is_partitioned_nv =
325       group_op == spv::GroupOperation::PartitionedReduceNV ||
326       group_op == spv::GroupOperation::PartitionedInclusiveScanNV ||
327       group_op == spv::GroupOperation::PartitionedExclusiveScanNV;
328   if (inst->operands().size() <= 5) {
329     if (is_clustered_reduce) {
330       return _.diag(SPV_ERROR_INVALID_DATA, inst)
331              << "ClusterSize must be present when Operation is ClusteredReduce";
332     } else if (is_partitioned_nv) {
333       return _.diag(SPV_ERROR_INVALID_DATA, inst)
334              << "Ballot must be present when Operation is PartitionedReduceNV, "
335                 "PartitionedInclusiveScanNV, or PartitionedExclusiveScanNV";
336     }
337   } else {
338     const auto operand_id = inst->GetOperandAs<uint32_t>(5);
339     const auto* operand = _.FindDef(operand_id);
340     if (is_partitioned_nv) {
341       if (!operand || !_.IsIntScalarOrVectorType(operand->type_id())) {
342         return _.diag(SPV_ERROR_INVALID_DATA, inst)
343                << "Ballot must be a 4-component integer vector";
344       }
345 
346       if (_.GetDimension(operand->type_id()) != 4) {
347         return _.diag(SPV_ERROR_INVALID_DATA, inst)
348                << "Ballot must be a 4-component integer vector";
349       }
350     } else {
351       if (!operand || !_.IsUnsignedIntScalarType(operand->type_id())) {
352         return _.diag(SPV_ERROR_INVALID_DATA, inst)
353                << "ClusterSize must be an unsigned integer scalar";
354       }
355 
356       if (!spvOpcodeIsConstant(operand->opcode())) {
357         return _.diag(SPV_ERROR_INVALID_DATA, inst)
358                << "ClusterSize must be a constant instruction";
359       }
360     }
361   }
362   return SPV_SUCCESS;
363 }
364 
ValidateGroupNonUniformRotateKHR(ValidationState_t & _,const Instruction * inst)365 spv_result_t ValidateGroupNonUniformRotateKHR(ValidationState_t& _,
366                                               const Instruction* inst) {
367   // Scope is already checked by ValidateExecutionScope() above.
368   const uint32_t result_type = inst->type_id();
369   if (!_.IsIntScalarOrVectorType(result_type) &&
370       !_.IsFloatScalarOrVectorType(result_type) &&
371       !_.IsBoolScalarOrVectorType(result_type)) {
372     return _.diag(SPV_ERROR_INVALID_DATA, inst)
373            << "Expected Result Type to be a scalar or vector of "
374               "floating-point, integer or boolean type.";
375   }
376 
377   const uint32_t value_type = _.GetTypeId(inst->GetOperandAs<uint32_t>(3));
378   if (value_type != result_type) {
379     return _.diag(SPV_ERROR_INVALID_DATA, inst)
380            << "Result Type must be the same as the type of Value.";
381   }
382 
383   const uint32_t delta_type = _.GetTypeId(inst->GetOperandAs<uint32_t>(4));
384   if (!_.IsUnsignedIntScalarType(delta_type)) {
385     return _.diag(SPV_ERROR_INVALID_DATA, inst)
386            << "Delta must be a scalar of integer type, whose Signedness "
387               "operand is 0.";
388   }
389 
390   if (inst->words().size() > 6) {
391     const uint32_t cluster_size_op_id = inst->GetOperandAs<uint32_t>(5);
392     const Instruction* cluster_size_inst = _.FindDef(cluster_size_op_id);
393     if (!cluster_size_inst ||
394         !_.IsUnsignedIntScalarType(cluster_size_inst->type_id())) {
395       return _.diag(SPV_ERROR_INVALID_DATA, inst)
396              << "ClusterSize must be a scalar of integer type, whose "
397                 "Signedness operand is 0.";
398     }
399 
400     if (!spvOpcodeIsConstant(cluster_size_inst->opcode())) {
401       return _.diag(SPV_ERROR_INVALID_DATA, inst)
402              << "ClusterSize must come from a constant instruction.";
403     }
404 
405     uint64_t cluster_size;
406     const bool valid_const =
407         _.EvalConstantValUint64(cluster_size_op_id, &cluster_size);
408     if (valid_const &&
409         ((cluster_size == 0) || ((cluster_size & (cluster_size - 1)) != 0))) {
410       return _.diag(SPV_WARNING, inst)
411              << "Behavior is undefined unless ClusterSize is at least 1 and a "
412                 "power of 2.";
413     }
414 
415     // TODO(kpet) Warn about undefined behavior when ClusterSize is greater than
416     // the declared SubGroupSize
417   }
418 
419   return SPV_SUCCESS;
420 }
421 
422 }  // namespace
423 
424 // Validates correctness of non-uniform group instructions.
NonUniformPass(ValidationState_t & _,const Instruction * inst)425 spv_result_t NonUniformPass(ValidationState_t& _, const Instruction* inst) {
426   const spv::Op opcode = inst->opcode();
427 
428   if (spvOpcodeIsNonUniformGroupOperation(opcode)) {
429     // OpGroupNonUniformQuadAllKHR and OpGroupNonUniformQuadAnyKHR don't have
430     // scope paramter
431     if ((opcode != spv::Op::OpGroupNonUniformQuadAllKHR) &&
432         (opcode != spv::Op::OpGroupNonUniformQuadAnyKHR)) {
433       const uint32_t execution_scope = inst->GetOperandAs<uint32_t>(2);
434       if (auto error = ValidateExecutionScope(_, inst, execution_scope)) {
435         return error;
436       }
437     }
438   }
439 
440   switch (opcode) {
441     case spv::Op::OpGroupNonUniformElect:
442       return ValidateGroupNonUniformElect(_, inst);
443     case spv::Op::OpGroupNonUniformAny:
444     case spv::Op::OpGroupNonUniformAll:
445       return ValidateGroupNonUniformAnyAll(_, inst);
446     case spv::Op::OpGroupNonUniformAllEqual:
447       return ValidateGroupNonUniformAllEqual(_, inst);
448     case spv::Op::OpGroupNonUniformBroadcast:
449     case spv::Op::OpGroupNonUniformShuffle:
450     case spv::Op::OpGroupNonUniformShuffleXor:
451     case spv::Op::OpGroupNonUniformShuffleUp:
452     case spv::Op::OpGroupNonUniformShuffleDown:
453     case spv::Op::OpGroupNonUniformQuadBroadcast:
454     case spv::Op::OpGroupNonUniformQuadSwap:
455       return ValidateGroupNonUniformBroadcastShuffle(_, inst);
456     case spv::Op::OpGroupNonUniformBroadcastFirst:
457       return ValidateGroupNonUniformBroadcastFirst(_, inst);
458     case spv::Op::OpGroupNonUniformBallot:
459       return ValidateGroupNonUniformBallot(_, inst);
460     case spv::Op::OpGroupNonUniformInverseBallot:
461       return ValidateGroupNonUniformInverseBallot(_, inst);
462     case spv::Op::OpGroupNonUniformBallotBitExtract:
463       return ValidateGroupNonUniformBallotBitExtract(_, inst);
464     case spv::Op::OpGroupNonUniformBallotBitCount:
465       return ValidateGroupNonUniformBallotBitCount(_, inst);
466     case spv::Op::OpGroupNonUniformBallotFindLSB:
467     case spv::Op::OpGroupNonUniformBallotFindMSB:
468       return ValidateGroupNonUniformBallotFind(_, inst);
469     case spv::Op::OpGroupNonUniformIAdd:
470     case spv::Op::OpGroupNonUniformFAdd:
471     case spv::Op::OpGroupNonUniformIMul:
472     case spv::Op::OpGroupNonUniformFMul:
473     case spv::Op::OpGroupNonUniformSMin:
474     case spv::Op::OpGroupNonUniformUMin:
475     case spv::Op::OpGroupNonUniformFMin:
476     case spv::Op::OpGroupNonUniformSMax:
477     case spv::Op::OpGroupNonUniformUMax:
478     case spv::Op::OpGroupNonUniformFMax:
479     case spv::Op::OpGroupNonUniformBitwiseAnd:
480     case spv::Op::OpGroupNonUniformBitwiseOr:
481     case spv::Op::OpGroupNonUniformBitwiseXor:
482     case spv::Op::OpGroupNonUniformLogicalAnd:
483     case spv::Op::OpGroupNonUniformLogicalOr:
484     case spv::Op::OpGroupNonUniformLogicalXor:
485       return ValidateGroupNonUniformArithmetic(_, inst);
486     case spv::Op::OpGroupNonUniformRotateKHR:
487       return ValidateGroupNonUniformRotateKHR(_, inst);
488     default:
489       break;
490   }
491 
492   return SPV_SUCCESS;
493 }
494 
495 }  // namespace val
496 }  // namespace spvtools
497