xref: /aosp_15_r20/external/angle/third_party/spirv-tools/src/source/val/validate_arithmetics.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Performs validation of arithmetic instructions.
16 
17 #include <vector>
18 
19 #include "source/opcode.h"
20 #include "source/val/instruction.h"
21 #include "source/val/validate.h"
22 #include "source/val/validation_state.h"
23 
24 namespace spvtools {
25 namespace val {
26 
27 // Validates correctness of arithmetic instructions.
ArithmeticsPass(ValidationState_t & _,const Instruction * inst)28 spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
29   const spv::Op opcode = inst->opcode();
30   const uint32_t result_type = inst->type_id();
31 
32   switch (opcode) {
33     case spv::Op::OpFAdd:
34     case spv::Op::OpFSub:
35     case spv::Op::OpFMul:
36     case spv::Op::OpFDiv:
37     case spv::Op::OpFRem:
38     case spv::Op::OpFMod:
39     case spv::Op::OpFNegate: {
40       bool supportsCoopMat =
41           (opcode != spv::Op::OpFMul && opcode != spv::Op::OpFRem &&
42            opcode != spv::Op::OpFMod);
43       if (!_.IsFloatScalarType(result_type) &&
44           !_.IsFloatVectorType(result_type) &&
45           !(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)) &&
46           !(opcode == spv::Op::OpFMul &&
47             _.IsCooperativeMatrixKHRType(result_type) &&
48             _.IsFloatCooperativeMatrixType(result_type)))
49         return _.diag(SPV_ERROR_INVALID_DATA, inst)
50                << "Expected floating scalar or vector type as Result Type: "
51                << spvOpcodeString(opcode);
52 
53       for (size_t operand_index = 2; operand_index < inst->operands().size();
54            ++operand_index) {
55         if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
56           const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
57           if (!_.IsCooperativeMatrixKHRType(type_id) ||
58               !_.IsFloatCooperativeMatrixType(type_id)) {
59             return _.diag(SPV_ERROR_INVALID_DATA, inst)
60                    << "Expected arithmetic operands to be of Result Type: "
61                    << spvOpcodeString(opcode) << " operand index "
62                    << operand_index;
63           }
64           spv_result_t ret =
65               _.CooperativeMatrixShapesMatch(inst, result_type, type_id, false);
66           if (ret != SPV_SUCCESS) return ret;
67         } else if (_.GetOperandTypeId(inst, operand_index) != result_type)
68           return _.diag(SPV_ERROR_INVALID_DATA, inst)
69                  << "Expected arithmetic operands to be of Result Type: "
70                  << spvOpcodeString(opcode) << " operand index "
71                  << operand_index;
72       }
73       break;
74     }
75 
76     case spv::Op::OpUDiv:
77     case spv::Op::OpUMod: {
78       bool supportsCoopMat = (opcode == spv::Op::OpUDiv);
79       if (!_.IsUnsignedIntScalarType(result_type) &&
80           !_.IsUnsignedIntVectorType(result_type) &&
81           !(supportsCoopMat &&
82             _.IsUnsignedIntCooperativeMatrixType(result_type)))
83         return _.diag(SPV_ERROR_INVALID_DATA, inst)
84                << "Expected unsigned int scalar or vector type as Result Type: "
85                << spvOpcodeString(opcode);
86 
87       for (size_t operand_index = 2; operand_index < inst->operands().size();
88            ++operand_index) {
89         if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
90           const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
91           if (!_.IsCooperativeMatrixKHRType(type_id) ||
92               !_.IsUnsignedIntCooperativeMatrixType(type_id)) {
93             return _.diag(SPV_ERROR_INVALID_DATA, inst)
94                    << "Expected arithmetic operands to be of Result Type: "
95                    << spvOpcodeString(opcode) << " operand index "
96                    << operand_index;
97           }
98           spv_result_t ret =
99               _.CooperativeMatrixShapesMatch(inst, result_type, type_id, false);
100           if (ret != SPV_SUCCESS) return ret;
101         } else if (_.GetOperandTypeId(inst, operand_index) != result_type)
102           return _.diag(SPV_ERROR_INVALID_DATA, inst)
103                  << "Expected arithmetic operands to be of Result Type: "
104                  << spvOpcodeString(opcode) << " operand index "
105                  << operand_index;
106       }
107       break;
108     }
109 
110     case spv::Op::OpISub:
111     case spv::Op::OpIAdd:
112     case spv::Op::OpIMul:
113     case spv::Op::OpSDiv:
114     case spv::Op::OpSMod:
115     case spv::Op::OpSRem:
116     case spv::Op::OpSNegate: {
117       bool supportsCoopMat =
118           (opcode != spv::Op::OpIMul && opcode != spv::Op::OpSRem &&
119            opcode != spv::Op::OpSMod);
120       if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
121           !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) &&
122           !(opcode == spv::Op::OpIMul &&
123             _.IsCooperativeMatrixKHRType(result_type) &&
124             _.IsIntCooperativeMatrixType(result_type)))
125         return _.diag(SPV_ERROR_INVALID_DATA, inst)
126                << "Expected int scalar or vector type as Result Type: "
127                << spvOpcodeString(opcode);
128 
129       const uint32_t dimension = _.GetDimension(result_type);
130       const uint32_t bit_width = _.GetBitWidth(result_type);
131 
132       for (size_t operand_index = 2; operand_index < inst->operands().size();
133            ++operand_index) {
134         const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
135 
136         if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
137           if (!_.IsCooperativeMatrixKHRType(type_id) ||
138               !_.IsIntCooperativeMatrixType(type_id)) {
139             return _.diag(SPV_ERROR_INVALID_DATA, inst)
140                    << "Expected arithmetic operands to be of Result Type: "
141                    << spvOpcodeString(opcode) << " operand index "
142                    << operand_index;
143           }
144           spv_result_t ret =
145               _.CooperativeMatrixShapesMatch(inst, result_type, type_id, false);
146           if (ret != SPV_SUCCESS) return ret;
147         }
148 
149         if (!type_id ||
150             (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id) &&
151              !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) &&
152              !(opcode == spv::Op::OpIMul &&
153                _.IsCooperativeMatrixKHRType(result_type) &&
154                _.IsIntCooperativeMatrixType(result_type))))
155           return _.diag(SPV_ERROR_INVALID_DATA, inst)
156                  << "Expected int scalar or vector type as operand: "
157                  << spvOpcodeString(opcode) << " operand index "
158                  << operand_index;
159 
160         if (_.GetDimension(type_id) != dimension)
161           return _.diag(SPV_ERROR_INVALID_DATA, inst)
162                  << "Expected arithmetic operands to have the same dimension "
163                  << "as Result Type: " << spvOpcodeString(opcode)
164                  << " operand index " << operand_index;
165 
166         if (_.GetBitWidth(type_id) != bit_width)
167           return _.diag(SPV_ERROR_INVALID_DATA, inst)
168                  << "Expected arithmetic operands to have the same bit width "
169                  << "as Result Type: " << spvOpcodeString(opcode)
170                  << " operand index " << operand_index;
171       }
172       break;
173     }
174 
175     case spv::Op::OpDot: {
176       if (!_.IsFloatScalarType(result_type))
177         return _.diag(SPV_ERROR_INVALID_DATA, inst)
178                << "Expected float scalar type as Result Type: "
179                << spvOpcodeString(opcode);
180 
181       uint32_t first_vector_num_components = 0;
182 
183       for (size_t operand_index = 2; operand_index < inst->operands().size();
184            ++operand_index) {
185         const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
186 
187         if (!type_id || !_.IsFloatVectorType(type_id))
188           return _.diag(SPV_ERROR_INVALID_DATA, inst)
189                  << "Expected float vector as operand: "
190                  << spvOpcodeString(opcode) << " operand index "
191                  << operand_index;
192 
193         const uint32_t component_type = _.GetComponentType(type_id);
194         if (component_type != result_type)
195           return _.diag(SPV_ERROR_INVALID_DATA, inst)
196                  << "Expected component type to be equal to Result Type: "
197                  << spvOpcodeString(opcode) << " operand index "
198                  << operand_index;
199 
200         const uint32_t num_components = _.GetDimension(type_id);
201         if (operand_index == 2) {
202           first_vector_num_components = num_components;
203         } else if (num_components != first_vector_num_components) {
204           return _.diag(SPV_ERROR_INVALID_DATA, inst)
205                  << "Expected operands to have the same number of components: "
206                  << spvOpcodeString(opcode);
207         }
208       }
209       break;
210     }
211 
212     case spv::Op::OpVectorTimesScalar: {
213       if (!_.IsFloatVectorType(result_type))
214         return _.diag(SPV_ERROR_INVALID_DATA, inst)
215                << "Expected float vector type as Result Type: "
216                << spvOpcodeString(opcode);
217 
218       const uint32_t vector_type_id = _.GetOperandTypeId(inst, 2);
219       if (result_type != vector_type_id)
220         return _.diag(SPV_ERROR_INVALID_DATA, inst)
221                << "Expected vector operand type to be equal to Result Type: "
222                << spvOpcodeString(opcode);
223 
224       const uint32_t component_type = _.GetComponentType(vector_type_id);
225 
226       const uint32_t scalar_type_id = _.GetOperandTypeId(inst, 3);
227       if (component_type != scalar_type_id)
228         return _.diag(SPV_ERROR_INVALID_DATA, inst)
229                << "Expected scalar operand type to be equal to the component "
230                << "type of the vector operand: " << spvOpcodeString(opcode);
231 
232       break;
233     }
234 
235     case spv::Op::OpMatrixTimesScalar: {
236       if (!_.IsFloatMatrixType(result_type) &&
237           !(_.IsCooperativeMatrixType(result_type)))
238         return _.diag(SPV_ERROR_INVALID_DATA, inst)
239                << "Expected float matrix type as Result Type: "
240                << spvOpcodeString(opcode);
241 
242       const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 2);
243       if (result_type != matrix_type_id)
244         return _.diag(SPV_ERROR_INVALID_DATA, inst)
245                << "Expected matrix operand type to be equal to Result Type: "
246                << spvOpcodeString(opcode);
247 
248       const uint32_t component_type = _.GetComponentType(matrix_type_id);
249 
250       const uint32_t scalar_type_id = _.GetOperandTypeId(inst, 3);
251       if (component_type != scalar_type_id)
252         return _.diag(SPV_ERROR_INVALID_DATA, inst)
253                << "Expected scalar operand type to be equal to the component "
254                << "type of the matrix operand: " << spvOpcodeString(opcode);
255 
256       break;
257     }
258 
259     case spv::Op::OpVectorTimesMatrix: {
260       const uint32_t vector_type_id = _.GetOperandTypeId(inst, 2);
261       const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 3);
262 
263       if (!_.IsFloatVectorType(result_type))
264         return _.diag(SPV_ERROR_INVALID_DATA, inst)
265                << "Expected float vector type as Result Type: "
266                << spvOpcodeString(opcode);
267 
268       const uint32_t res_component_type = _.GetComponentType(result_type);
269 
270       if (!vector_type_id || !_.IsFloatVectorType(vector_type_id))
271         return _.diag(SPV_ERROR_INVALID_DATA, inst)
272                << "Expected float vector type as left operand: "
273                << spvOpcodeString(opcode);
274 
275       if (res_component_type != _.GetComponentType(vector_type_id))
276         return _.diag(SPV_ERROR_INVALID_DATA, inst)
277                << "Expected component types of Result Type and vector to be "
278                << "equal: " << spvOpcodeString(opcode);
279 
280       uint32_t matrix_num_rows = 0;
281       uint32_t matrix_num_cols = 0;
282       uint32_t matrix_col_type = 0;
283       uint32_t matrix_component_type = 0;
284       if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows,
285                                &matrix_num_cols, &matrix_col_type,
286                                &matrix_component_type))
287         return _.diag(SPV_ERROR_INVALID_DATA, inst)
288                << "Expected float matrix type as right operand: "
289                << spvOpcodeString(opcode);
290 
291       if (res_component_type != matrix_component_type)
292         return _.diag(SPV_ERROR_INVALID_DATA, inst)
293                << "Expected component types of Result Type and matrix to be "
294                << "equal: " << spvOpcodeString(opcode);
295 
296       if (matrix_num_cols != _.GetDimension(result_type))
297         return _.diag(SPV_ERROR_INVALID_DATA, inst)
298                << "Expected number of columns of the matrix to be equal to "
299                << "Result Type vector size: " << spvOpcodeString(opcode);
300 
301       if (matrix_num_rows != _.GetDimension(vector_type_id))
302         return _.diag(SPV_ERROR_INVALID_DATA, inst)
303                << "Expected number of rows of the matrix to be equal to the "
304                << "vector operand size: " << spvOpcodeString(opcode);
305 
306       break;
307     }
308 
309     case spv::Op::OpMatrixTimesVector: {
310       const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 2);
311       const uint32_t vector_type_id = _.GetOperandTypeId(inst, 3);
312 
313       if (!_.IsFloatVectorType(result_type))
314         return _.diag(SPV_ERROR_INVALID_DATA, inst)
315                << "Expected float vector type as Result Type: "
316                << spvOpcodeString(opcode);
317 
318       uint32_t matrix_num_rows = 0;
319       uint32_t matrix_num_cols = 0;
320       uint32_t matrix_col_type = 0;
321       uint32_t matrix_component_type = 0;
322       if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows,
323                                &matrix_num_cols, &matrix_col_type,
324                                &matrix_component_type))
325         return _.diag(SPV_ERROR_INVALID_DATA, inst)
326                << "Expected float matrix type as left operand: "
327                << spvOpcodeString(opcode);
328 
329       if (result_type != matrix_col_type)
330         return _.diag(SPV_ERROR_INVALID_DATA, inst)
331                << "Expected column type of the matrix to be equal to Result "
332                   "Type: "
333                << spvOpcodeString(opcode);
334 
335       if (!vector_type_id || !_.IsFloatVectorType(vector_type_id))
336         return _.diag(SPV_ERROR_INVALID_DATA, inst)
337                << "Expected float vector type as right operand: "
338                << spvOpcodeString(opcode);
339 
340       if (matrix_component_type != _.GetComponentType(vector_type_id))
341         return _.diag(SPV_ERROR_INVALID_DATA, inst)
342                << "Expected component types of the operands to be equal: "
343                << spvOpcodeString(opcode);
344 
345       if (matrix_num_cols != _.GetDimension(vector_type_id))
346         return _.diag(SPV_ERROR_INVALID_DATA, inst)
347                << "Expected number of columns of the matrix to be equal to the "
348                << "vector size: " << spvOpcodeString(opcode);
349 
350       break;
351     }
352 
353     case spv::Op::OpMatrixTimesMatrix: {
354       const uint32_t left_type_id = _.GetOperandTypeId(inst, 2);
355       const uint32_t right_type_id = _.GetOperandTypeId(inst, 3);
356 
357       uint32_t res_num_rows = 0;
358       uint32_t res_num_cols = 0;
359       uint32_t res_col_type = 0;
360       uint32_t res_component_type = 0;
361       if (!_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols,
362                                &res_col_type, &res_component_type))
363         return _.diag(SPV_ERROR_INVALID_DATA, inst)
364                << "Expected float matrix type as Result Type: "
365                << spvOpcodeString(opcode);
366 
367       uint32_t left_num_rows = 0;
368       uint32_t left_num_cols = 0;
369       uint32_t left_col_type = 0;
370       uint32_t left_component_type = 0;
371       if (!_.GetMatrixTypeInfo(left_type_id, &left_num_rows, &left_num_cols,
372                                &left_col_type, &left_component_type))
373         return _.diag(SPV_ERROR_INVALID_DATA, inst)
374                << "Expected float matrix type as left operand: "
375                << spvOpcodeString(opcode);
376 
377       uint32_t right_num_rows = 0;
378       uint32_t right_num_cols = 0;
379       uint32_t right_col_type = 0;
380       uint32_t right_component_type = 0;
381       if (!_.GetMatrixTypeInfo(right_type_id, &right_num_rows, &right_num_cols,
382                                &right_col_type, &right_component_type))
383         return _.diag(SPV_ERROR_INVALID_DATA, inst)
384                << "Expected float matrix type as right operand: "
385                << spvOpcodeString(opcode);
386 
387       if (!_.IsFloatScalarType(res_component_type))
388         return _.diag(SPV_ERROR_INVALID_DATA, inst)
389                << "Expected float matrix type as Result Type: "
390                << spvOpcodeString(opcode);
391 
392       if (res_col_type != left_col_type)
393         return _.diag(SPV_ERROR_INVALID_DATA, inst)
394                << "Expected column types of Result Type and left matrix to be "
395                << "equal: " << spvOpcodeString(opcode);
396 
397       if (res_component_type != right_component_type)
398         return _.diag(SPV_ERROR_INVALID_DATA, inst)
399                << "Expected component types of Result Type and right matrix to "
400                   "be "
401                << "equal: " << spvOpcodeString(opcode);
402 
403       if (res_num_cols != right_num_cols)
404         return _.diag(SPV_ERROR_INVALID_DATA, inst)
405                << "Expected number of columns of Result Type and right matrix "
406                   "to "
407                << "be equal: " << spvOpcodeString(opcode);
408 
409       if (left_num_cols != right_num_rows)
410         return _.diag(SPV_ERROR_INVALID_DATA, inst)
411                << "Expected number of columns of left matrix and number of "
412                   "rows "
413                << "of right matrix to be equal: " << spvOpcodeString(opcode);
414 
415       assert(left_num_rows == res_num_rows);
416       break;
417     }
418 
419     case spv::Op::OpOuterProduct: {
420       const uint32_t left_type_id = _.GetOperandTypeId(inst, 2);
421       const uint32_t right_type_id = _.GetOperandTypeId(inst, 3);
422 
423       uint32_t res_num_rows = 0;
424       uint32_t res_num_cols = 0;
425       uint32_t res_col_type = 0;
426       uint32_t res_component_type = 0;
427       if (!_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols,
428                                &res_col_type, &res_component_type))
429         return _.diag(SPV_ERROR_INVALID_DATA, inst)
430                << "Expected float matrix type as Result Type: "
431                << spvOpcodeString(opcode);
432 
433       if (left_type_id != res_col_type)
434         return _.diag(SPV_ERROR_INVALID_DATA, inst)
435                << "Expected column type of Result Type to be equal to the type "
436                << "of the left operand: " << spvOpcodeString(opcode);
437 
438       if (!right_type_id || !_.IsFloatVectorType(right_type_id))
439         return _.diag(SPV_ERROR_INVALID_DATA, inst)
440                << "Expected float vector type as right operand: "
441                << spvOpcodeString(opcode);
442 
443       if (res_component_type != _.GetComponentType(right_type_id))
444         return _.diag(SPV_ERROR_INVALID_DATA, inst)
445                << "Expected component types of the operands to be equal: "
446                << spvOpcodeString(opcode);
447 
448       if (res_num_cols != _.GetDimension(right_type_id))
449         return _.diag(SPV_ERROR_INVALID_DATA, inst)
450                << "Expected number of columns of the matrix to be equal to the "
451                << "vector size of the right operand: "
452                << spvOpcodeString(opcode);
453 
454       break;
455     }
456 
457     case spv::Op::OpIAddCarry:
458     case spv::Op::OpISubBorrow:
459     case spv::Op::OpUMulExtended:
460     case spv::Op::OpSMulExtended: {
461       std::vector<uint32_t> result_types;
462       if (!_.GetStructMemberTypes(result_type, &result_types))
463         return _.diag(SPV_ERROR_INVALID_DATA, inst)
464                << "Expected a struct as Result Type: "
465                << spvOpcodeString(opcode);
466 
467       if (result_types.size() != 2)
468         return _.diag(SPV_ERROR_INVALID_DATA, inst)
469                << "Expected Result Type struct to have two members: "
470                << spvOpcodeString(opcode);
471 
472       if (opcode == spv::Op::OpSMulExtended) {
473         if (!_.IsIntScalarType(result_types[0]) &&
474             !_.IsIntVectorType(result_types[0]))
475           return _.diag(SPV_ERROR_INVALID_DATA, inst)
476                  << "Expected Result Type struct member types to be integer "
477                     "scalar "
478                  << "or vector: " << spvOpcodeString(opcode);
479       } else {
480         if (!_.IsUnsignedIntScalarType(result_types[0]) &&
481             !_.IsUnsignedIntVectorType(result_types[0]))
482           return _.diag(SPV_ERROR_INVALID_DATA, inst)
483                  << "Expected Result Type struct member types to be unsigned "
484                  << "integer scalar or vector: " << spvOpcodeString(opcode);
485       }
486 
487       if (result_types[0] != result_types[1])
488         return _.diag(SPV_ERROR_INVALID_DATA, inst)
489                << "Expected Result Type struct member types to be identical: "
490                << spvOpcodeString(opcode);
491 
492       const uint32_t left_type_id = _.GetOperandTypeId(inst, 2);
493       const uint32_t right_type_id = _.GetOperandTypeId(inst, 3);
494 
495       if (left_type_id != result_types[0] || right_type_id != result_types[0])
496         return _.diag(SPV_ERROR_INVALID_DATA, inst)
497                << "Expected both operands to be of Result Type member type: "
498                << spvOpcodeString(opcode);
499 
500       break;
501     }
502 
503     case spv::Op::OpCooperativeMatrixMulAddNV: {
504       const uint32_t D_type_id = _.GetOperandTypeId(inst, 1);
505       const uint32_t A_type_id = _.GetOperandTypeId(inst, 2);
506       const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
507       const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
508 
509       if (!_.IsCooperativeMatrixNVType(A_type_id)) {
510         return _.diag(SPV_ERROR_INVALID_DATA, inst)
511                << "Expected cooperative matrix type as A Type: "
512                << spvOpcodeString(opcode);
513       }
514       if (!_.IsCooperativeMatrixNVType(B_type_id)) {
515         return _.diag(SPV_ERROR_INVALID_DATA, inst)
516                << "Expected cooperative matrix type as B Type: "
517                << spvOpcodeString(opcode);
518       }
519       if (!_.IsCooperativeMatrixNVType(C_type_id)) {
520         return _.diag(SPV_ERROR_INVALID_DATA, inst)
521                << "Expected cooperative matrix type as C Type: "
522                << spvOpcodeString(opcode);
523       }
524       if (!_.IsCooperativeMatrixNVType(D_type_id)) {
525         return _.diag(SPV_ERROR_INVALID_DATA, inst)
526                << "Expected cooperative matrix type as Result Type: "
527                << spvOpcodeString(opcode);
528       }
529 
530       const auto A = _.FindDef(A_type_id);
531       const auto B = _.FindDef(B_type_id);
532       const auto C = _.FindDef(C_type_id);
533       const auto D = _.FindDef(D_type_id);
534 
535       std::tuple<bool, bool, uint32_t> A_scope, B_scope, C_scope, D_scope,
536           A_rows, B_rows, C_rows, D_rows, A_cols, B_cols, C_cols, D_cols;
537 
538       A_scope = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(2));
539       B_scope = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(2));
540       C_scope = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(2));
541       D_scope = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(2));
542 
543       A_rows = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(3));
544       B_rows = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(3));
545       C_rows = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(3));
546       D_rows = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(3));
547 
548       A_cols = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(4));
549       B_cols = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(4));
550       C_cols = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(4));
551       D_cols = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(4));
552 
553       const auto notEqual = [](std::tuple<bool, bool, uint32_t> X,
554                                std::tuple<bool, bool, uint32_t> Y) {
555         return (std::get<1>(X) && std::get<1>(Y) &&
556                 std::get<2>(X) != std::get<2>(Y));
557       };
558 
559       if (notEqual(A_scope, B_scope) || notEqual(A_scope, C_scope) ||
560           notEqual(A_scope, D_scope) || notEqual(B_scope, C_scope) ||
561           notEqual(B_scope, D_scope) || notEqual(C_scope, D_scope)) {
562         return _.diag(SPV_ERROR_INVALID_DATA, inst)
563                << "Cooperative matrix scopes must match: "
564                << spvOpcodeString(opcode);
565       }
566 
567       if (notEqual(A_rows, C_rows) || notEqual(A_rows, D_rows) ||
568           notEqual(C_rows, D_rows)) {
569         return _.diag(SPV_ERROR_INVALID_DATA, inst)
570                << "Cooperative matrix 'M' mismatch: "
571                << spvOpcodeString(opcode);
572       }
573 
574       if (notEqual(B_cols, C_cols) || notEqual(B_cols, D_cols) ||
575           notEqual(C_cols, D_cols)) {
576         return _.diag(SPV_ERROR_INVALID_DATA, inst)
577                << "Cooperative matrix 'N' mismatch: "
578                << spvOpcodeString(opcode);
579       }
580 
581       if (notEqual(A_cols, B_rows)) {
582         return _.diag(SPV_ERROR_INVALID_DATA, inst)
583                << "Cooperative matrix 'K' mismatch: "
584                << spvOpcodeString(opcode);
585       }
586       break;
587     }
588 
589     case spv::Op::OpCooperativeMatrixMulAddKHR: {
590       const uint32_t D_type_id = _.GetOperandTypeId(inst, 1);
591       const uint32_t A_type_id = _.GetOperandTypeId(inst, 2);
592       const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
593       const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
594 
595       if (!_.IsCooperativeMatrixAType(A_type_id)) {
596         return _.diag(SPV_ERROR_INVALID_DATA, inst)
597                << "Cooperative matrix type must be A Type: "
598                << spvOpcodeString(opcode);
599       }
600       if (!_.IsCooperativeMatrixBType(B_type_id)) {
601         return _.diag(SPV_ERROR_INVALID_DATA, inst)
602                << "Cooperative matrix type must be B Type: "
603                << spvOpcodeString(opcode);
604       }
605       if (!_.IsCooperativeMatrixAccType(C_type_id)) {
606         return _.diag(SPV_ERROR_INVALID_DATA, inst)
607                << "Cooperative matrix type must be Accumulator Type: "
608                << spvOpcodeString(opcode);
609       }
610       if (!_.IsCooperativeMatrixKHRType(D_type_id)) {
611         return _.diag(SPV_ERROR_INVALID_DATA, inst)
612                << "Expected cooperative matrix type as Result Type: "
613                << spvOpcodeString(opcode);
614       }
615 
616       const auto A = _.FindDef(A_type_id);
617       const auto B = _.FindDef(B_type_id);
618       const auto C = _.FindDef(C_type_id);
619       const auto D = _.FindDef(D_type_id);
620 
621       std::tuple<bool, bool, uint32_t> A_scope, B_scope, C_scope, D_scope,
622           A_rows, B_rows, C_rows, D_rows, A_cols, B_cols, C_cols, D_cols;
623 
624       A_scope = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(2));
625       B_scope = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(2));
626       C_scope = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(2));
627       D_scope = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(2));
628 
629       A_rows = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(3));
630       B_rows = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(3));
631       C_rows = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(3));
632       D_rows = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(3));
633 
634       A_cols = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(4));
635       B_cols = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(4));
636       C_cols = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(4));
637       D_cols = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(4));
638 
639       const auto notEqual = [](std::tuple<bool, bool, uint32_t> X,
640                                std::tuple<bool, bool, uint32_t> Y) {
641         return (std::get<1>(X) && std::get<1>(Y) &&
642                 std::get<2>(X) != std::get<2>(Y));
643       };
644 
645       if (notEqual(A_scope, B_scope) || notEqual(A_scope, C_scope) ||
646           notEqual(A_scope, D_scope) || notEqual(B_scope, C_scope) ||
647           notEqual(B_scope, D_scope) || notEqual(C_scope, D_scope)) {
648         return _.diag(SPV_ERROR_INVALID_DATA, inst)
649                << "Cooperative matrix scopes must match: "
650                << spvOpcodeString(opcode);
651       }
652 
653       if (notEqual(A_rows, C_rows) || notEqual(A_rows, D_rows) ||
654           notEqual(C_rows, D_rows)) {
655         return _.diag(SPV_ERROR_INVALID_DATA, inst)
656                << "Cooperative matrix 'M' mismatch: "
657                << spvOpcodeString(opcode);
658       }
659 
660       if (notEqual(B_cols, C_cols) || notEqual(B_cols, D_cols) ||
661           notEqual(C_cols, D_cols)) {
662         return _.diag(SPV_ERROR_INVALID_DATA, inst)
663                << "Cooperative matrix 'N' mismatch: "
664                << spvOpcodeString(opcode);
665       }
666 
667       if (notEqual(A_cols, B_rows)) {
668         return _.diag(SPV_ERROR_INVALID_DATA, inst)
669                << "Cooperative matrix 'K' mismatch: "
670                << spvOpcodeString(opcode);
671       }
672       break;
673     }
674 
675     case spv::Op::OpCooperativeMatrixReduceNV: {
676       if (!_.IsCooperativeMatrixKHRType(result_type)) {
677         return _.diag(SPV_ERROR_INVALID_DATA, inst)
678                << "Result Type must be a cooperative matrix type: "
679                << spvOpcodeString(opcode);
680       }
681 
682       const auto result_comp_type_id =
683           _.FindDef(result_type)->GetOperandAs<uint32_t>(1);
684 
685       const auto matrix_id = inst->GetOperandAs<uint32_t>(2);
686       const auto matrix = _.FindDef(matrix_id);
687       const auto matrix_type_id = matrix->type_id();
688       if (!_.IsCooperativeMatrixKHRType(matrix_type_id)) {
689         return _.diag(SPV_ERROR_INVALID_DATA, inst)
690                << "Matrix must have a cooperative matrix type: "
691                << spvOpcodeString(opcode);
692       }
693       const auto matrix_type = _.FindDef(matrix_type_id);
694       const auto matrix_comp_type_id = matrix_type->GetOperandAs<uint32_t>(1);
695       if (matrix_comp_type_id != result_comp_type_id) {
696         return _.diag(SPV_ERROR_INVALID_DATA, inst)
697                << "Result Type and Matrix type must have the same component "
698                   "type: "
699                << spvOpcodeString(opcode);
700       }
701       if (_.FindDef(result_type)->GetOperandAs<uint32_t>(2) !=
702           matrix_type->GetOperandAs<uint32_t>(2)) {
703         return _.diag(SPV_ERROR_INVALID_DATA, inst)
704                << "Result Type and Matrix type must have the same scope: "
705                << spvOpcodeString(opcode);
706       }
707 
708       if (!_.IsCooperativeMatrixAccType(result_type)) {
709         return _.diag(SPV_ERROR_INVALID_DATA, inst)
710                << "Result Type must have UseAccumulator: "
711                << spvOpcodeString(opcode);
712       }
713       if (!_.IsCooperativeMatrixAccType(matrix_type_id)) {
714         return _.diag(SPV_ERROR_INVALID_DATA, inst)
715                << "Matrix type must have UseAccumulator: "
716                << spvOpcodeString(opcode);
717       }
718 
719       const auto reduce_value = inst->GetOperandAs<uint32_t>(3);
720 
721       if ((reduce_value &
722            uint32_t(
723                spv::CooperativeMatrixReduceMask::CooperativeMatrixReduce2x2)) &&
724           (reduce_value & uint32_t(spv::CooperativeMatrixReduceMask::Row |
725                                    spv::CooperativeMatrixReduceMask::Column))) {
726         return _.diag(SPV_ERROR_INVALID_DATA, inst)
727                << "Reduce 2x2 must not be used with Row/Column: "
728                << spvOpcodeString(opcode);
729       }
730 
731       std::tuple<bool, bool, uint32_t> result_rows, result_cols, matrix_rows,
732           matrix_cols;
733       result_rows =
734           _.EvalInt32IfConst(_.FindDef(result_type)->GetOperandAs<uint32_t>(3));
735       result_cols =
736           _.EvalInt32IfConst(_.FindDef(result_type)->GetOperandAs<uint32_t>(4));
737       matrix_rows = _.EvalInt32IfConst(matrix_type->GetOperandAs<uint32_t>(3));
738       matrix_cols = _.EvalInt32IfConst(matrix_type->GetOperandAs<uint32_t>(4));
739 
740       if (reduce_value &
741           uint32_t(
742               spv::CooperativeMatrixReduceMask::CooperativeMatrixReduce2x2)) {
743         if (std::get<1>(result_rows) && std::get<1>(result_cols) &&
744             std::get<1>(matrix_rows) && std::get<1>(matrix_cols) &&
745             (std::get<2>(result_rows) != std::get<2>(matrix_rows) / 2 ||
746              std::get<2>(result_cols) != std::get<2>(matrix_cols) / 2)) {
747           return _.diag(SPV_ERROR_INVALID_DATA, inst)
748                  << "For Reduce2x2, result rows/cols must be half of matrix "
749                     "rows/cols: "
750                  << spvOpcodeString(opcode);
751         }
752       }
753       if (reduce_value == uint32_t(spv::CooperativeMatrixReduceMask::Row)) {
754         if (std::get<1>(result_rows) && std::get<1>(matrix_rows) &&
755             std::get<2>(result_rows) != std::get<2>(matrix_rows)) {
756           return _.diag(SPV_ERROR_INVALID_DATA, inst)
757                  << "For ReduceRow, result rows must match matrix rows: "
758                  << spvOpcodeString(opcode);
759         }
760       }
761       if (reduce_value == uint32_t(spv::CooperativeMatrixReduceMask::Column)) {
762         if (std::get<1>(result_cols) && std::get<1>(matrix_cols) &&
763             std::get<2>(result_cols) != std::get<2>(matrix_cols)) {
764           return _.diag(SPV_ERROR_INVALID_DATA, inst)
765                  << "For ReduceColumn, result cols must match matrix cols: "
766                  << spvOpcodeString(opcode);
767         }
768       }
769 
770       const auto combine_func_id = inst->GetOperandAs<uint32_t>(4);
771       const auto combine_func = _.FindDef(combine_func_id);
772       if (!combine_func || combine_func->opcode() != spv::Op::OpFunction) {
773         return _.diag(SPV_ERROR_INVALID_DATA, inst)
774                << "CombineFunc must be a function: " << spvOpcodeString(opcode);
775       }
776       const auto function_type_id = combine_func->GetOperandAs<uint32_t>(3);
777       const auto function_type = _.FindDef(function_type_id);
778       if (function_type->operands().size() != 4) {
779         return _.diag(SPV_ERROR_INVALID_DATA, inst)
780                << "CombineFunc must have two parameters: "
781                << spvOpcodeString(opcode);
782       }
783       for (uint32_t i = 0; i < 3; ++i) {
784         // checks return type and two params
785         const auto param_type_id = function_type->GetOperandAs<uint32_t>(i + 1);
786         if (param_type_id != matrix_comp_type_id) {
787           return _.diag(SPV_ERROR_INVALID_DATA, inst)
788                  << "CombineFunc return type and parameters must match matrix "
789                     "component type: "
790                  << spvOpcodeString(opcode);
791         }
792       }
793 
794       break;
795     }
796 
797     default:
798       break;
799   }
800 
801   return SPV_SUCCESS;
802 }
803 
804 }  // namespace val
805 }  // namespace spvtools
806