1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Performs validation of arithmetic instructions.
16
17 #include <vector>
18
19 #include "source/opcode.h"
20 #include "source/val/instruction.h"
21 #include "source/val/validate.h"
22 #include "source/val/validation_state.h"
23
24 namespace spvtools {
25 namespace val {
26
27 // Validates correctness of arithmetic instructions.
ArithmeticsPass(ValidationState_t & _,const Instruction * inst)28 spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
29 const spv::Op opcode = inst->opcode();
30 const uint32_t result_type = inst->type_id();
31
32 switch (opcode) {
33 case spv::Op::OpFAdd:
34 case spv::Op::OpFSub:
35 case spv::Op::OpFMul:
36 case spv::Op::OpFDiv:
37 case spv::Op::OpFRem:
38 case spv::Op::OpFMod:
39 case spv::Op::OpFNegate: {
40 bool supportsCoopMat =
41 (opcode != spv::Op::OpFMul && opcode != spv::Op::OpFRem &&
42 opcode != spv::Op::OpFMod);
43 if (!_.IsFloatScalarType(result_type) &&
44 !_.IsFloatVectorType(result_type) &&
45 !(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)) &&
46 !(opcode == spv::Op::OpFMul &&
47 _.IsCooperativeMatrixKHRType(result_type) &&
48 _.IsFloatCooperativeMatrixType(result_type)))
49 return _.diag(SPV_ERROR_INVALID_DATA, inst)
50 << "Expected floating scalar or vector type as Result Type: "
51 << spvOpcodeString(opcode);
52
53 for (size_t operand_index = 2; operand_index < inst->operands().size();
54 ++operand_index) {
55 if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
56 const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
57 if (!_.IsCooperativeMatrixKHRType(type_id) ||
58 !_.IsFloatCooperativeMatrixType(type_id)) {
59 return _.diag(SPV_ERROR_INVALID_DATA, inst)
60 << "Expected arithmetic operands to be of Result Type: "
61 << spvOpcodeString(opcode) << " operand index "
62 << operand_index;
63 }
64 spv_result_t ret =
65 _.CooperativeMatrixShapesMatch(inst, result_type, type_id, false);
66 if (ret != SPV_SUCCESS) return ret;
67 } else if (_.GetOperandTypeId(inst, operand_index) != result_type)
68 return _.diag(SPV_ERROR_INVALID_DATA, inst)
69 << "Expected arithmetic operands to be of Result Type: "
70 << spvOpcodeString(opcode) << " operand index "
71 << operand_index;
72 }
73 break;
74 }
75
76 case spv::Op::OpUDiv:
77 case spv::Op::OpUMod: {
78 bool supportsCoopMat = (opcode == spv::Op::OpUDiv);
79 if (!_.IsUnsignedIntScalarType(result_type) &&
80 !_.IsUnsignedIntVectorType(result_type) &&
81 !(supportsCoopMat &&
82 _.IsUnsignedIntCooperativeMatrixType(result_type)))
83 return _.diag(SPV_ERROR_INVALID_DATA, inst)
84 << "Expected unsigned int scalar or vector type as Result Type: "
85 << spvOpcodeString(opcode);
86
87 for (size_t operand_index = 2; operand_index < inst->operands().size();
88 ++operand_index) {
89 if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
90 const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
91 if (!_.IsCooperativeMatrixKHRType(type_id) ||
92 !_.IsUnsignedIntCooperativeMatrixType(type_id)) {
93 return _.diag(SPV_ERROR_INVALID_DATA, inst)
94 << "Expected arithmetic operands to be of Result Type: "
95 << spvOpcodeString(opcode) << " operand index "
96 << operand_index;
97 }
98 spv_result_t ret =
99 _.CooperativeMatrixShapesMatch(inst, result_type, type_id, false);
100 if (ret != SPV_SUCCESS) return ret;
101 } else if (_.GetOperandTypeId(inst, operand_index) != result_type)
102 return _.diag(SPV_ERROR_INVALID_DATA, inst)
103 << "Expected arithmetic operands to be of Result Type: "
104 << spvOpcodeString(opcode) << " operand index "
105 << operand_index;
106 }
107 break;
108 }
109
110 case spv::Op::OpISub:
111 case spv::Op::OpIAdd:
112 case spv::Op::OpIMul:
113 case spv::Op::OpSDiv:
114 case spv::Op::OpSMod:
115 case spv::Op::OpSRem:
116 case spv::Op::OpSNegate: {
117 bool supportsCoopMat =
118 (opcode != spv::Op::OpIMul && opcode != spv::Op::OpSRem &&
119 opcode != spv::Op::OpSMod);
120 if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
121 !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) &&
122 !(opcode == spv::Op::OpIMul &&
123 _.IsCooperativeMatrixKHRType(result_type) &&
124 _.IsIntCooperativeMatrixType(result_type)))
125 return _.diag(SPV_ERROR_INVALID_DATA, inst)
126 << "Expected int scalar or vector type as Result Type: "
127 << spvOpcodeString(opcode);
128
129 const uint32_t dimension = _.GetDimension(result_type);
130 const uint32_t bit_width = _.GetBitWidth(result_type);
131
132 for (size_t operand_index = 2; operand_index < inst->operands().size();
133 ++operand_index) {
134 const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
135
136 if (supportsCoopMat && _.IsCooperativeMatrixKHRType(result_type)) {
137 if (!_.IsCooperativeMatrixKHRType(type_id) ||
138 !_.IsIntCooperativeMatrixType(type_id)) {
139 return _.diag(SPV_ERROR_INVALID_DATA, inst)
140 << "Expected arithmetic operands to be of Result Type: "
141 << spvOpcodeString(opcode) << " operand index "
142 << operand_index;
143 }
144 spv_result_t ret =
145 _.CooperativeMatrixShapesMatch(inst, result_type, type_id, false);
146 if (ret != SPV_SUCCESS) return ret;
147 }
148
149 if (!type_id ||
150 (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id) &&
151 !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)) &&
152 !(opcode == spv::Op::OpIMul &&
153 _.IsCooperativeMatrixKHRType(result_type) &&
154 _.IsIntCooperativeMatrixType(result_type))))
155 return _.diag(SPV_ERROR_INVALID_DATA, inst)
156 << "Expected int scalar or vector type as operand: "
157 << spvOpcodeString(opcode) << " operand index "
158 << operand_index;
159
160 if (_.GetDimension(type_id) != dimension)
161 return _.diag(SPV_ERROR_INVALID_DATA, inst)
162 << "Expected arithmetic operands to have the same dimension "
163 << "as Result Type: " << spvOpcodeString(opcode)
164 << " operand index " << operand_index;
165
166 if (_.GetBitWidth(type_id) != bit_width)
167 return _.diag(SPV_ERROR_INVALID_DATA, inst)
168 << "Expected arithmetic operands to have the same bit width "
169 << "as Result Type: " << spvOpcodeString(opcode)
170 << " operand index " << operand_index;
171 }
172 break;
173 }
174
175 case spv::Op::OpDot: {
176 if (!_.IsFloatScalarType(result_type))
177 return _.diag(SPV_ERROR_INVALID_DATA, inst)
178 << "Expected float scalar type as Result Type: "
179 << spvOpcodeString(opcode);
180
181 uint32_t first_vector_num_components = 0;
182
183 for (size_t operand_index = 2; operand_index < inst->operands().size();
184 ++operand_index) {
185 const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
186
187 if (!type_id || !_.IsFloatVectorType(type_id))
188 return _.diag(SPV_ERROR_INVALID_DATA, inst)
189 << "Expected float vector as operand: "
190 << spvOpcodeString(opcode) << " operand index "
191 << operand_index;
192
193 const uint32_t component_type = _.GetComponentType(type_id);
194 if (component_type != result_type)
195 return _.diag(SPV_ERROR_INVALID_DATA, inst)
196 << "Expected component type to be equal to Result Type: "
197 << spvOpcodeString(opcode) << " operand index "
198 << operand_index;
199
200 const uint32_t num_components = _.GetDimension(type_id);
201 if (operand_index == 2) {
202 first_vector_num_components = num_components;
203 } else if (num_components != first_vector_num_components) {
204 return _.diag(SPV_ERROR_INVALID_DATA, inst)
205 << "Expected operands to have the same number of components: "
206 << spvOpcodeString(opcode);
207 }
208 }
209 break;
210 }
211
212 case spv::Op::OpVectorTimesScalar: {
213 if (!_.IsFloatVectorType(result_type))
214 return _.diag(SPV_ERROR_INVALID_DATA, inst)
215 << "Expected float vector type as Result Type: "
216 << spvOpcodeString(opcode);
217
218 const uint32_t vector_type_id = _.GetOperandTypeId(inst, 2);
219 if (result_type != vector_type_id)
220 return _.diag(SPV_ERROR_INVALID_DATA, inst)
221 << "Expected vector operand type to be equal to Result Type: "
222 << spvOpcodeString(opcode);
223
224 const uint32_t component_type = _.GetComponentType(vector_type_id);
225
226 const uint32_t scalar_type_id = _.GetOperandTypeId(inst, 3);
227 if (component_type != scalar_type_id)
228 return _.diag(SPV_ERROR_INVALID_DATA, inst)
229 << "Expected scalar operand type to be equal to the component "
230 << "type of the vector operand: " << spvOpcodeString(opcode);
231
232 break;
233 }
234
235 case spv::Op::OpMatrixTimesScalar: {
236 if (!_.IsFloatMatrixType(result_type) &&
237 !(_.IsCooperativeMatrixType(result_type)))
238 return _.diag(SPV_ERROR_INVALID_DATA, inst)
239 << "Expected float matrix type as Result Type: "
240 << spvOpcodeString(opcode);
241
242 const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 2);
243 if (result_type != matrix_type_id)
244 return _.diag(SPV_ERROR_INVALID_DATA, inst)
245 << "Expected matrix operand type to be equal to Result Type: "
246 << spvOpcodeString(opcode);
247
248 const uint32_t component_type = _.GetComponentType(matrix_type_id);
249
250 const uint32_t scalar_type_id = _.GetOperandTypeId(inst, 3);
251 if (component_type != scalar_type_id)
252 return _.diag(SPV_ERROR_INVALID_DATA, inst)
253 << "Expected scalar operand type to be equal to the component "
254 << "type of the matrix operand: " << spvOpcodeString(opcode);
255
256 break;
257 }
258
259 case spv::Op::OpVectorTimesMatrix: {
260 const uint32_t vector_type_id = _.GetOperandTypeId(inst, 2);
261 const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 3);
262
263 if (!_.IsFloatVectorType(result_type))
264 return _.diag(SPV_ERROR_INVALID_DATA, inst)
265 << "Expected float vector type as Result Type: "
266 << spvOpcodeString(opcode);
267
268 const uint32_t res_component_type = _.GetComponentType(result_type);
269
270 if (!vector_type_id || !_.IsFloatVectorType(vector_type_id))
271 return _.diag(SPV_ERROR_INVALID_DATA, inst)
272 << "Expected float vector type as left operand: "
273 << spvOpcodeString(opcode);
274
275 if (res_component_type != _.GetComponentType(vector_type_id))
276 return _.diag(SPV_ERROR_INVALID_DATA, inst)
277 << "Expected component types of Result Type and vector to be "
278 << "equal: " << spvOpcodeString(opcode);
279
280 uint32_t matrix_num_rows = 0;
281 uint32_t matrix_num_cols = 0;
282 uint32_t matrix_col_type = 0;
283 uint32_t matrix_component_type = 0;
284 if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows,
285 &matrix_num_cols, &matrix_col_type,
286 &matrix_component_type))
287 return _.diag(SPV_ERROR_INVALID_DATA, inst)
288 << "Expected float matrix type as right operand: "
289 << spvOpcodeString(opcode);
290
291 if (res_component_type != matrix_component_type)
292 return _.diag(SPV_ERROR_INVALID_DATA, inst)
293 << "Expected component types of Result Type and matrix to be "
294 << "equal: " << spvOpcodeString(opcode);
295
296 if (matrix_num_cols != _.GetDimension(result_type))
297 return _.diag(SPV_ERROR_INVALID_DATA, inst)
298 << "Expected number of columns of the matrix to be equal to "
299 << "Result Type vector size: " << spvOpcodeString(opcode);
300
301 if (matrix_num_rows != _.GetDimension(vector_type_id))
302 return _.diag(SPV_ERROR_INVALID_DATA, inst)
303 << "Expected number of rows of the matrix to be equal to the "
304 << "vector operand size: " << spvOpcodeString(opcode);
305
306 break;
307 }
308
309 case spv::Op::OpMatrixTimesVector: {
310 const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 2);
311 const uint32_t vector_type_id = _.GetOperandTypeId(inst, 3);
312
313 if (!_.IsFloatVectorType(result_type))
314 return _.diag(SPV_ERROR_INVALID_DATA, inst)
315 << "Expected float vector type as Result Type: "
316 << spvOpcodeString(opcode);
317
318 uint32_t matrix_num_rows = 0;
319 uint32_t matrix_num_cols = 0;
320 uint32_t matrix_col_type = 0;
321 uint32_t matrix_component_type = 0;
322 if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows,
323 &matrix_num_cols, &matrix_col_type,
324 &matrix_component_type))
325 return _.diag(SPV_ERROR_INVALID_DATA, inst)
326 << "Expected float matrix type as left operand: "
327 << spvOpcodeString(opcode);
328
329 if (result_type != matrix_col_type)
330 return _.diag(SPV_ERROR_INVALID_DATA, inst)
331 << "Expected column type of the matrix to be equal to Result "
332 "Type: "
333 << spvOpcodeString(opcode);
334
335 if (!vector_type_id || !_.IsFloatVectorType(vector_type_id))
336 return _.diag(SPV_ERROR_INVALID_DATA, inst)
337 << "Expected float vector type as right operand: "
338 << spvOpcodeString(opcode);
339
340 if (matrix_component_type != _.GetComponentType(vector_type_id))
341 return _.diag(SPV_ERROR_INVALID_DATA, inst)
342 << "Expected component types of the operands to be equal: "
343 << spvOpcodeString(opcode);
344
345 if (matrix_num_cols != _.GetDimension(vector_type_id))
346 return _.diag(SPV_ERROR_INVALID_DATA, inst)
347 << "Expected number of columns of the matrix to be equal to the "
348 << "vector size: " << spvOpcodeString(opcode);
349
350 break;
351 }
352
353 case spv::Op::OpMatrixTimesMatrix: {
354 const uint32_t left_type_id = _.GetOperandTypeId(inst, 2);
355 const uint32_t right_type_id = _.GetOperandTypeId(inst, 3);
356
357 uint32_t res_num_rows = 0;
358 uint32_t res_num_cols = 0;
359 uint32_t res_col_type = 0;
360 uint32_t res_component_type = 0;
361 if (!_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols,
362 &res_col_type, &res_component_type))
363 return _.diag(SPV_ERROR_INVALID_DATA, inst)
364 << "Expected float matrix type as Result Type: "
365 << spvOpcodeString(opcode);
366
367 uint32_t left_num_rows = 0;
368 uint32_t left_num_cols = 0;
369 uint32_t left_col_type = 0;
370 uint32_t left_component_type = 0;
371 if (!_.GetMatrixTypeInfo(left_type_id, &left_num_rows, &left_num_cols,
372 &left_col_type, &left_component_type))
373 return _.diag(SPV_ERROR_INVALID_DATA, inst)
374 << "Expected float matrix type as left operand: "
375 << spvOpcodeString(opcode);
376
377 uint32_t right_num_rows = 0;
378 uint32_t right_num_cols = 0;
379 uint32_t right_col_type = 0;
380 uint32_t right_component_type = 0;
381 if (!_.GetMatrixTypeInfo(right_type_id, &right_num_rows, &right_num_cols,
382 &right_col_type, &right_component_type))
383 return _.diag(SPV_ERROR_INVALID_DATA, inst)
384 << "Expected float matrix type as right operand: "
385 << spvOpcodeString(opcode);
386
387 if (!_.IsFloatScalarType(res_component_type))
388 return _.diag(SPV_ERROR_INVALID_DATA, inst)
389 << "Expected float matrix type as Result Type: "
390 << spvOpcodeString(opcode);
391
392 if (res_col_type != left_col_type)
393 return _.diag(SPV_ERROR_INVALID_DATA, inst)
394 << "Expected column types of Result Type and left matrix to be "
395 << "equal: " << spvOpcodeString(opcode);
396
397 if (res_component_type != right_component_type)
398 return _.diag(SPV_ERROR_INVALID_DATA, inst)
399 << "Expected component types of Result Type and right matrix to "
400 "be "
401 << "equal: " << spvOpcodeString(opcode);
402
403 if (res_num_cols != right_num_cols)
404 return _.diag(SPV_ERROR_INVALID_DATA, inst)
405 << "Expected number of columns of Result Type and right matrix "
406 "to "
407 << "be equal: " << spvOpcodeString(opcode);
408
409 if (left_num_cols != right_num_rows)
410 return _.diag(SPV_ERROR_INVALID_DATA, inst)
411 << "Expected number of columns of left matrix and number of "
412 "rows "
413 << "of right matrix to be equal: " << spvOpcodeString(opcode);
414
415 assert(left_num_rows == res_num_rows);
416 break;
417 }
418
419 case spv::Op::OpOuterProduct: {
420 const uint32_t left_type_id = _.GetOperandTypeId(inst, 2);
421 const uint32_t right_type_id = _.GetOperandTypeId(inst, 3);
422
423 uint32_t res_num_rows = 0;
424 uint32_t res_num_cols = 0;
425 uint32_t res_col_type = 0;
426 uint32_t res_component_type = 0;
427 if (!_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols,
428 &res_col_type, &res_component_type))
429 return _.diag(SPV_ERROR_INVALID_DATA, inst)
430 << "Expected float matrix type as Result Type: "
431 << spvOpcodeString(opcode);
432
433 if (left_type_id != res_col_type)
434 return _.diag(SPV_ERROR_INVALID_DATA, inst)
435 << "Expected column type of Result Type to be equal to the type "
436 << "of the left operand: " << spvOpcodeString(opcode);
437
438 if (!right_type_id || !_.IsFloatVectorType(right_type_id))
439 return _.diag(SPV_ERROR_INVALID_DATA, inst)
440 << "Expected float vector type as right operand: "
441 << spvOpcodeString(opcode);
442
443 if (res_component_type != _.GetComponentType(right_type_id))
444 return _.diag(SPV_ERROR_INVALID_DATA, inst)
445 << "Expected component types of the operands to be equal: "
446 << spvOpcodeString(opcode);
447
448 if (res_num_cols != _.GetDimension(right_type_id))
449 return _.diag(SPV_ERROR_INVALID_DATA, inst)
450 << "Expected number of columns of the matrix to be equal to the "
451 << "vector size of the right operand: "
452 << spvOpcodeString(opcode);
453
454 break;
455 }
456
457 case spv::Op::OpIAddCarry:
458 case spv::Op::OpISubBorrow:
459 case spv::Op::OpUMulExtended:
460 case spv::Op::OpSMulExtended: {
461 std::vector<uint32_t> result_types;
462 if (!_.GetStructMemberTypes(result_type, &result_types))
463 return _.diag(SPV_ERROR_INVALID_DATA, inst)
464 << "Expected a struct as Result Type: "
465 << spvOpcodeString(opcode);
466
467 if (result_types.size() != 2)
468 return _.diag(SPV_ERROR_INVALID_DATA, inst)
469 << "Expected Result Type struct to have two members: "
470 << spvOpcodeString(opcode);
471
472 if (opcode == spv::Op::OpSMulExtended) {
473 if (!_.IsIntScalarType(result_types[0]) &&
474 !_.IsIntVectorType(result_types[0]))
475 return _.diag(SPV_ERROR_INVALID_DATA, inst)
476 << "Expected Result Type struct member types to be integer "
477 "scalar "
478 << "or vector: " << spvOpcodeString(opcode);
479 } else {
480 if (!_.IsUnsignedIntScalarType(result_types[0]) &&
481 !_.IsUnsignedIntVectorType(result_types[0]))
482 return _.diag(SPV_ERROR_INVALID_DATA, inst)
483 << "Expected Result Type struct member types to be unsigned "
484 << "integer scalar or vector: " << spvOpcodeString(opcode);
485 }
486
487 if (result_types[0] != result_types[1])
488 return _.diag(SPV_ERROR_INVALID_DATA, inst)
489 << "Expected Result Type struct member types to be identical: "
490 << spvOpcodeString(opcode);
491
492 const uint32_t left_type_id = _.GetOperandTypeId(inst, 2);
493 const uint32_t right_type_id = _.GetOperandTypeId(inst, 3);
494
495 if (left_type_id != result_types[0] || right_type_id != result_types[0])
496 return _.diag(SPV_ERROR_INVALID_DATA, inst)
497 << "Expected both operands to be of Result Type member type: "
498 << spvOpcodeString(opcode);
499
500 break;
501 }
502
503 case spv::Op::OpCooperativeMatrixMulAddNV: {
504 const uint32_t D_type_id = _.GetOperandTypeId(inst, 1);
505 const uint32_t A_type_id = _.GetOperandTypeId(inst, 2);
506 const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
507 const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
508
509 if (!_.IsCooperativeMatrixNVType(A_type_id)) {
510 return _.diag(SPV_ERROR_INVALID_DATA, inst)
511 << "Expected cooperative matrix type as A Type: "
512 << spvOpcodeString(opcode);
513 }
514 if (!_.IsCooperativeMatrixNVType(B_type_id)) {
515 return _.diag(SPV_ERROR_INVALID_DATA, inst)
516 << "Expected cooperative matrix type as B Type: "
517 << spvOpcodeString(opcode);
518 }
519 if (!_.IsCooperativeMatrixNVType(C_type_id)) {
520 return _.diag(SPV_ERROR_INVALID_DATA, inst)
521 << "Expected cooperative matrix type as C Type: "
522 << spvOpcodeString(opcode);
523 }
524 if (!_.IsCooperativeMatrixNVType(D_type_id)) {
525 return _.diag(SPV_ERROR_INVALID_DATA, inst)
526 << "Expected cooperative matrix type as Result Type: "
527 << spvOpcodeString(opcode);
528 }
529
530 const auto A = _.FindDef(A_type_id);
531 const auto B = _.FindDef(B_type_id);
532 const auto C = _.FindDef(C_type_id);
533 const auto D = _.FindDef(D_type_id);
534
535 std::tuple<bool, bool, uint32_t> A_scope, B_scope, C_scope, D_scope,
536 A_rows, B_rows, C_rows, D_rows, A_cols, B_cols, C_cols, D_cols;
537
538 A_scope = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(2));
539 B_scope = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(2));
540 C_scope = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(2));
541 D_scope = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(2));
542
543 A_rows = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(3));
544 B_rows = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(3));
545 C_rows = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(3));
546 D_rows = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(3));
547
548 A_cols = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(4));
549 B_cols = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(4));
550 C_cols = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(4));
551 D_cols = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(4));
552
553 const auto notEqual = [](std::tuple<bool, bool, uint32_t> X,
554 std::tuple<bool, bool, uint32_t> Y) {
555 return (std::get<1>(X) && std::get<1>(Y) &&
556 std::get<2>(X) != std::get<2>(Y));
557 };
558
559 if (notEqual(A_scope, B_scope) || notEqual(A_scope, C_scope) ||
560 notEqual(A_scope, D_scope) || notEqual(B_scope, C_scope) ||
561 notEqual(B_scope, D_scope) || notEqual(C_scope, D_scope)) {
562 return _.diag(SPV_ERROR_INVALID_DATA, inst)
563 << "Cooperative matrix scopes must match: "
564 << spvOpcodeString(opcode);
565 }
566
567 if (notEqual(A_rows, C_rows) || notEqual(A_rows, D_rows) ||
568 notEqual(C_rows, D_rows)) {
569 return _.diag(SPV_ERROR_INVALID_DATA, inst)
570 << "Cooperative matrix 'M' mismatch: "
571 << spvOpcodeString(opcode);
572 }
573
574 if (notEqual(B_cols, C_cols) || notEqual(B_cols, D_cols) ||
575 notEqual(C_cols, D_cols)) {
576 return _.diag(SPV_ERROR_INVALID_DATA, inst)
577 << "Cooperative matrix 'N' mismatch: "
578 << spvOpcodeString(opcode);
579 }
580
581 if (notEqual(A_cols, B_rows)) {
582 return _.diag(SPV_ERROR_INVALID_DATA, inst)
583 << "Cooperative matrix 'K' mismatch: "
584 << spvOpcodeString(opcode);
585 }
586 break;
587 }
588
589 case spv::Op::OpCooperativeMatrixMulAddKHR: {
590 const uint32_t D_type_id = _.GetOperandTypeId(inst, 1);
591 const uint32_t A_type_id = _.GetOperandTypeId(inst, 2);
592 const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
593 const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
594
595 if (!_.IsCooperativeMatrixAType(A_type_id)) {
596 return _.diag(SPV_ERROR_INVALID_DATA, inst)
597 << "Cooperative matrix type must be A Type: "
598 << spvOpcodeString(opcode);
599 }
600 if (!_.IsCooperativeMatrixBType(B_type_id)) {
601 return _.diag(SPV_ERROR_INVALID_DATA, inst)
602 << "Cooperative matrix type must be B Type: "
603 << spvOpcodeString(opcode);
604 }
605 if (!_.IsCooperativeMatrixAccType(C_type_id)) {
606 return _.diag(SPV_ERROR_INVALID_DATA, inst)
607 << "Cooperative matrix type must be Accumulator Type: "
608 << spvOpcodeString(opcode);
609 }
610 if (!_.IsCooperativeMatrixKHRType(D_type_id)) {
611 return _.diag(SPV_ERROR_INVALID_DATA, inst)
612 << "Expected cooperative matrix type as Result Type: "
613 << spvOpcodeString(opcode);
614 }
615
616 const auto A = _.FindDef(A_type_id);
617 const auto B = _.FindDef(B_type_id);
618 const auto C = _.FindDef(C_type_id);
619 const auto D = _.FindDef(D_type_id);
620
621 std::tuple<bool, bool, uint32_t> A_scope, B_scope, C_scope, D_scope,
622 A_rows, B_rows, C_rows, D_rows, A_cols, B_cols, C_cols, D_cols;
623
624 A_scope = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(2));
625 B_scope = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(2));
626 C_scope = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(2));
627 D_scope = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(2));
628
629 A_rows = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(3));
630 B_rows = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(3));
631 C_rows = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(3));
632 D_rows = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(3));
633
634 A_cols = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(4));
635 B_cols = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(4));
636 C_cols = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(4));
637 D_cols = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(4));
638
639 const auto notEqual = [](std::tuple<bool, bool, uint32_t> X,
640 std::tuple<bool, bool, uint32_t> Y) {
641 return (std::get<1>(X) && std::get<1>(Y) &&
642 std::get<2>(X) != std::get<2>(Y));
643 };
644
645 if (notEqual(A_scope, B_scope) || notEqual(A_scope, C_scope) ||
646 notEqual(A_scope, D_scope) || notEqual(B_scope, C_scope) ||
647 notEqual(B_scope, D_scope) || notEqual(C_scope, D_scope)) {
648 return _.diag(SPV_ERROR_INVALID_DATA, inst)
649 << "Cooperative matrix scopes must match: "
650 << spvOpcodeString(opcode);
651 }
652
653 if (notEqual(A_rows, C_rows) || notEqual(A_rows, D_rows) ||
654 notEqual(C_rows, D_rows)) {
655 return _.diag(SPV_ERROR_INVALID_DATA, inst)
656 << "Cooperative matrix 'M' mismatch: "
657 << spvOpcodeString(opcode);
658 }
659
660 if (notEqual(B_cols, C_cols) || notEqual(B_cols, D_cols) ||
661 notEqual(C_cols, D_cols)) {
662 return _.diag(SPV_ERROR_INVALID_DATA, inst)
663 << "Cooperative matrix 'N' mismatch: "
664 << spvOpcodeString(opcode);
665 }
666
667 if (notEqual(A_cols, B_rows)) {
668 return _.diag(SPV_ERROR_INVALID_DATA, inst)
669 << "Cooperative matrix 'K' mismatch: "
670 << spvOpcodeString(opcode);
671 }
672 break;
673 }
674
675 case spv::Op::OpCooperativeMatrixReduceNV: {
676 if (!_.IsCooperativeMatrixKHRType(result_type)) {
677 return _.diag(SPV_ERROR_INVALID_DATA, inst)
678 << "Result Type must be a cooperative matrix type: "
679 << spvOpcodeString(opcode);
680 }
681
682 const auto result_comp_type_id =
683 _.FindDef(result_type)->GetOperandAs<uint32_t>(1);
684
685 const auto matrix_id = inst->GetOperandAs<uint32_t>(2);
686 const auto matrix = _.FindDef(matrix_id);
687 const auto matrix_type_id = matrix->type_id();
688 if (!_.IsCooperativeMatrixKHRType(matrix_type_id)) {
689 return _.diag(SPV_ERROR_INVALID_DATA, inst)
690 << "Matrix must have a cooperative matrix type: "
691 << spvOpcodeString(opcode);
692 }
693 const auto matrix_type = _.FindDef(matrix_type_id);
694 const auto matrix_comp_type_id = matrix_type->GetOperandAs<uint32_t>(1);
695 if (matrix_comp_type_id != result_comp_type_id) {
696 return _.diag(SPV_ERROR_INVALID_DATA, inst)
697 << "Result Type and Matrix type must have the same component "
698 "type: "
699 << spvOpcodeString(opcode);
700 }
701 if (_.FindDef(result_type)->GetOperandAs<uint32_t>(2) !=
702 matrix_type->GetOperandAs<uint32_t>(2)) {
703 return _.diag(SPV_ERROR_INVALID_DATA, inst)
704 << "Result Type and Matrix type must have the same scope: "
705 << spvOpcodeString(opcode);
706 }
707
708 if (!_.IsCooperativeMatrixAccType(result_type)) {
709 return _.diag(SPV_ERROR_INVALID_DATA, inst)
710 << "Result Type must have UseAccumulator: "
711 << spvOpcodeString(opcode);
712 }
713 if (!_.IsCooperativeMatrixAccType(matrix_type_id)) {
714 return _.diag(SPV_ERROR_INVALID_DATA, inst)
715 << "Matrix type must have UseAccumulator: "
716 << spvOpcodeString(opcode);
717 }
718
719 const auto reduce_value = inst->GetOperandAs<uint32_t>(3);
720
721 if ((reduce_value &
722 uint32_t(
723 spv::CooperativeMatrixReduceMask::CooperativeMatrixReduce2x2)) &&
724 (reduce_value & uint32_t(spv::CooperativeMatrixReduceMask::Row |
725 spv::CooperativeMatrixReduceMask::Column))) {
726 return _.diag(SPV_ERROR_INVALID_DATA, inst)
727 << "Reduce 2x2 must not be used with Row/Column: "
728 << spvOpcodeString(opcode);
729 }
730
731 std::tuple<bool, bool, uint32_t> result_rows, result_cols, matrix_rows,
732 matrix_cols;
733 result_rows =
734 _.EvalInt32IfConst(_.FindDef(result_type)->GetOperandAs<uint32_t>(3));
735 result_cols =
736 _.EvalInt32IfConst(_.FindDef(result_type)->GetOperandAs<uint32_t>(4));
737 matrix_rows = _.EvalInt32IfConst(matrix_type->GetOperandAs<uint32_t>(3));
738 matrix_cols = _.EvalInt32IfConst(matrix_type->GetOperandAs<uint32_t>(4));
739
740 if (reduce_value &
741 uint32_t(
742 spv::CooperativeMatrixReduceMask::CooperativeMatrixReduce2x2)) {
743 if (std::get<1>(result_rows) && std::get<1>(result_cols) &&
744 std::get<1>(matrix_rows) && std::get<1>(matrix_cols) &&
745 (std::get<2>(result_rows) != std::get<2>(matrix_rows) / 2 ||
746 std::get<2>(result_cols) != std::get<2>(matrix_cols) / 2)) {
747 return _.diag(SPV_ERROR_INVALID_DATA, inst)
748 << "For Reduce2x2, result rows/cols must be half of matrix "
749 "rows/cols: "
750 << spvOpcodeString(opcode);
751 }
752 }
753 if (reduce_value == uint32_t(spv::CooperativeMatrixReduceMask::Row)) {
754 if (std::get<1>(result_rows) && std::get<1>(matrix_rows) &&
755 std::get<2>(result_rows) != std::get<2>(matrix_rows)) {
756 return _.diag(SPV_ERROR_INVALID_DATA, inst)
757 << "For ReduceRow, result rows must match matrix rows: "
758 << spvOpcodeString(opcode);
759 }
760 }
761 if (reduce_value == uint32_t(spv::CooperativeMatrixReduceMask::Column)) {
762 if (std::get<1>(result_cols) && std::get<1>(matrix_cols) &&
763 std::get<2>(result_cols) != std::get<2>(matrix_cols)) {
764 return _.diag(SPV_ERROR_INVALID_DATA, inst)
765 << "For ReduceColumn, result cols must match matrix cols: "
766 << spvOpcodeString(opcode);
767 }
768 }
769
770 const auto combine_func_id = inst->GetOperandAs<uint32_t>(4);
771 const auto combine_func = _.FindDef(combine_func_id);
772 if (!combine_func || combine_func->opcode() != spv::Op::OpFunction) {
773 return _.diag(SPV_ERROR_INVALID_DATA, inst)
774 << "CombineFunc must be a function: " << spvOpcodeString(opcode);
775 }
776 const auto function_type_id = combine_func->GetOperandAs<uint32_t>(3);
777 const auto function_type = _.FindDef(function_type_id);
778 if (function_type->operands().size() != 4) {
779 return _.diag(SPV_ERROR_INVALID_DATA, inst)
780 << "CombineFunc must have two parameters: "
781 << spvOpcodeString(opcode);
782 }
783 for (uint32_t i = 0; i < 3; ++i) {
784 // checks return type and two params
785 const auto param_type_id = function_type->GetOperandAs<uint32_t>(i + 1);
786 if (param_type_id != matrix_comp_type_id) {
787 return _.diag(SPV_ERROR_INVALID_DATA, inst)
788 << "CombineFunc return type and parameters must match matrix "
789 "component type: "
790 << spvOpcodeString(opcode);
791 }
792 }
793
794 break;
795 }
796
797 default:
798 break;
799 }
800
801 return SPV_SUCCESS;
802 }
803
804 } // namespace val
805 } // namespace spvtools
806