xref: /aosp_15_r20/external/angle/third_party/spirv-tools/src/source/val/validate_ray_tracing_reorder.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 // Copyright (c) 2022 The Khronos Group Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Validates ray tracing instructions from SPV_NV_shader_execution_reorder
16 
17 #include "source/opcode.h"
18 #include "source/val/instruction.h"
19 #include "source/val/validate.h"
20 #include "source/val/validation_state.h"
21 
22 #include <limits>
23 
24 namespace spvtools {
25 namespace val {
26 
27 static const uint32_t KRayParamInvalidId = std::numeric_limits<uint32_t>::max();
28 
ValidateHitObjectPointer(ValidationState_t & _,const Instruction * inst,uint32_t hit_object_index)29 spv_result_t ValidateHitObjectPointer(ValidationState_t& _,
30                                       const Instruction* inst,
31                                       uint32_t hit_object_index) {
32   const uint32_t hit_object_id = inst->GetOperandAs<uint32_t>(hit_object_index);
33   auto variable = _.FindDef(hit_object_id);
34   const auto var_opcode = variable->opcode();
35   if (!variable || (var_opcode != spv::Op::OpVariable &&
36                     var_opcode != spv::Op::OpFunctionParameter &&
37                     var_opcode != spv::Op::OpAccessChain)) {
38     return _.diag(SPV_ERROR_INVALID_DATA, inst)
39            << "Hit Object must be a memory object declaration";
40   }
41   auto pointer = _.FindDef(variable->GetOperandAs<uint32_t>(0));
42   if (!pointer || pointer->opcode() != spv::Op::OpTypePointer) {
43     return _.diag(SPV_ERROR_INVALID_DATA, inst)
44            << "Hit Object must be a pointer";
45   }
46   auto type = _.FindDef(pointer->GetOperandAs<uint32_t>(2));
47   if (!type || type->opcode() != spv::Op::OpTypeHitObjectNV) {
48     return _.diag(SPV_ERROR_INVALID_DATA, inst)
49            << "Type must be OpTypeHitObjectNV";
50   }
51   return SPV_SUCCESS;
52 }
53 
ValidateHitObjectInstructionCommonParameters(ValidationState_t & _,const Instruction * inst,uint32_t acceleration_struct_index,uint32_t instance_id_index,uint32_t primtive_id_index,uint32_t geometry_index,uint32_t ray_flags_index,uint32_t cull_mask_index,uint32_t hit_kind_index,uint32_t sbt_index,uint32_t sbt_offset_index,uint32_t sbt_stride_index,uint32_t sbt_record_offset_index,uint32_t sbt_record_stride_index,uint32_t miss_index,uint32_t ray_origin_index,uint32_t ray_tmin_index,uint32_t ray_direction_index,uint32_t ray_tmax_index,uint32_t payload_index,uint32_t hit_object_attr_index)54 spv_result_t ValidateHitObjectInstructionCommonParameters(
55     ValidationState_t& _, const Instruction* inst,
56     uint32_t acceleration_struct_index, uint32_t instance_id_index,
57     uint32_t primtive_id_index, uint32_t geometry_index,
58     uint32_t ray_flags_index, uint32_t cull_mask_index, uint32_t hit_kind_index,
59     uint32_t sbt_index, uint32_t sbt_offset_index, uint32_t sbt_stride_index,
60     uint32_t sbt_record_offset_index, uint32_t sbt_record_stride_index,
61     uint32_t miss_index, uint32_t ray_origin_index, uint32_t ray_tmin_index,
62     uint32_t ray_direction_index, uint32_t ray_tmax_index,
63     uint32_t payload_index, uint32_t hit_object_attr_index) {
64   auto isValidId = [](uint32_t spvid) { return spvid < KRayParamInvalidId; };
65   if (isValidId(acceleration_struct_index) &&
66       _.GetIdOpcode(_.GetOperandTypeId(inst, acceleration_struct_index)) !=
67           spv::Op::OpTypeAccelerationStructureKHR) {
68     return _.diag(SPV_ERROR_INVALID_DATA, inst)
69            << "Expected Acceleration Structure to be of type "
70               "OpTypeAccelerationStructureKHR";
71   }
72 
73   if (isValidId(instance_id_index)) {
74     const uint32_t instance_id = _.GetOperandTypeId(inst, instance_id_index);
75     if (!_.IsIntScalarType(instance_id) || _.GetBitWidth(instance_id) != 32) {
76       return _.diag(SPV_ERROR_INVALID_DATA, inst)
77              << "Instance Id must be a 32-bit int scalar";
78     }
79   }
80 
81   if (isValidId(primtive_id_index)) {
82     const uint32_t primitive_id = _.GetOperandTypeId(inst, primtive_id_index);
83     if (!_.IsIntScalarType(primitive_id) || _.GetBitWidth(primitive_id) != 32) {
84       return _.diag(SPV_ERROR_INVALID_DATA, inst)
85              << "Primitive Id must be a 32-bit int scalar";
86     }
87   }
88 
89   if (isValidId(geometry_index)) {
90     const uint32_t geometry_index_id = _.GetOperandTypeId(inst, geometry_index);
91     if (!_.IsIntScalarType(geometry_index_id) ||
92         _.GetBitWidth(geometry_index_id) != 32) {
93       return _.diag(SPV_ERROR_INVALID_DATA, inst)
94              << "Geometry Index must be a 32-bit int scalar";
95     }
96   }
97 
98   if (isValidId(miss_index)) {
99     const uint32_t miss_index_id = _.GetOperandTypeId(inst, miss_index);
100     if (!_.IsUnsignedIntScalarType(miss_index_id) ||
101         _.GetBitWidth(miss_index_id) != 32) {
102       return _.diag(SPV_ERROR_INVALID_DATA, inst)
103              << "Miss Index must be a 32-bit int scalar";
104     }
105   }
106 
107   if (isValidId(cull_mask_index)) {
108     const uint32_t cull_mask_id = _.GetOperandTypeId(inst, cull_mask_index);
109     if (!_.IsUnsignedIntScalarType(cull_mask_id) ||
110         _.GetBitWidth(cull_mask_id) != 32) {
111       return _.diag(SPV_ERROR_INVALID_DATA, inst)
112              << "Cull mask must be a 32-bit int scalar";
113     }
114   }
115 
116   if (isValidId(sbt_index)) {
117     const uint32_t sbt_index_id = _.GetOperandTypeId(inst, sbt_index);
118     if (!_.IsUnsignedIntScalarType(sbt_index_id) ||
119         _.GetBitWidth(sbt_index_id) != 32) {
120       return _.diag(SPV_ERROR_INVALID_DATA, inst)
121              << "SBT Index must be a 32-bit unsigned int scalar";
122     }
123   }
124 
125   if (isValidId(sbt_offset_index)) {
126     const uint32_t sbt_offset_id = _.GetOperandTypeId(inst, sbt_offset_index);
127     if (!_.IsUnsignedIntScalarType(sbt_offset_id) ||
128         _.GetBitWidth(sbt_offset_id) != 32) {
129       return _.diag(SPV_ERROR_INVALID_DATA, inst)
130              << "SBT Offset must be a 32-bit unsigned int scalar";
131     }
132   }
133 
134   if (isValidId(sbt_stride_index)) {
135     const uint32_t sbt_stride_index_id =
136         _.GetOperandTypeId(inst, sbt_stride_index);
137     if (!_.IsUnsignedIntScalarType(sbt_stride_index_id) ||
138         _.GetBitWidth(sbt_stride_index_id) != 32) {
139       return _.diag(SPV_ERROR_INVALID_DATA, inst)
140              << "SBT Stride must be a 32-bit unsigned int scalar";
141     }
142   }
143 
144   if (isValidId(sbt_record_offset_index)) {
145     const uint32_t sbt_record_offset_index_id =
146         _.GetOperandTypeId(inst, sbt_record_offset_index);
147     if (!_.IsUnsignedIntScalarType(sbt_record_offset_index_id) ||
148         _.GetBitWidth(sbt_record_offset_index_id) != 32) {
149       return _.diag(SPV_ERROR_INVALID_DATA, inst)
150              << "SBT record offset must be a 32-bit unsigned int scalar";
151     }
152   }
153 
154   if (isValidId(sbt_record_stride_index)) {
155     const uint32_t sbt_record_stride_index_id =
156         _.GetOperandTypeId(inst, sbt_record_stride_index);
157     if (!_.IsUnsignedIntScalarType(sbt_record_stride_index_id) ||
158         _.GetBitWidth(sbt_record_stride_index_id) != 32) {
159       return _.diag(SPV_ERROR_INVALID_DATA, inst)
160              << "SBT record stride must be a 32-bit unsigned int scalar";
161     }
162   }
163 
164   if (isValidId(ray_origin_index)) {
165     const uint32_t ray_origin_id = _.GetOperandTypeId(inst, ray_origin_index);
166     if (!_.IsFloatVectorType(ray_origin_id) ||
167         _.GetDimension(ray_origin_id) != 3 ||
168         _.GetBitWidth(ray_origin_id) != 32) {
169       return _.diag(SPV_ERROR_INVALID_DATA, inst)
170              << "Ray Origin must be a 32-bit float 3-component vector";
171     }
172   }
173 
174   if (isValidId(ray_tmin_index)) {
175     const uint32_t ray_tmin_id = _.GetOperandTypeId(inst, ray_tmin_index);
176     if (!_.IsFloatScalarType(ray_tmin_id) || _.GetBitWidth(ray_tmin_id) != 32) {
177       return _.diag(SPV_ERROR_INVALID_DATA, inst)
178              << "Ray TMin must be a 32-bit float scalar";
179     }
180   }
181 
182   if (isValidId(ray_direction_index)) {
183     const uint32_t ray_direction_id =
184         _.GetOperandTypeId(inst, ray_direction_index);
185     if (!_.IsFloatVectorType(ray_direction_id) ||
186         _.GetDimension(ray_direction_id) != 3 ||
187         _.GetBitWidth(ray_direction_id) != 32) {
188       return _.diag(SPV_ERROR_INVALID_DATA, inst)
189              << "Ray Direction must be a 32-bit float 3-component vector";
190     }
191   }
192 
193   if (isValidId(ray_tmax_index)) {
194     const uint32_t ray_tmax_id = _.GetOperandTypeId(inst, ray_tmax_index);
195     if (!_.IsFloatScalarType(ray_tmax_id) || _.GetBitWidth(ray_tmax_id) != 32) {
196       return _.diag(SPV_ERROR_INVALID_DATA, inst)
197              << "Ray TMax must be a 32-bit float scalar";
198     }
199   }
200 
201   if (isValidId(ray_flags_index)) {
202     const uint32_t ray_flags_id = _.GetOperandTypeId(inst, ray_flags_index);
203     if (!_.IsIntScalarType(ray_flags_id) || _.GetBitWidth(ray_flags_id) != 32) {
204       return _.diag(SPV_ERROR_INVALID_DATA, inst)
205              << "Ray Flags must be a 32-bit int scalar";
206     }
207   }
208 
209   if (isValidId(payload_index)) {
210     const uint32_t payload_id = inst->GetOperandAs<uint32_t>(payload_index);
211     auto variable = _.FindDef(payload_id);
212     const auto var_opcode = variable->opcode();
213     if (!variable || var_opcode != spv::Op::OpVariable ||
214         (variable->GetOperandAs<spv::StorageClass>(2) !=
215              spv::StorageClass::RayPayloadKHR &&
216          variable->GetOperandAs<spv::StorageClass>(2) !=
217              spv::StorageClass::IncomingRayPayloadKHR)) {
218       return _.diag(SPV_ERROR_INVALID_DATA, inst)
219              << "payload must be a OpVariable of storage "
220                 "class RayPayloadKHR or IncomingRayPayloadKHR";
221     }
222   }
223 
224   if (isValidId(hit_kind_index)) {
225     const uint32_t hit_kind_id = _.GetOperandTypeId(inst, hit_kind_index);
226     if (!_.IsUnsignedIntScalarType(hit_kind_id) ||
227         _.GetBitWidth(hit_kind_id) != 32) {
228       return _.diag(SPV_ERROR_INVALID_DATA, inst)
229              << "Hit Kind must be a 32-bit unsigned int scalar";
230     }
231   }
232 
233   if (isValidId(hit_object_attr_index)) {
234     const uint32_t hit_object_attr_id =
235         inst->GetOperandAs<uint32_t>(hit_object_attr_index);
236     auto variable = _.FindDef(hit_object_attr_id);
237     const auto var_opcode = variable->opcode();
238     if (!variable || var_opcode != spv::Op::OpVariable ||
239         (variable->GetOperandAs<spv::StorageClass>(2)) !=
240             spv::StorageClass::HitObjectAttributeNV) {
241       return _.diag(SPV_ERROR_INVALID_DATA, inst)
242              << "Hit Object Attributes id must be a OpVariable of storage "
243                 "class HitObjectAttributeNV";
244     }
245   }
246 
247   return SPV_SUCCESS;
248 }
249 
RayReorderNVPass(ValidationState_t & _,const Instruction * inst)250 spv_result_t RayReorderNVPass(ValidationState_t& _, const Instruction* inst) {
251   const spv::Op opcode = inst->opcode();
252   const uint32_t result_type = inst->type_id();
253 
254   auto RegisterOpcodeForValidModel = [](ValidationState_t& vs,
255                                         const Instruction* rtinst) {
256     std::string opcode_name = spvOpcodeString(rtinst->opcode());
257     vs.function(rtinst->function()->id())
258         ->RegisterExecutionModelLimitation(
259             [opcode_name](spv::ExecutionModel model, std::string* message) {
260               if (model != spv::ExecutionModel::RayGenerationKHR &&
261                   model != spv::ExecutionModel::ClosestHitKHR &&
262                   model != spv::ExecutionModel::MissKHR) {
263                 if (message) {
264                   *message = opcode_name +
265                              " requires RayGenerationKHR, ClosestHitKHR and "
266                              "MissKHR execution models";
267                 }
268                 return false;
269               }
270               return true;
271             });
272     return;
273   };
274 
275   switch (opcode) {
276     case spv::Op::OpHitObjectIsMissNV:
277     case spv::Op::OpHitObjectIsHitNV:
278     case spv::Op::OpHitObjectIsEmptyNV: {
279       RegisterOpcodeForValidModel(_, inst);
280       if (!_.IsBoolScalarType(result_type)) {
281         return _.diag(SPV_ERROR_INVALID_DATA, inst)
282                << "expected Result Type to be bool scalar type";
283       }
284 
285       if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
286       break;
287     }
288 
289     case spv::Op::OpHitObjectGetShaderRecordBufferHandleNV: {
290       RegisterOpcodeForValidModel(_, inst);
291       if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
292 
293       if (!_.IsIntVectorType(result_type) ||
294           (_.GetDimension(result_type) != 2) ||
295           (_.GetBitWidth(result_type) != 32))
296         return _.diag(SPV_ERROR_INVALID_DATA, inst)
297                << "Expected 32-bit integer type 2-component vector as Result "
298                   "Type: "
299                << spvOpcodeString(opcode);
300       break;
301     }
302 
303     case spv::Op::OpHitObjectGetHitKindNV:
304     case spv::Op::OpHitObjectGetPrimitiveIndexNV:
305     case spv::Op::OpHitObjectGetGeometryIndexNV:
306     case spv::Op::OpHitObjectGetInstanceIdNV:
307     case spv::Op::OpHitObjectGetInstanceCustomIndexNV:
308     case spv::Op::OpHitObjectGetShaderBindingTableRecordIndexNV: {
309       RegisterOpcodeForValidModel(_, inst);
310       if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
311 
312       if (!_.IsIntScalarType(result_type) || !_.GetBitWidth(result_type))
313         return _.diag(SPV_ERROR_INVALID_DATA, inst)
314                << "Expected 32-bit integer type scalar as Result Type: "
315                << spvOpcodeString(opcode);
316       break;
317     }
318 
319     case spv::Op::OpHitObjectGetCurrentTimeNV:
320     case spv::Op::OpHitObjectGetRayTMaxNV:
321     case spv::Op::OpHitObjectGetRayTMinNV: {
322       RegisterOpcodeForValidModel(_, inst);
323       if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
324 
325       if (!_.IsFloatScalarType(result_type) || _.GetBitWidth(result_type) != 32)
326         return _.diag(SPV_ERROR_INVALID_DATA, inst)
327                << "Expected 32-bit floating-point type scalar as Result Type: "
328                << spvOpcodeString(opcode);
329       break;
330     }
331 
332     case spv::Op::OpHitObjectGetObjectToWorldNV:
333     case spv::Op::OpHitObjectGetWorldToObjectNV: {
334       RegisterOpcodeForValidModel(_, inst);
335       if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
336 
337       uint32_t num_rows = 0;
338       uint32_t num_cols = 0;
339       uint32_t col_type = 0;
340       uint32_t component_type = 0;
341 
342       if (!_.GetMatrixTypeInfo(result_type, &num_rows, &num_cols, &col_type,
343                                &component_type)) {
344         return _.diag(SPV_ERROR_INVALID_DATA, inst)
345                << "expected matrix type as Result Type: "
346                << spvOpcodeString(opcode);
347       }
348 
349       if (num_cols != 4) {
350         return _.diag(SPV_ERROR_INVALID_DATA, inst)
351                << "expected Result Type matrix to have a Column Count of 4"
352                << spvOpcodeString(opcode);
353       }
354 
355       if (!_.IsFloatScalarType(component_type) ||
356           _.GetBitWidth(result_type) != 32 || num_rows != 3) {
357         return _.diag(SPV_ERROR_INVALID_DATA, inst)
358                << "expected Result Type matrix to have a Column Type of "
359                   "3-component 32-bit float vectors: "
360                << spvOpcodeString(opcode);
361       }
362       break;
363     }
364 
365     case spv::Op::OpHitObjectGetObjectRayOriginNV:
366     case spv::Op::OpHitObjectGetObjectRayDirectionNV:
367     case spv::Op::OpHitObjectGetWorldRayDirectionNV:
368     case spv::Op::OpHitObjectGetWorldRayOriginNV: {
369       RegisterOpcodeForValidModel(_, inst);
370       if (auto error = ValidateHitObjectPointer(_, inst, 2)) return error;
371 
372       if (!_.IsFloatVectorType(result_type) ||
373           (_.GetDimension(result_type) != 3) ||
374           (_.GetBitWidth(result_type) != 32))
375         return _.diag(SPV_ERROR_INVALID_DATA, inst)
376                << "Expected 32-bit floating-point type 3-component vector as "
377                   "Result Type: "
378                << spvOpcodeString(opcode);
379       break;
380     }
381 
382     case spv::Op::OpHitObjectGetAttributesNV: {
383       RegisterOpcodeForValidModel(_, inst);
384       if (auto error = ValidateHitObjectPointer(_, inst, 0)) return error;
385 
386       const uint32_t hit_object_attr_id = inst->GetOperandAs<uint32_t>(1);
387       auto variable = _.FindDef(hit_object_attr_id);
388       const auto var_opcode = variable->opcode();
389       if (!variable || var_opcode != spv::Op::OpVariable ||
390           variable->GetOperandAs<spv::StorageClass>(2) !=
391               spv::StorageClass::HitObjectAttributeNV) {
392         return _.diag(SPV_ERROR_INVALID_DATA, inst)
393                << "Hit Object Attributes id must be a OpVariable of storage "
394                   "class HitObjectAttributeNV";
395       }
396       break;
397     }
398 
399     case spv::Op::OpHitObjectExecuteShaderNV: {
400       RegisterOpcodeForValidModel(_, inst);
401       if (auto error = ValidateHitObjectPointer(_, inst, 0)) return error;
402 
403       const uint32_t hit_object_attr_id = inst->GetOperandAs<uint32_t>(1);
404       auto variable = _.FindDef(hit_object_attr_id);
405       const auto var_opcode = variable->opcode();
406       if (!variable || var_opcode != spv::Op::OpVariable ||
407           (variable->GetOperandAs<spv::StorageClass>(2)) !=
408               spv::StorageClass::RayPayloadKHR) {
409         return _.diag(SPV_ERROR_INVALID_DATA, inst)
410                << "Hit Object Attributes id must be a OpVariable of storage "
411                   "class RayPayloadKHR";
412       }
413       break;
414     }
415 
416     case spv::Op::OpHitObjectRecordEmptyNV: {
417       RegisterOpcodeForValidModel(_, inst);
418       if (auto error = ValidateHitObjectPointer(_, inst, 0)) return error;
419       break;
420     }
421 
422     case spv::Op::OpHitObjectRecordMissNV: {
423       RegisterOpcodeForValidModel(_, inst);
424       if (auto error = ValidateHitObjectPointer(_, inst, 0)) return error;
425 
426       const uint32_t miss_index = _.GetOperandTypeId(inst, 1);
427       if (!_.IsUnsignedIntScalarType(miss_index) ||
428           _.GetBitWidth(miss_index) != 32) {
429         return _.diag(SPV_ERROR_INVALID_DATA, inst)
430                << "Miss Index must be a 32-bit int scalar";
431       }
432 
433       const uint32_t ray_origin = _.GetOperandTypeId(inst, 2);
434       if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 ||
435           _.GetBitWidth(ray_origin) != 32) {
436         return _.diag(SPV_ERROR_INVALID_DATA, inst)
437                << "Ray Origin must be a 32-bit float 3-component vector";
438       }
439 
440       const uint32_t ray_tmin = _.GetOperandTypeId(inst, 3);
441       if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
442         return _.diag(SPV_ERROR_INVALID_DATA, inst)
443                << "Ray TMin must be a 32-bit float scalar";
444       }
445 
446       const uint32_t ray_direction = _.GetOperandTypeId(inst, 4);
447       if (!_.IsFloatVectorType(ray_direction) ||
448           _.GetDimension(ray_direction) != 3 ||
449           _.GetBitWidth(ray_direction) != 32) {
450         return _.diag(SPV_ERROR_INVALID_DATA, inst)
451                << "Ray Direction must be a 32-bit float 3-component vector";
452       }
453 
454       const uint32_t ray_tmax = _.GetOperandTypeId(inst, 5);
455       if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
456         return _.diag(SPV_ERROR_INVALID_DATA, inst)
457                << "Ray TMax must be a 32-bit float scalar";
458       }
459       break;
460     }
461 
462     case spv::Op::OpHitObjectRecordHitWithIndexNV: {
463       RegisterOpcodeForValidModel(_, inst);
464       if (auto error = ValidateHitObjectPointer(_, inst, 0)) return error;
465 
466       if (auto error = ValidateHitObjectInstructionCommonParameters(
467               _, inst, 1 /* Acceleration Struct */, 2 /* Instance Id */,
468               3 /* Primtive Id */, 4 /* Geometry Index */,
469               KRayParamInvalidId /* Ray Flags */,
470               KRayParamInvalidId /* Cull Mask */, 5 /* Hit Kind*/,
471               6 /* SBT index */, KRayParamInvalidId /* SBT Offset */,
472               KRayParamInvalidId /* SBT Stride */,
473               KRayParamInvalidId /* SBT Record Offset */,
474               KRayParamInvalidId /* SBT Record Stride */,
475               KRayParamInvalidId /* Miss Index */, 7 /* Ray Origin */,
476               8 /* Ray TMin */, 9 /* Ray Direction */, 10 /* Ray TMax */,
477               KRayParamInvalidId /* Payload */, 11 /* Hit Object Attribute */))
478         return error;
479 
480       break;
481     }
482 
483     case spv::Op::OpHitObjectRecordHitNV: {
484       RegisterOpcodeForValidModel(_, inst);
485       if (auto error = ValidateHitObjectPointer(_, inst, 0)) return error;
486 
487       if (auto error = ValidateHitObjectInstructionCommonParameters(
488               _, inst, 1 /* Acceleration Struct */, 2 /* Instance Id */,
489               3 /* Primtive Id */, 4 /* Geometry Index */,
490               KRayParamInvalidId /* Ray Flags */,
491               KRayParamInvalidId /* Cull Mask */, 5 /* Hit Kind*/,
492               KRayParamInvalidId /* SBT index */,
493               KRayParamInvalidId /* SBT Offset */,
494               KRayParamInvalidId /* SBT Stride */, 6 /* SBT Record Offset */,
495               7 /* SBT Record Stride */, KRayParamInvalidId /* Miss Index */,
496               8 /* Ray Origin */, 9 /* Ray TMin */, 10 /* Ray Direction */,
497               11 /* Ray TMax */, KRayParamInvalidId /* Payload */,
498               12 /* Hit Object Attribute */))
499         return error;
500 
501       break;
502     }
503 
504     case spv::Op::OpHitObjectTraceRayMotionNV: {
505       RegisterOpcodeForValidModel(_, inst);
506       if (auto error = ValidateHitObjectPointer(_, inst, 0)) return error;
507 
508       if (auto error = ValidateHitObjectInstructionCommonParameters(
509               _, inst, 1 /* Acceleration Struct */,
510               KRayParamInvalidId /* Instance Id */,
511               KRayParamInvalidId /* Primtive Id */,
512               KRayParamInvalidId /* Geometry Index */, 2 /* Ray Flags */,
513               3 /* Cull Mask */, KRayParamInvalidId /* Hit Kind*/,
514               KRayParamInvalidId /* SBT index */, 4 /* SBT Offset */,
515               5 /* SBT Stride */, KRayParamInvalidId /* SBT Record Offset */,
516               KRayParamInvalidId /* SBT Record Stride */, 6 /* Miss Index */,
517               7 /* Ray Origin */, 8 /* Ray TMin */, 9 /* Ray Direction */,
518               10 /* Ray TMax */, 12 /* Payload */,
519               KRayParamInvalidId /* Hit Object Attribute */))
520         return error;
521       // Current Time
522       const uint32_t current_time_id = _.GetOperandTypeId(inst, 11);
523       if (!_.IsFloatScalarType(current_time_id) ||
524           _.GetBitWidth(current_time_id) != 32) {
525         return _.diag(SPV_ERROR_INVALID_DATA, inst)
526                << "Current Times must be a 32-bit float scalar type";
527       }
528 
529       break;
530     }
531 
532     case spv::Op::OpHitObjectTraceRayNV: {
533       RegisterOpcodeForValidModel(_, inst);
534       if (auto error = ValidateHitObjectPointer(_, inst, 0)) return error;
535 
536       if (auto error = ValidateHitObjectInstructionCommonParameters(
537               _, inst, 1 /* Acceleration Struct */,
538               KRayParamInvalidId /* Instance Id */,
539               KRayParamInvalidId /* Primtive Id */,
540               KRayParamInvalidId /* Geometry Index */, 2 /* Ray Flags */,
541               3 /* Cull Mask */, KRayParamInvalidId /* Hit Kind*/,
542               KRayParamInvalidId /* SBT index */, 4 /* SBT Offset */,
543               5 /* SBT Stride */, KRayParamInvalidId /* SBT Record Offset */,
544               KRayParamInvalidId /* SBT Record Stride */, 6 /* Miss Index */,
545               7 /* Ray Origin */, 8 /* Ray TMin */, 9 /* Ray Direction */,
546               10 /* Ray TMax */, 11 /* Payload */,
547               KRayParamInvalidId /* Hit Object Attribute */))
548         return error;
549       break;
550     }
551 
552     case spv::Op::OpReorderThreadWithHitObjectNV: {
553       std::string opcode_name = spvOpcodeString(inst->opcode());
554       _.function(inst->function()->id())
555           ->RegisterExecutionModelLimitation(
556               [opcode_name](spv::ExecutionModel model, std::string* message) {
557                 if (model != spv::ExecutionModel::RayGenerationKHR) {
558                   if (message) {
559                     *message = opcode_name +
560                                " requires RayGenerationKHR execution model";
561                   }
562                   return false;
563                 }
564                 return true;
565               });
566 
567       if (auto error = ValidateHitObjectPointer(_, inst, 0)) return error;
568 
569       if (inst->operands().size() > 1) {
570         if (inst->operands().size() != 3) {
571           return _.diag(SPV_ERROR_INVALID_DATA, inst)
572                  << "Hint and Bits are optional together i.e "
573                  << " Either both Hint and Bits should be provided or neither.";
574         }
575 
576         // Validate the optional opreands Hint and Bits
577         const uint32_t hint_id = _.GetOperandTypeId(inst, 1);
578         if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
579           return _.diag(SPV_ERROR_INVALID_DATA, inst)
580                  << "Hint must be a 32-bit int scalar";
581         }
582         const uint32_t bits_id = _.GetOperandTypeId(inst, 2);
583         if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
584           return _.diag(SPV_ERROR_INVALID_DATA, inst)
585                  << "bits must be a 32-bit int scalar";
586         }
587       }
588       break;
589     }
590 
591     case spv::Op::OpReorderThreadWithHintNV: {
592       std::string opcode_name = spvOpcodeString(inst->opcode());
593       _.function(inst->function()->id())
594           ->RegisterExecutionModelLimitation(
595               [opcode_name](spv::ExecutionModel model, std::string* message) {
596                 if (model != spv::ExecutionModel::RayGenerationKHR) {
597                   if (message) {
598                     *message = opcode_name +
599                                " requires RayGenerationKHR execution model";
600                   }
601                   return false;
602                 }
603                 return true;
604               });
605 
606       const uint32_t hint_id = _.GetOperandTypeId(inst, 0);
607       if (!_.IsIntScalarType(hint_id) || _.GetBitWidth(hint_id) != 32) {
608         return _.diag(SPV_ERROR_INVALID_DATA, inst)
609                << "Hint must be a 32-bit int scalar";
610       }
611 
612       const uint32_t bits_id = _.GetOperandTypeId(inst, 1);
613       if (!_.IsIntScalarType(bits_id) || _.GetBitWidth(bits_id) != 32) {
614         return _.diag(SPV_ERROR_INVALID_DATA, inst)
615                << "bits must be a 32-bit int scalar";
616       }
617     }
618 
619     default:
620       break;
621   }
622   return SPV_SUCCESS;
623 }
624 }  // namespace val
625 }  // namespace spvtools
626