xref: /aosp_15_r20/external/angle/third_party/spirv-tools/src/source/val/validate_ray_query.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 // Copyright (c) 2022 The Khronos Group Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Validates ray query instructions from SPV_KHR_ray_query
16 
17 #include "source/opcode.h"
18 #include "source/val/instruction.h"
19 #include "source/val/validate.h"
20 #include "source/val/validation_state.h"
21 
22 namespace spvtools {
23 namespace val {
24 namespace {
25 
ValidateRayQueryPointer(ValidationState_t & _,const Instruction * inst,uint32_t ray_query_index)26 spv_result_t ValidateRayQueryPointer(ValidationState_t& _,
27                                      const Instruction* inst,
28                                      uint32_t ray_query_index) {
29   const uint32_t ray_query_id = inst->GetOperandAs<uint32_t>(ray_query_index);
30   auto variable = _.FindDef(ray_query_id);
31   const auto var_opcode = variable->opcode();
32   if (!variable || (var_opcode != spv::Op::OpVariable &&
33                     var_opcode != spv::Op::OpFunctionParameter &&
34                     var_opcode != spv::Op::OpAccessChain)) {
35     return _.diag(SPV_ERROR_INVALID_DATA, inst)
36            << "Ray Query must be a memory object declaration";
37   }
38   auto pointer = _.FindDef(variable->GetOperandAs<uint32_t>(0));
39   if (!pointer || pointer->opcode() != spv::Op::OpTypePointer) {
40     return _.diag(SPV_ERROR_INVALID_DATA, inst)
41            << "Ray Query must be a pointer";
42   }
43   auto type = _.FindDef(pointer->GetOperandAs<uint32_t>(2));
44   if (!type || type->opcode() != spv::Op::OpTypeRayQueryKHR) {
45     return _.diag(SPV_ERROR_INVALID_DATA, inst)
46            << "Ray Query must be a pointer to OpTypeRayQueryKHR";
47   }
48   return SPV_SUCCESS;
49 }
50 
ValidateIntersectionId(ValidationState_t & _,const Instruction * inst,uint32_t intersection_index)51 spv_result_t ValidateIntersectionId(ValidationState_t& _,
52                                     const Instruction* inst,
53                                     uint32_t intersection_index) {
54   const uint32_t intersection_id =
55       inst->GetOperandAs<uint32_t>(intersection_index);
56   const uint32_t intersection_type = _.GetTypeId(intersection_id);
57   const spv::Op intersection_opcode = _.GetIdOpcode(intersection_id);
58   if (!_.IsIntScalarType(intersection_type) ||
59       _.GetBitWidth(intersection_type) != 32 ||
60       !spvOpcodeIsConstant(intersection_opcode)) {
61     return _.diag(SPV_ERROR_INVALID_DATA, inst)
62            << "expected Intersection ID to be a constant 32-bit int scalar";
63   }
64 
65   return SPV_SUCCESS;
66 }
67 
68 }  // namespace
69 
RayQueryPass(ValidationState_t & _,const Instruction * inst)70 spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
71   const spv::Op opcode = inst->opcode();
72   const uint32_t result_type = inst->type_id();
73 
74   switch (opcode) {
75     case spv::Op::OpRayQueryInitializeKHR: {
76       if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
77 
78       if (_.GetIdOpcode(_.GetOperandTypeId(inst, 1)) !=
79           spv::Op::OpTypeAccelerationStructureKHR) {
80         return _.diag(SPV_ERROR_INVALID_DATA, inst)
81                << "Expected Acceleration Structure to be of type "
82                   "OpTypeAccelerationStructureKHR";
83       }
84 
85       const uint32_t ray_flags = _.GetOperandTypeId(inst, 2);
86       if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) {
87         return _.diag(SPV_ERROR_INVALID_DATA, inst)
88                << "Ray Flags must be a 32-bit int scalar";
89       }
90 
91       const uint32_t cull_mask = _.GetOperandTypeId(inst, 3);
92       if (!_.IsIntScalarType(cull_mask) || _.GetBitWidth(cull_mask) != 32) {
93         return _.diag(SPV_ERROR_INVALID_DATA, inst)
94                << "Cull Mask must be a 32-bit int scalar";
95       }
96 
97       const uint32_t ray_origin = _.GetOperandTypeId(inst, 4);
98       if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 ||
99           _.GetBitWidth(ray_origin) != 32) {
100         return _.diag(SPV_ERROR_INVALID_DATA, inst)
101                << "Ray Origin must be a 32-bit float 3-component vector";
102       }
103 
104       const uint32_t ray_tmin = _.GetOperandTypeId(inst, 5);
105       if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
106         return _.diag(SPV_ERROR_INVALID_DATA, inst)
107                << "Ray TMin must be a 32-bit float scalar";
108       }
109 
110       const uint32_t ray_direction = _.GetOperandTypeId(inst, 6);
111       if (!_.IsFloatVectorType(ray_direction) ||
112           _.GetDimension(ray_direction) != 3 ||
113           _.GetBitWidth(ray_direction) != 32) {
114         return _.diag(SPV_ERROR_INVALID_DATA, inst)
115                << "Ray Direction must be a 32-bit float 3-component vector";
116       }
117 
118       const uint32_t ray_tmax = _.GetOperandTypeId(inst, 7);
119       if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
120         return _.diag(SPV_ERROR_INVALID_DATA, inst)
121                << "Ray TMax must be a 32-bit float scalar";
122       }
123       break;
124     }
125 
126     case spv::Op::OpRayQueryTerminateKHR:
127     case spv::Op::OpRayQueryConfirmIntersectionKHR: {
128       if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
129       break;
130     }
131 
132     case spv::Op::OpRayQueryGenerateIntersectionKHR: {
133       if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
134 
135       const uint32_t hit_t_id = _.GetOperandTypeId(inst, 1);
136       if (!_.IsFloatScalarType(hit_t_id) || _.GetBitWidth(hit_t_id) != 32) {
137         return _.diag(SPV_ERROR_INVALID_DATA, inst)
138                << "Hit T must be a 32-bit float scalar";
139       }
140 
141       break;
142     }
143 
144     case spv::Op::OpRayQueryGetIntersectionFrontFaceKHR:
145     case spv::Op::OpRayQueryProceedKHR:
146     case spv::Op::OpRayQueryGetIntersectionCandidateAABBOpaqueKHR: {
147       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
148 
149       if (!_.IsBoolScalarType(result_type)) {
150         return _.diag(SPV_ERROR_INVALID_DATA, inst)
151                << "expected Result Type to be bool scalar type";
152       }
153 
154       if (opcode == spv::Op::OpRayQueryGetIntersectionFrontFaceKHR) {
155         if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
156       }
157 
158       break;
159     }
160 
161     case spv::Op::OpRayQueryGetIntersectionTKHR:
162     case spv::Op::OpRayQueryGetRayTMinKHR: {
163       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
164 
165       if (!_.IsFloatScalarType(result_type) ||
166           _.GetBitWidth(result_type) != 32) {
167         return _.diag(SPV_ERROR_INVALID_DATA, inst)
168                << "expected Result Type to be 32-bit float scalar type";
169       }
170 
171       if (opcode == spv::Op::OpRayQueryGetIntersectionTKHR) {
172         if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
173       }
174 
175       break;
176     }
177 
178     case spv::Op::OpRayQueryGetIntersectionTypeKHR:
179     case spv::Op::OpRayQueryGetIntersectionInstanceCustomIndexKHR:
180     case spv::Op::OpRayQueryGetIntersectionInstanceIdKHR:
181     case spv::Op::
182         OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
183     case spv::Op::OpRayQueryGetIntersectionGeometryIndexKHR:
184     case spv::Op::OpRayQueryGetIntersectionPrimitiveIndexKHR:
185     case spv::Op::OpRayQueryGetRayFlagsKHR: {
186       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
187 
188       if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) {
189         return _.diag(SPV_ERROR_INVALID_DATA, inst)
190                << "expected Result Type to be 32-bit int scalar type";
191       }
192 
193       if (opcode != spv::Op::OpRayQueryGetRayFlagsKHR) {
194         if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
195       }
196 
197       break;
198     }
199 
200     case spv::Op::OpRayQueryGetIntersectionObjectRayDirectionKHR:
201     case spv::Op::OpRayQueryGetIntersectionObjectRayOriginKHR:
202     case spv::Op::OpRayQueryGetWorldRayDirectionKHR:
203     case spv::Op::OpRayQueryGetWorldRayOriginKHR: {
204       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
205 
206       if (!_.IsFloatVectorType(result_type) ||
207           _.GetDimension(result_type) != 3 ||
208           _.GetBitWidth(result_type) != 32) {
209         return _.diag(SPV_ERROR_INVALID_DATA, inst)
210                << "expected Result Type to be 32-bit float 3-component "
211                   "vector type";
212       }
213 
214       if (opcode == spv::Op::OpRayQueryGetIntersectionObjectRayDirectionKHR ||
215           opcode == spv::Op::OpRayQueryGetIntersectionObjectRayOriginKHR) {
216         if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
217       }
218 
219       break;
220     }
221 
222     case spv::Op::OpRayQueryGetIntersectionBarycentricsKHR: {
223       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
224       if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
225 
226       if (!_.IsFloatVectorType(result_type) ||
227           _.GetDimension(result_type) != 2 ||
228           _.GetBitWidth(result_type) != 32) {
229         return _.diag(SPV_ERROR_INVALID_DATA, inst)
230                << "expected Result Type to be 32-bit float 2-component "
231                   "vector type";
232       }
233 
234       break;
235     }
236 
237     case spv::Op::OpRayQueryGetIntersectionObjectToWorldKHR:
238     case spv::Op::OpRayQueryGetIntersectionWorldToObjectKHR: {
239       if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
240       if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
241 
242       uint32_t num_rows = 0;
243       uint32_t num_cols = 0;
244       uint32_t col_type = 0;
245       uint32_t component_type = 0;
246       if (!_.GetMatrixTypeInfo(result_type, &num_rows, &num_cols, &col_type,
247                                &component_type)) {
248         return _.diag(SPV_ERROR_INVALID_DATA, inst)
249                << "expected matrix type as Result Type";
250       }
251 
252       if (num_cols != 4) {
253         return _.diag(SPV_ERROR_INVALID_DATA, inst)
254                << "expected Result Type matrix to have a Column Count of 4";
255       }
256 
257       if (!_.IsFloatScalarType(component_type) ||
258           _.GetBitWidth(result_type) != 32 || num_rows != 3) {
259         return _.diag(SPV_ERROR_INVALID_DATA, inst)
260                << "expected Result Type matrix to have a Column Type of "
261                   "3-component 32-bit float vectors";
262       }
263       break;
264     }
265 
266     default:
267       break;
268   }
269 
270   return SPV_SUCCESS;
271 }
272 
273 }  // namespace val
274 }  // namespace spvtools
275