1 /*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2019 The Khronos Group Inc.
6 * Copyright (c) 2018-2019 NVIDIA Corporation
7 * Copyright (c) 2023 LunarG, Inc.
8 * Copyright (c) 2023 Nintendo
9 *
10 * Licensed under the Apache License, Version 2.0 (the "License");
11 * you may not use this file except in compliance with the License.
12 * You may obtain a copy of the License at
13 *
14 * http://www.apache.org/licenses/LICENSE-2.0
15 *
16 * Unless required by applicable law or agreed to in writing, software
17 * distributed under the License is distributed on an "AS IS" BASIS,
18 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 * See the License for the specific language governing permissions and
20 * limitations under the License.
21 *
22 *//*!
23 * \file
24 * \brief Vulkan Cooperative Matrix tests
25 *//*--------------------------------------------------------------------*/
26
27 #include "vktComputeCooperativeMatrixTests.hpp"
28
29 #include "vkBufferWithMemory.hpp"
30 #include "vkImageWithMemory.hpp"
31 #include "vkQueryUtil.hpp"
32 #include "vkBuilderUtil.hpp"
33 #include "vkCmdUtil.hpp"
34 #include "vkTypeUtil.hpp"
35 #include "vkObjUtil.hpp"
36
37 #include "vktTestGroupUtil.hpp"
38 #include "vktTestCase.hpp"
39
40 #include "deDefs.h"
41 #include "tcuFloat.hpp"
42 #include "deMath.h"
43 #include "deRandom.h"
44 #include "deSharedPtr.hpp"
45 #include "deString.h"
46
47 #include "tcuTestCase.hpp"
48 #include "tcuTestLog.hpp"
49
50 #include <string>
51 #include <sstream>
52 #include <set>
53 #include <algorithm>
54
55 namespace vkt
56 {
57 namespace compute
58 {
59 namespace
60 {
61 using namespace vk;
62 using namespace std;
63
64 //#define COOPERATIVE_MATRIX_EXTENDED_DEBUG 1
65
66 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_FLOAT16_KHR == (uint32_t)VK_COMPONENT_TYPE_FLOAT16_NV);
67 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_FLOAT32_KHR == (uint32_t)VK_COMPONENT_TYPE_FLOAT32_NV);
68 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_FLOAT64_KHR == (uint32_t)VK_COMPONENT_TYPE_FLOAT64_NV);
69 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT8_KHR == (uint32_t)VK_COMPONENT_TYPE_SINT8_NV);
70 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT16_KHR == (uint32_t)VK_COMPONENT_TYPE_SINT16_NV);
71 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT32_KHR == (uint32_t)VK_COMPONENT_TYPE_SINT32_NV);
72 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT64_KHR == (uint32_t)VK_COMPONENT_TYPE_SINT64_NV);
73 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT8_KHR == (uint32_t)VK_COMPONENT_TYPE_UINT8_NV);
74 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT16_KHR == (uint32_t)VK_COMPONENT_TYPE_UINT16_NV);
75 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT32_KHR == (uint32_t)VK_COMPONENT_TYPE_UINT32_NV);
76 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT64_KHR == (uint32_t)VK_COMPONENT_TYPE_UINT64_NV);
77
78 DE_STATIC_ASSERT((uint32_t)VK_SCOPE_DEVICE_KHR == (uint32_t)VK_SCOPE_DEVICE_NV);
79 DE_STATIC_ASSERT((uint32_t)VK_SCOPE_WORKGROUP_KHR == (uint32_t)VK_SCOPE_WORKGROUP_NV);
80 DE_STATIC_ASSERT((uint32_t)VK_SCOPE_SUBGROUP_KHR == (uint32_t)VK_SCOPE_SUBGROUP_NV);
81 DE_STATIC_ASSERT((uint32_t)VK_SCOPE_QUEUE_FAMILY_KHR == (uint32_t)VK_SCOPE_QUEUE_FAMILY_NV);
82
83 typedef enum
84 {
85 UT_NV = 0,
86 UT_KHR_A,
87 UT_KHR_B,
88 UT_KHR_Result,
89 } UseType;
90
91 typedef enum
92 {
93 TT_LENGTH = 0,
94 TT_CONSTANT,
95 TT_CONVERT,
96 TT_COMPOSITE,
97 TT_COMPOSITE_RVALUE,
98 TT_ADD,
99 TT_SUB,
100 TT_DIV,
101 TT_MUL,
102 TT_NEGATE,
103 TT_MATRIXTIMESSCALAR,
104 TT_FUNC,
105 TT_MATRIXMULADD,
106 TT_COMPOSITE_ARRAY,
107 TT_MATRIXMULADD_ARRAY,
108 TT_MATRIXMULADD_SATURATED,
109 TT_MATRIXMULADD_WRAPPING,
110 TT_MATRIXMULADD_STRIDE0,
111 TT_MULTICOMPONENT_LOAD,
112 TT_MULTICOMPONENT_SAVE,
113 } TestType;
114
115 typedef enum
116 {
117 SC_BUFFER = 0,
118 SC_WORKGROUP,
119 SC_WORKGROUP_VARIABLE_POINTERS,
120 SC_BUFFER_VARIABLE_POINTERS,
121 SC_PHYSICAL_STORAGE_BUFFER,
122 } StorageClass;
123
124 enum SubgroupSizeMode
125 {
126 SUBGROUP_SIZE_NONE = 0,
127 SUBGROUP_SIZE_MIN = 1,
128 SUBGROUP_SIZE_MAX = 2,
129 };
130
131 const VkFlags allShaderStages = VK_SHADER_STAGE_COMPUTE_BIT;
132
133 struct CaseDef
134 {
135 TestType testType;
136 uint32_t subgroupsPerWorkgroupX;
137 uint32_t subgroupsPerWorkgroupY;
138 uint32_t workgroupsX;
139 uint32_t workgroupsY;
140 VkComponentTypeKHR inputType;
141 VkComponentTypeKHR outputType;
142 bool colMajor;
143 StorageClass storageClass;
144 UseType useType;
145 SubgroupSizeMode subgroupSizeMode;
146 vk::ComputePipelineConstructionType computePipelineConstructionType;
147 uint32_t inputComponentCount;
148 uint32_t outputComponentCount;
149 };
150
isKhr(UseType useType)151 bool isKhr(UseType useType)
152 {
153 return useType != UT_NV;
154 }
155
isMatrixMulAddOp(TestType testType)156 bool isMatrixMulAddOp(TestType testType)
157 {
158 return testType == TT_MATRIXMULADD || testType == TT_MATRIXMULADD_ARRAY || testType == TT_MATRIXMULADD_SATURATED ||
159 testType == TT_MATRIXMULADD_WRAPPING || testType == TT_MATRIXMULADD_STRIDE0;
160 }
161
162 template <typename T>
getCooperativeMatrixProperties(const InstanceInterface &,VkPhysicalDevice,uint32_t *,T *)163 VkResult getCooperativeMatrixProperties(const InstanceInterface &, VkPhysicalDevice, uint32_t *, T *)
164 {
165 TCU_THROW(InternalError, "Not Implementetd");
166 }
167
getCooperativeMatrixProperties(const InstanceInterface & vki,VkPhysicalDevice physicalDevice,uint32_t * pPropertyCount,VkCooperativeMatrixPropertiesKHR * pProperties)168 VkResult getCooperativeMatrixProperties(const InstanceInterface &vki, VkPhysicalDevice physicalDevice,
169 uint32_t *pPropertyCount, VkCooperativeMatrixPropertiesKHR *pProperties)
170 {
171 return vki.getPhysicalDeviceCooperativeMatrixPropertiesKHR(physicalDevice, pPropertyCount, pProperties);
172 }
173
getCooperativeMatrixProperties(const InstanceInterface & vki,VkPhysicalDevice physicalDevice,uint32_t * pPropertyCount,VkCooperativeMatrixPropertiesNV * pProperties)174 VkResult getCooperativeMatrixProperties(const InstanceInterface &vki, VkPhysicalDevice physicalDevice,
175 uint32_t *pPropertyCount, VkCooperativeMatrixPropertiesNV *pProperties)
176 {
177 return vki.getPhysicalDeviceCooperativeMatrixPropertiesNV(physicalDevice, pPropertyCount, pProperties);
178 }
179
convertCooperativeMatrixProperties(const VkCooperativeMatrixPropertiesNV & properties)180 VkCooperativeMatrixPropertiesKHR convertCooperativeMatrixProperties(const VkCooperativeMatrixPropertiesNV &properties)
181 {
182 VkCooperativeMatrixPropertiesKHR result = initVulkanStructure();
183
184 result.sType = (VkStructureType)properties.sType;
185 result.pNext = (void *)properties.pNext;
186 result.MSize = (uint32_t)properties.MSize;
187 result.NSize = (uint32_t)properties.NSize;
188 result.KSize = (uint32_t)properties.KSize;
189 result.AType = (VkComponentTypeKHR)properties.AType;
190 result.BType = (VkComponentTypeKHR)properties.BType;
191 result.CType = (VkComponentTypeKHR)properties.CType;
192 result.ResultType = (VkComponentTypeKHR)properties.DType;
193 result.saturatingAccumulation = (VkBool32)VK_FALSE;
194 result.scope = (VkScopeKHR)properties.scope;
195
196 return result;
197 }
198
convertCooperativeMatrixProperties(const std::vector<VkCooperativeMatrixPropertiesNV> & properties)199 std::vector<VkCooperativeMatrixPropertiesKHR> convertCooperativeMatrixProperties(
200 const std::vector<VkCooperativeMatrixPropertiesNV> &properties)
201 {
202 std::vector<VkCooperativeMatrixPropertiesKHR> result(properties.size());
203
204 for (size_t i = 0; i < properties.size(); ++i)
205 result[i] = convertCooperativeMatrixProperties(properties[i]);
206
207 return result;
208 }
209
210 template <typename T>
getCooperativeMatrixPropertiesAll(Context & context,std::vector<T> & properties)211 void getCooperativeMatrixPropertiesAll(Context &context, std::vector<T> &properties)
212 {
213 uint32_t propertyCount = 0;
214
215 VK_CHECK(getCooperativeMatrixProperties(context.getInstanceInterface(), context.getPhysicalDevice(), &propertyCount,
216 (T *)DE_NULL));
217
218 if (propertyCount > 0)
219 {
220 const T sample = initVulkanStructureConst();
221
222 properties.resize(propertyCount, sample);
223
224 VK_CHECK(getCooperativeMatrixProperties(context.getInstanceInterface(), context.getPhysicalDevice(),
225 &propertyCount, properties.data()));
226 }
227 else
228 {
229 properties.clear();
230 }
231 }
232
getCooperativeMatrixPropertiesConverted(Context & context,const bool khr)233 std::vector<VkCooperativeMatrixPropertiesKHR> getCooperativeMatrixPropertiesConverted(Context &context, const bool khr)
234 {
235 std::vector<VkCooperativeMatrixPropertiesKHR> properties;
236
237 if (khr)
238 {
239 getCooperativeMatrixPropertiesAll(context, properties);
240 }
241 else
242 {
243 std::vector<VkCooperativeMatrixPropertiesNV> propertiesNV;
244
245 getCooperativeMatrixPropertiesAll(context, propertiesNV);
246
247 properties = convertCooperativeMatrixProperties(propertiesNV);
248 }
249
250 return properties;
251 }
252
getSubgroupSizeFromMode(Context & context,const SubgroupSizeMode subgroupSizeMode)253 uint32_t getSubgroupSizeFromMode(Context &context, const SubgroupSizeMode subgroupSizeMode)
254 {
255 #ifndef CTS_USES_VULKANSC
256 const VkPhysicalDeviceSubgroupSizeControlProperties &subgroupSizeControlProperties =
257 context.getSubgroupSizeControlProperties();
258 #else
259 const VkPhysicalDeviceSubgroupSizeControlPropertiesEXT &subgroupSizeControlProperties =
260 context.getSubgroupSizeControlPropertiesEXT();
261 #endif // CTS_USES_VULKANSC
262
263 switch (subgroupSizeMode)
264 {
265 case SUBGROUP_SIZE_MAX:
266 return subgroupSizeControlProperties.maxSubgroupSize;
267 case SUBGROUP_SIZE_MIN:
268 return subgroupSizeControlProperties.minSubgroupSize;
269 case SUBGROUP_SIZE_NONE:
270 return context.getSubgroupProperties().subgroupSize;
271 default:
272 TCU_THROW(NotSupportedError, "Unsupported Subgroup size");
273 }
274 }
275
276 class CooperativeMatrixTestInstance : public TestInstance
277 {
278 public:
279 CooperativeMatrixTestInstance(Context &context, const CaseDef &data);
280 ~CooperativeMatrixTestInstance(void);
281 tcu::TestStatus iterate(void);
282
283 private:
284 CaseDef m_data;
285 };
286
CooperativeMatrixTestInstance(Context & context,const CaseDef & data)287 CooperativeMatrixTestInstance::CooperativeMatrixTestInstance(Context &context, const CaseDef &data)
288 : vkt::TestInstance(context)
289 , m_data(data)
290 {
291 }
292
~CooperativeMatrixTestInstance(void)293 CooperativeMatrixTestInstance::~CooperativeMatrixTestInstance(void)
294 {
295 }
296
297 class CooperativeMatrixTestCase : public TestCase
298 {
299 public:
300 CooperativeMatrixTestCase(tcu::TestContext &context, const char *name, const CaseDef data);
301 ~CooperativeMatrixTestCase(void);
302 virtual void initPrograms(SourceCollections &programCollection) const;
303 virtual TestInstance *createInstance(Context &context) const;
304 virtual void checkSupport(Context &context) const;
305
306 private:
307 CaseDef m_data;
308 };
309
CooperativeMatrixTestCase(tcu::TestContext & context,const char * name,const CaseDef data)310 CooperativeMatrixTestCase::CooperativeMatrixTestCase(tcu::TestContext &context, const char *name, const CaseDef data)
311 : vkt::TestCase(context, name)
312 , m_data(data)
313 {
314 }
315
~CooperativeMatrixTestCase(void)316 CooperativeMatrixTestCase::~CooperativeMatrixTestCase(void)
317 {
318 }
319
checkSupport(Context & context) const320 void CooperativeMatrixTestCase::checkSupport(Context &context) const
321 {
322 if (!context.contextSupports(vk::ApiVersion(0, 1, 1, 0)))
323 {
324 TCU_THROW(NotSupportedError, "Vulkan 1.1 not supported");
325 }
326
327 if (isKhr(m_data.useType))
328 {
329 if (!context.getCooperativeMatrixFeatures().cooperativeMatrix)
330 {
331 TCU_THROW(NotSupportedError,
332 "VkPhysicalDeviceCooperativeMatrixFeaturesKHR::cooperativeMatrix not supported");
333 }
334 }
335 else
336 {
337 if (!context.getCooperativeMatrixFeaturesNV().cooperativeMatrix)
338 {
339 TCU_THROW(NotSupportedError,
340 "VkPhysicalDeviceCooperativeMatrixFeaturesNV::cooperativeMatrix not supported");
341 }
342 }
343
344 if (!context.getVulkanMemoryModelFeatures().vulkanMemoryModel)
345 {
346 TCU_THROW(NotSupportedError, "vulkanMemoryModel not supported");
347 }
348
349 if ((m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS || m_data.storageClass == SC_BUFFER_VARIABLE_POINTERS) &&
350 !context.getVariablePointersFeatures().variablePointers)
351 {
352 TCU_THROW(NotSupportedError, "variable pointers not supported");
353 }
354
355 if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER && !context.isBufferDeviceAddressSupported())
356 {
357 TCU_THROW(NotSupportedError, "buffer device address not supported");
358 }
359
360 if (!context.getShaderFloat16Int8Features().shaderFloat16 &&
361 (m_data.inputType == VK_COMPONENT_TYPE_FLOAT16_KHR || m_data.outputType == VK_COMPONENT_TYPE_FLOAT16_KHR))
362 {
363 TCU_THROW(NotSupportedError, "shaderFloat16 not supported");
364 }
365
366 std::vector<VkCooperativeMatrixPropertiesKHR> properties =
367 getCooperativeMatrixPropertiesConverted(context, isKhr(m_data.useType));
368 bool supported[2] = {false, false};
369 const auto isMMA = isMatrixMulAddOp(m_data.testType);
370 const auto isMMASat = m_data.testType == TT_MATRIXMULADD_SATURATED;
371
372 for (size_t i = 0; i < properties.size(); ++i)
373 {
374 const VkCooperativeMatrixPropertiesKHR *p = &properties[i];
375
376 if (p->scope != VK_SCOPE_SUBGROUP_KHR)
377 continue;
378
379 if (isMMA && isMMASat != static_cast<bool>(p->saturatingAccumulation))
380 continue;
381
382 if (isMMA)
383 {
384 if (p->AType == m_data.inputType && p->BType == m_data.inputType && p->CType == m_data.outputType &&
385 p->ResultType == m_data.outputType)
386 {
387 supported[0] = supported[1] = true;
388 }
389 }
390 else
391 {
392 const VkComponentTypeKHR types[2] = {m_data.inputType, m_data.outputType};
393
394 for (uint32_t j = 0; j < 2; ++j)
395 {
396 switch (m_data.useType)
397 {
398 case UT_NV:
399 {
400 if (p->AType == types[j] || p->BType == types[j] || p->CType == types[j] ||
401 p->ResultType == types[j])
402 supported[j] = true;
403
404 break;
405 }
406 case UT_KHR_A:
407 {
408 if (p->AType == types[j])
409 supported[j] = true;
410
411 break;
412 }
413 case UT_KHR_B:
414 {
415 if (p->BType == types[j])
416 supported[j] = true;
417
418 break;
419 }
420 case UT_KHR_Result:
421 {
422 if (p->ResultType == types[j])
423 supported[j] = true;
424
425 break;
426 }
427 default:
428 TCU_THROW(InternalError, "Unsupported use type");
429 }
430 }
431 }
432 }
433
434 if (!supported[0] || !supported[1])
435 TCU_THROW(NotSupportedError, "cooperative matrix combination not supported");
436
437 checkShaderObjectRequirements(context.getInstanceInterface(), context.getPhysicalDevice(),
438 m_data.computePipelineConstructionType);
439 }
440
441 struct
442 {
443 const char *typeName;
444 const char *coopmatTypeName;
445 uint32_t bits;
446 bool isSigned;
447 } componentTypeInfo[] = {
448 {"float16_t", "fcoopmatNV", 16, true}, {"float32_t", "fcoopmatNV", 32, true}, {"float64_t", "fcoopmatNV", 64, true},
449 {"int8_t", "icoopmatNV", 8, true}, {"int16_t", "icoopmatNV", 16, true}, {"int32_t", "icoopmatNV", 32, true},
450 {"int64_t", "icoopmatNV", 64, true}, {"uint8_t", "ucoopmatNV", 8, false}, {"uint16_t", "ucoopmatNV", 16, false},
451 {"uint32_t", "ucoopmatNV", 32, false}, {"uint64_t", "ucoopmatNV", 64, false},
452 };
453
isFloatType(VkComponentTypeKHR t)454 bool isFloatType(VkComponentTypeKHR t)
455 {
456 switch (t)
457 {
458 case VK_COMPONENT_TYPE_FLOAT16_KHR:
459 case VK_COMPONENT_TYPE_FLOAT32_KHR:
460 case VK_COMPONENT_TYPE_FLOAT64_KHR:
461 return true;
462 default:
463 return false;
464 }
465 }
466
isSIntType(VkComponentTypeKHR t)467 bool isSIntType(VkComponentTypeKHR t)
468 {
469 switch (t)
470 {
471 case VK_COMPONENT_TYPE_SINT8_KHR:
472 case VK_COMPONENT_TYPE_SINT16_KHR:
473 case VK_COMPONENT_TYPE_SINT32_KHR:
474 case VK_COMPONENT_TYPE_SINT64_KHR:
475 return true;
476 default:
477 return false;
478 }
479 }
480
initPrograms(SourceCollections & programCollection) const481 void CooperativeMatrixTestCase::initPrograms(SourceCollections &programCollection) const
482 {
483 const char *suffix = (isKhr(m_data.useType) ? "" : "NV");
484 const char *ext = isKhr(m_data.useType) ? "#extension GL_KHR_cooperative_matrix : enable\n" :
485 "#extension GL_NV_cooperative_matrix : enable\n"
486 "#extension GL_NV_integer_cooperative_matrix : enable\n";
487 const char *sat = (m_data.testType == TT_MATRIXMULADD_SATURATED) ? ", gl_MatrixOperandsSaturatingAccumulation" : "";
488 std::stringstream css;
489 css << "#version 450 core\n";
490 css << "#pragma use_vulkan_memory_model\n";
491 css << "#extension GL_KHR_shader_subgroup_basic : enable\n"
492 "#extension GL_KHR_memory_scope_semantics : enable\n"
493 << ext
494 << "#extension GL_EXT_shader_explicit_arithmetic_types : enable\n"
495 "#extension GL_EXT_buffer_reference : enable\n"
496 "// strides overriden by spec constants\n"
497 "layout(constant_id = 2) const int AStride = 1;\n"
498 "layout(constant_id = 3) const int BStride = 1;\n"
499 "layout(constant_id = 4) const int CStride = 1;\n"
500 "layout(constant_id = 5) const int OStride = 1;\n"
501 "layout(constant_id = 6) const int M = 1;\n"
502 "layout(constant_id = 7) const int N = 1;\n"
503 "layout(constant_id = 8) const int K = 1;\n"
504 "layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;\n";
505
506 if (m_data.storageClass == SC_BUFFER_VARIABLE_POINTERS || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
507 css << "#pragma use_variable_pointers\n";
508
509 struct
510 {
511 string rows, cols;
512 } dims[4];
513
514 if (isMatrixMulAddOp(m_data.testType))
515 {
516 dims[0].rows = "M";
517 dims[0].cols = "K";
518 dims[1].rows = "K";
519 dims[1].cols = "N";
520 dims[2].rows = "M";
521 dims[2].cols = "N";
522 dims[3].rows = "M";
523 dims[3].cols = "N";
524 }
525 else
526 {
527 dims[0].rows = "M";
528 dims[0].cols = "N";
529 dims[1].rows = "M";
530 dims[1].cols = "N";
531 dims[2].rows = "M";
532 dims[2].cols = "N";
533 dims[3].rows = "M";
534 dims[3].cols = "N";
535 }
536
537 const char *typeStrA = componentTypeInfo[m_data.inputType].typeName;
538 const char *typeStrB = componentTypeInfo[m_data.inputType].typeName;
539 const char *typeStrC = componentTypeInfo[m_data.outputType].typeName;
540 const char *typeStrO = componentTypeInfo[m_data.outputType].typeName;
541 string inputType;
542 string outputType;
543 string divisorA;
544 string divisorB;
545 string divisorC;
546 string divisorO;
547 string *divisors[4] = {&divisorA, &divisorB, &divisorC, &divisorO};
548
549 if (m_data.testType == TT_MULTICOMPONENT_LOAD)
550 {
551 const char *componentSuffix = m_data.inputComponentCount == 2 ? "vec2" :
552 m_data.inputComponentCount == 4 ? "vec4" :
553 "";
554
555 inputType = string(1, componentTypeInfo[m_data.inputType].coopmatTypeName[0]) +
556 de::toString(componentTypeInfo[m_data.inputType].bits) + componentSuffix;
557
558 typeStrA = inputType.c_str();
559 typeStrB = inputType.c_str();
560 divisorA = m_data.inputComponentCount == 2 ? "/2" : m_data.inputComponentCount == 4 ? "/4" : "";
561 divisorB = divisorA;
562 }
563
564 if (m_data.testType == TT_MULTICOMPONENT_SAVE)
565 {
566 const char *componentSuffix = m_data.outputComponentCount == 2 ? "vec2" :
567 m_data.outputComponentCount == 4 ? "vec4" :
568 "";
569
570 outputType = string(1, componentTypeInfo[m_data.outputType].coopmatTypeName[0]) +
571 de::toString(componentTypeInfo[m_data.outputType].bits) + componentSuffix;
572
573 typeStrC = outputType.c_str();
574 typeStrO = outputType.c_str();
575 divisorC = m_data.outputComponentCount == 2 ? "/2" : m_data.outputComponentCount == 4 ? "/4" : "";
576 divisorO = divisorC;
577 }
578
579 css << "const int workgroupsX = " << m_data.workgroupsX << ";\n";
580 css << "const uvec2 subgroupsPerWG = uvec2(" << m_data.subgroupsPerWorkgroupX << ", "
581 << m_data.subgroupsPerWorkgroupY << ");\n";
582
583 if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
584 {
585 css << "layout(buffer_reference) buffer InputA { " << typeStrA << " x[]; };\n";
586 css << "layout(buffer_reference) buffer InputB { " << typeStrB << " x[]; };\n";
587 css << "layout(buffer_reference) buffer InputC { " << typeStrC << " x[]; };\n";
588 css << "layout(buffer_reference) buffer Output { " << typeStrO << " x[]; };\n";
589 css << "layout(set=0, binding=4) buffer Params { InputA inputA; InputB inputB; InputC inputC; Output outputO; "
590 "} params;\n";
591 }
592 else
593 {
594 css << "layout(set=0, binding=0) coherent buffer InputA { " << typeStrA << " x[]; } inputA;\n";
595 css << "layout(set=0, binding=1) coherent buffer InputB { " << typeStrB << " x[]; } inputB;\n";
596 css << "layout(set=0, binding=2) coherent buffer InputC { " << typeStrC << " x[]; } inputC;\n";
597 css << "layout(set=0, binding=3) coherent buffer Output { " << typeStrO << " x[]; } outputO;\n";
598 }
599
600 if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
601 {
602 css << "shared " << typeStrA << " sharedA[" << dims[0].rows << " * " << dims[0].cols
603 << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
604 css << "shared " << typeStrB << " sharedB[" << dims[1].rows << " * " << dims[1].cols
605 << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
606 css << "shared " << typeStrC << " sharedC[" << dims[2].rows << " * " << dims[2].cols
607 << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
608 css << "shared " << typeStrO << " sharedO[" << dims[3].rows << " * " << dims[3].cols
609 << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
610 }
611
612 std::stringstream matAType, matBType, matCType, outputMatType;
613
614 if (isKhr(m_data.useType))
615 {
616 const bool useSame = !isMatrixMulAddOp(m_data.testType);
617 const char *sameType = m_data.useType == UT_KHR_A ? "gl_MatrixUseA" :
618 m_data.useType == UT_KHR_B ? "gl_MatrixUseB" :
619 m_data.useType == UT_KHR_Result ? "gl_MatrixUseAccumulator" :
620 "Invalid use";
621 const char *atype = useSame ? sameType : "gl_MatrixUseA";
622 const char *btype = useSame ? sameType : "gl_MatrixUseB";
623 const char *ctype = useSame ? sameType : "gl_MatrixUseAccumulator";
624 const char *rtype = useSame ? sameType : "gl_MatrixUseAccumulator";
625
626 matAType << "coopmat<" << componentTypeInfo[m_data.inputType].typeName << ", gl_ScopeSubgroup, " << dims[0].rows
627 << ", " << dims[0].cols << ", " << atype << ">";
628 matBType << "coopmat<" << componentTypeInfo[m_data.inputType].typeName << ", gl_ScopeSubgroup, " << dims[1].rows
629 << ", " << dims[1].cols << ", " << btype << ">";
630 matCType << "coopmat<" << componentTypeInfo[m_data.outputType].typeName << ", gl_ScopeSubgroup, "
631 << dims[2].rows << ", " << dims[2].cols << ", " << ctype << ">";
632 outputMatType << "coopmat<" << componentTypeInfo[m_data.outputType].typeName << ", gl_ScopeSubgroup, "
633 << dims[3].rows << ", " << dims[3].cols << ", " << rtype << ">";
634 }
635 else
636 {
637 matAType << componentTypeInfo[m_data.inputType].coopmatTypeName << "<"
638 << componentTypeInfo[m_data.inputType].bits << ", gl_ScopeSubgroup, " << dims[0].rows << ", "
639 << dims[0].cols << ">";
640 matBType << componentTypeInfo[m_data.inputType].coopmatTypeName << "<"
641 << componentTypeInfo[m_data.inputType].bits << ", gl_ScopeSubgroup, " << dims[1].rows << ", "
642 << dims[1].cols << ">";
643 matCType << componentTypeInfo[m_data.outputType].coopmatTypeName << "<"
644 << componentTypeInfo[m_data.outputType].bits << ", gl_ScopeSubgroup, " << dims[2].rows << ", "
645 << dims[2].cols << ">";
646 outputMatType << componentTypeInfo[m_data.outputType].coopmatTypeName << "<"
647 << componentTypeInfo[m_data.outputType].bits << ", gl_ScopeSubgroup, " << dims[3].rows << ", "
648 << dims[3].cols << ">";
649 }
650
651 css << matAType.str() << " matA;\n";
652 css << matBType.str() << " matB;\n";
653 css << matCType.str() << " matC;\n";
654 css << outputMatType.str() << " matO;\n";
655
656 if (m_data.testType == TT_CONSTANT)
657 css << "const " << outputMatType.str() << " matConst = " << outputMatType.str() << "(1.0);\n";
658
659 if (m_data.testType == TT_FUNC)
660 css << matAType.str() << " f(" << matAType.str() << " m) { return -m; }\n";
661
662 css << "void main()\n"
663 "{\n"
664 // matrixID is the x,y index of the matrix owned by this subgroup.
665 " uvec2 subgroupXY = uvec2(gl_SubgroupID % subgroupsPerWG.x, gl_SubgroupID / subgroupsPerWG.x);\n"
666 " uvec2 matrixID = uvec2(gl_WorkGroupID.xy) * subgroupsPerWG + subgroupXY;\n";
667
668 if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
669 {
670 css << " InputA inputA = params.inputA;\n";
671 css << " InputB inputB = params.inputB;\n";
672 css << " InputC inputC = params.inputC;\n";
673 css << " Output outputO = params.outputO;\n";
674 }
675
676 string strides[4];
677 for (uint32_t i = 0; i < 4; ++i)
678 {
679 strides[i] = (m_data.colMajor ? dims[i].rows : dims[i].cols) + string(" * ") +
680 de::toString(m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
681 }
682
683 // element<i> is the starting element in buffer memory.
684 // elementS<i> is the starting element in shared memory.
685 css << " uint element0 = (" << strides[0] << " * " << (m_data.colMajor ? dims[0].cols : dims[0].rows)
686 << " * matrixID.y + " << (m_data.colMajor ? dims[0].rows : dims[0].cols) << " * matrixID.x)" << divisorA
687 << ";\n"
688 " uint element1 = ("
689 << strides[1] << " * " << (m_data.colMajor ? dims[1].cols : dims[1].rows) << " * matrixID.y + "
690 << (m_data.colMajor ? dims[1].rows : dims[1].cols) << " * matrixID.x)" << divisorB
691 << ";\n"
692 " uint element2 = ("
693 << strides[2] << " * " << (m_data.colMajor ? dims[2].cols : dims[2].rows) << " * matrixID.y + "
694 << (m_data.colMajor ? dims[2].rows : dims[2].cols) << " * matrixID.x)" << divisorC
695 << ";\n"
696 " uint element3 = ("
697 << strides[3] << " * " << (m_data.colMajor ? dims[3].cols : dims[3].rows) << " * matrixID.y + "
698 << (m_data.colMajor ? dims[3].rows : dims[3].cols) << " * matrixID.x)" << divisorO
699 << ";\n"
700 " uint elementS0, elementS1, elementS2, elementS3;\n";
701
702 // For shared memory tests, copy the matrix from buffer memory into
703 // workgroup memory. For simplicity, do it all on a single thread.
704 if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
705 {
706 const char *name[] = {
707 "sharedA",
708 "sharedB",
709 "sharedC",
710 };
711 const char *inputName[] = {
712 "inputA",
713 "inputB",
714 "inputC",
715 };
716 for (uint32_t m = 0; m < 4; ++m)
717 {
718 string sharedStride = strides[m] + " / workgroupsX";
719 css << " elementS" << m << " = (" << sharedStride << " * "
720 << (m_data.colMajor ? dims[m].cols : dims[m].rows) << " * subgroupXY.y + "
721 << (m_data.colMajor ? dims[m].rows : dims[m].cols) << " * subgroupXY.x)" << *divisors[m] << ";\n";
722 }
723 css << " if (subgroupElect()) {\n";
724 // copy all three input buffers.
725 for (uint32_t m = 0; m < 3; ++m)
726 {
727 string sharedStride = strides[m] + " / workgroupsX";
728 css << " for (int i = 0; i < " << dims[m].rows
729 << "; ++i) {\n"
730 " for (int j = 0; j < "
731 << dims[m].cols
732 << "; ++j) {\n"
733 " int localElementInput = ("
734 << strides[m] << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ")"
735 << *divisors[m]
736 << ";\n"
737 " int localElementShared = ("
738 << sharedStride << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j")
739 << ")" << *divisors[m]
740 << ";\n"
741 " "
742 << name[m] << "[elementS" << m << " + localElementShared] = " << inputName[m] << ".x[element" << m
743 << " + localElementInput];\n"
744 " }\n"
745 " }\n";
746 strides[m] = sharedStride;
747 }
748 css << " }\n";
749 css << " controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, "
750 "gl_SemanticsAcquireRelease);\n";
751 }
752
753 const char *colMajorNV = (m_data.colMajor ? "true" : "false");
754 const char *colMajorKHR =
755 (m_data.colMajor ? "gl_CooperativeMatrixLayoutColumnMajor" : "gl_CooperativeMatrixLayoutRowMajor");
756 const char *colMajor = (isKhr(m_data.useType) ? colMajorKHR : colMajorNV);
757
758 string loadStrides[3] = {strides[0] + divisorA, strides[1] + divisorB, strides[2] + divisorC};
759 // Load with a stride of 0
760 if (m_data.testType == TT_MATRIXMULADD_STRIDE0)
761 loadStrides[0] = loadStrides[1] = loadStrides[2] = "0";
762
763 if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
764 {
765 css << " coopMatLoad" << suffix << "(matA, sharedA, elementS0, " << loadStrides[0] << ", " << colMajor
766 << ");\n"
767 " coopMatLoad"
768 << suffix << "(matB, sharedB, elementS1, " << loadStrides[1] << ", " << colMajor
769 << ");\n"
770 " coopMatLoad"
771 << suffix << "(matC, sharedC, elementS2, " << loadStrides[2] << ", " << colMajor << ");\n";
772 }
773 else
774 {
775 css << " coopMatLoad" << suffix << "(matA, inputA.x, element0, " << loadStrides[0] << ", " << colMajor
776 << ");\n"
777 " coopMatLoad"
778 << suffix << "(matB, inputB.x, element1, " << loadStrides[1] << ", " << colMajor
779 << ");\n"
780 " coopMatLoad"
781 << suffix << "(matC, inputC.x, element2, " << loadStrides[2] << ", " << colMajor << ");\n";
782 }
783
784 if (m_data.testType == TT_COMPOSITE_ARRAY || m_data.testType == TT_MATRIXMULADD_ARRAY)
785 {
786 css << " " << matAType.str() << " matAArr[2];\n matAArr[1] = matA; matAArr[0] = " << matAType.str()
787 << "(0.0);\n"
788 " "
789 << matBType.str() << " matBArr[2];\n matBArr[1] = matB; matBArr[0] = " << matBType.str()
790 << "(0.0);\n"
791 " "
792 << matCType.str() << " matCArr[2];\n matCArr[1] = matC; matCArr[0] = " << matCType.str()
793 << "(0.0);\n"
794 " "
795 << outputMatType.str() << " matOArr[2];\n";
796 }
797
798 switch (m_data.testType)
799 {
800 default:
801 DE_ASSERT(0);
802 // fall through
803 case TT_LENGTH:
804 css << " matO = " << outputMatType.str() << "(matO.length());\n";
805 break;
806 case TT_CONSTANT:
807 css << " matO = matConst;\n";
808 break;
809 case TT_CONVERT:
810 css << " matO = " << outputMatType.str() << "(matA);\n";
811 break;
812 case TT_COMPOSITE:
813 css << " " << matAType.str() << " t = " << matAType.str()
814 << "(matB[0]);\n"
815 " for (int i = 1; i < matA.length(); ++i) {\n"
816 " matO[i] = matA[i] + matB[i];\n"
817 " }\n"
818 " if (matA.length() > 0)\n"
819 " matO[0] = matA[0] + t[0];\n";
820 break;
821 case TT_COMPOSITE_RVALUE:
822 css << " for (int i = 1; i < matA.length(); ++i) {\n"
823 " matO[i] = matA[i] + matB[i];\n"
824 " }\n"
825 " "
826 << matAType.str()
827 << " t = matA;\n"
828 " if (matA.length() > 0) {\n"
829 " matO[0] = (t += matB)[0];\n"
830 " }\n";
831 break;
832 case TT_COMPOSITE_ARRAY:
833 css << " for (int i = 0; i < matA.length(); ++i) {\n"
834 " matOArr[1][i] = matAArr[1][i] + matBArr[1][i];\n"
835 " }\n";
836 break;
837 case TT_ADD:
838 css << " matO = matA + matB;\n";
839 break;
840 case TT_SUB:
841 css << " matO = matA - matB;\n";
842 break;
843 case TT_DIV:
844 css << " matO = matA / matB;\n";
845 break;
846 case TT_MUL:
847 css << " matO = matA * matB;\n";
848 break;
849 case TT_NEGATE:
850 css << " matO = -matA;\n";
851 break;
852 case TT_FUNC:
853 css << " matO = f(matA);\n";
854 break;
855 case TT_MATRIXTIMESSCALAR:
856 css << " matO = (" << typeStrA << "(2.0)*matA)*" << typeStrA << "(3.0);\n";
857 break;
858 case TT_MATRIXMULADD_STRIDE0:
859 case TT_MATRIXMULADD_WRAPPING:
860 case TT_MATRIXMULADD_SATURATED:
861 case TT_MATRIXMULADD:
862 css << " matO = coopMatMulAdd" << suffix << "(matA, matB, matC" << sat << ");\n";
863 break;
864 case TT_MATRIXMULADD_ARRAY:
865 css << " matOArr[1] = coopMatMulAdd" << suffix << "(matAArr[1], matBArr[1], matCArr[1]);\n";
866 break;
867 case TT_MULTICOMPONENT_LOAD:
868 css << " matO = matA;\n";
869 break;
870 case TT_MULTICOMPONENT_SAVE:
871 css << " matO = matA;\n";
872 break;
873 }
874
875 if (m_data.testType == TT_COMPOSITE_ARRAY || m_data.testType == TT_MATRIXMULADD_ARRAY)
876 {
877 css << " matOArr[0] = " << outputMatType.str() << "(0.0);\n";
878 css << " matO = matOArr[1];\n";
879 }
880
881 if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
882 {
883 string sharedStride = strides[3] + " / workgroupsX";
884 css << " coopMatStore" << suffix << "(matO, sharedO, elementS3, " << sharedStride << divisorO << ", "
885 << colMajor << ");\n";
886 css << " controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, "
887 "gl_SemanticsAcquireRelease);\n";
888 css << " if (subgroupElect()) {\n";
889 css << " for (int i = 0; i < " << dims[3].rows
890 << "; ++i) {\n"
891 " for (int j = 0; j < "
892 << dims[3].cols
893 << "; ++j) {\n"
894 " int localElementInput = ("
895 << strides[3] << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ")"
896 << *divisors[3]
897 << ";\n"
898 " int localElementShared = ("
899 << sharedStride << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ")"
900 << *divisors[3]
901 << ";\n"
902 " outputO.x[element3 + localElementInput] = sharedO[elementS3 + localElementShared];\n"
903 " }\n"
904 " }\n";
905 css << " }\n";
906 }
907 else
908 {
909 css << " coopMatStore" << suffix << "(matO, outputO.x, element3, " << strides[3] << divisorO << ", "
910 << colMajor << ");\n";
911 }
912
913 css << "}\n";
914
915 const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
916
917 programCollection.glslSources.add("test") << glu::ComputeSource(css.str()) << buildOptions;
918 }
919
createInstance(Context & context) const920 TestInstance *CooperativeMatrixTestCase::createInstance(Context &context) const
921 {
922 return new CooperativeMatrixTestInstance(context, m_data);
923 }
924
setDataFloat(void * base,VkComponentTypeKHR dt,uint32_t i,float value)925 void setDataFloat(void *base, VkComponentTypeKHR dt, uint32_t i, float value)
926 {
927 if (dt == VK_COMPONENT_TYPE_FLOAT32_KHR)
928 {
929 ((float *)base)[i] = value;
930 }
931 else
932 {
933 DE_ASSERT(dt == VK_COMPONENT_TYPE_FLOAT16_KHR);
934 ((tcu::float16_t *)base)[i] = tcu::Float16(value).bits();
935 }
936 }
937
getDataFloat(void * base,VkComponentTypeKHR dt,uint32_t i)938 float getDataFloat(void *base, VkComponentTypeKHR dt, uint32_t i)
939 {
940 if (dt == VK_COMPONENT_TYPE_FLOAT32_KHR)
941 {
942 return ((float *)base)[i];
943 }
944 else
945 {
946 DE_ASSERT(dt == VK_COMPONENT_TYPE_FLOAT16_KHR);
947 return tcu::Float16(((const tcu::float16_t *)base)[i]).asFloat();
948 }
949 }
950
setDataInt(void * base,VkComponentTypeKHR dt,uint32_t i,uint32_t value)951 void setDataInt(void *base, VkComponentTypeKHR dt, uint32_t i, uint32_t value)
952 {
953 DE_ASSERT(componentTypeInfo[dt].bits <= 32);
954
955 switch (dt)
956 {
957 case VK_COMPONENT_TYPE_UINT8_KHR:
958 ((uint8_t *)base)[i] = (uint8_t)value;
959 break;
960 case VK_COMPONENT_TYPE_UINT16_KHR:
961 ((uint16_t *)base)[i] = (uint16_t)value;
962 break;
963 case VK_COMPONENT_TYPE_UINT32_KHR:
964 ((uint32_t *)base)[i] = (uint32_t)value;
965 break;
966 case VK_COMPONENT_TYPE_SINT8_KHR:
967 ((int8_t *)base)[i] = (int8_t)value;
968 break;
969 case VK_COMPONENT_TYPE_SINT16_KHR:
970 ((int16_t *)base)[i] = (int16_t)value;
971 break;
972 case VK_COMPONENT_TYPE_SINT32_KHR:
973 ((int32_t *)base)[i] = (int32_t)value;
974 break;
975 default:
976 TCU_THROW(InternalError, "Unsupported type");
977 }
978 }
979
getDataInt(void * base,VkComponentTypeKHR dt,uint32_t i)980 uint32_t getDataInt(void *base, VkComponentTypeKHR dt, uint32_t i)
981 {
982 DE_ASSERT(componentTypeInfo[dt].bits <= 32);
983
984 switch (dt)
985 {
986 case VK_COMPONENT_TYPE_UINT8_KHR:
987 return ((uint8_t *)base)[i];
988 case VK_COMPONENT_TYPE_UINT16_KHR:
989 return ((uint16_t *)base)[i];
990 case VK_COMPONENT_TYPE_UINT32_KHR:
991 return ((uint32_t *)base)[i];
992 case VK_COMPONENT_TYPE_SINT8_KHR:
993 return ((int8_t *)base)[i];
994 case VK_COMPONENT_TYPE_SINT16_KHR:
995 return ((int16_t *)base)[i];
996 case VK_COMPONENT_TYPE_SINT32_KHR:
997 return ((int32_t *)base)[i];
998 default:
999 TCU_THROW(InternalError, "Unsupported type");
1000 }
1001 }
1002
1003 template <typename T>
getDataConvertedToT(void * base,VkComponentTypeKHR dt,uint32_t i)1004 T getDataConvertedToT(void *base, VkComponentTypeKHR dt, uint32_t i)
1005 {
1006 DE_ASSERT(componentTypeInfo[dt].bits <= 32);
1007
1008 switch (dt)
1009 {
1010 case VK_COMPONENT_TYPE_UINT8_KHR:
1011 return (T)((uint8_t *)base)[i];
1012 case VK_COMPONENT_TYPE_UINT16_KHR:
1013 return (T)((uint16_t *)base)[i];
1014 case VK_COMPONENT_TYPE_UINT32_KHR:
1015 return (T)((uint32_t *)base)[i];
1016 case VK_COMPONENT_TYPE_SINT8_KHR:
1017 return (T)((int8_t *)base)[i];
1018 case VK_COMPONENT_TYPE_SINT16_KHR:
1019 return (T)((int16_t *)base)[i];
1020 case VK_COMPONENT_TYPE_SINT32_KHR:
1021 return (T)((int32_t *)base)[i];
1022 case VK_COMPONENT_TYPE_FLOAT32_KHR:
1023 {
1024 float temp = ((float *)base)[i];
1025 if (std::numeric_limits<T>::min() == 0)
1026 temp = std::max(temp, 0.0f);
1027 return (T)temp;
1028 }
1029 case VK_COMPONENT_TYPE_FLOAT16_KHR:
1030 {
1031 float temp = tcu::Float16(((tcu::float16_t *)base)[i]).asFloat();
1032 if (std::numeric_limits<T>::min() == 0)
1033 temp = std::max(temp, 0.0f);
1034 return (T)temp;
1035 }
1036 default:
1037 TCU_THROW(InternalError, "Unsupported type");
1038 }
1039 }
1040
1041 template <typename T>
satAdd(T a,T b)1042 T satAdd(T a, T b)
1043 {
1044 if (a > 0)
1045 {
1046 if (b > std::numeric_limits<T>::max() - a)
1047 return std::numeric_limits<T>::max();
1048 }
1049 else if (b < std::numeric_limits<T>::min() - a)
1050 {
1051 return std::numeric_limits<T>::min();
1052 }
1053
1054 return (T)(a + b);
1055 }
1056
satAddData(VkComponentTypeKHR dt,uint32_t a,uint32_t b)1057 uint32_t satAddData(VkComponentTypeKHR dt, uint32_t a, uint32_t b)
1058 {
1059 DE_ASSERT(componentTypeInfo[dt].bits <= 32);
1060
1061 switch (dt)
1062 {
1063 case VK_COMPONENT_TYPE_UINT8_KHR:
1064 return deMinu32(a + b, std::numeric_limits<uint8_t>::max());
1065 case VK_COMPONENT_TYPE_UINT16_KHR:
1066 return deMinu32(a + b, std::numeric_limits<uint16_t>::max());
1067 case VK_COMPONENT_TYPE_UINT32_KHR:
1068 return (a + b >= a) ? a + b : std::numeric_limits<uint32_t>::max();
1069 case VK_COMPONENT_TYPE_SINT8_KHR:
1070 return (uint32_t)satAdd((int8_t)a, (int8_t)b);
1071 case VK_COMPONENT_TYPE_SINT16_KHR:
1072 return (uint32_t)satAdd((int16_t)a, (int16_t)b);
1073 case VK_COMPONENT_TYPE_SINT32_KHR:
1074 return (uint32_t)satAdd((int32_t)a, (int32_t)b);
1075 default:
1076 TCU_THROW(InternalError, "Unsupported type");
1077 }
1078 }
1079
getLimit(VkComponentTypeKHR dt,bool positive)1080 uint32_t getLimit(VkComponentTypeKHR dt, bool positive)
1081 {
1082 DE_ASSERT(componentTypeInfo[dt].bits <= 32);
1083
1084 switch (dt)
1085 {
1086 case VK_COMPONENT_TYPE_UINT8_KHR:
1087 return uint32_t(positive ? std::numeric_limits<uint8_t>::max() : std::numeric_limits<uint8_t>::min());
1088 case VK_COMPONENT_TYPE_UINT16_KHR:
1089 return uint32_t(positive ? std::numeric_limits<uint16_t>::max() : std::numeric_limits<uint16_t>::min());
1090 case VK_COMPONENT_TYPE_UINT32_KHR:
1091 return uint32_t(positive ? std::numeric_limits<uint32_t>::max() : std::numeric_limits<uint32_t>::min());
1092 case VK_COMPONENT_TYPE_SINT8_KHR:
1093 return uint32_t(positive ? std::numeric_limits<int8_t>::max() : std::numeric_limits<int8_t>::min());
1094 case VK_COMPONENT_TYPE_SINT16_KHR:
1095 return uint32_t(positive ? std::numeric_limits<int16_t>::max() : std::numeric_limits<int16_t>::min());
1096 case VK_COMPONENT_TYPE_SINT32_KHR:
1097 return uint32_t(positive ? std::numeric_limits<int32_t>::max() : std::numeric_limits<int32_t>::min());
1098 default:
1099 TCU_THROW(InternalError, "Unsupported type");
1100 }
1101 }
1102
setSingleElementInt(void * data,VkComponentTypeKHR dt,uint32_t start,uint32_t count,uint32_t step,uint32_t at,uint32_t val)1103 void setSingleElementInt(void *data, VkComponentTypeKHR dt, uint32_t start, uint32_t count, uint32_t step, uint32_t at,
1104 uint32_t val)
1105 {
1106 for (uint32_t i = 0; i < count; i++)
1107 setDataInt(data, dt, start + i * step, (i == at) ? val : 0);
1108 }
1109
1110 #ifdef COOPERATIVE_MATRIX_EXTENDED_DEBUG
dumpWholeMatrix(void * data,VkComponentTypeKHR dt,bool colMajor,uint32_t matrixElemCount,uint32_t stride)1111 string dumpWholeMatrix(void *data, VkComponentTypeKHR dt, bool colMajor, uint32_t matrixElemCount, uint32_t stride)
1112 {
1113 const uint32_t rowsCount = colMajor ? stride : matrixElemCount / stride;
1114 const uint32_t colsCount = colMajor ? matrixElemCount / stride : stride;
1115 bool floatType = isFloatType(dt);
1116 bool sIntType = isSIntType(dt);
1117 std::stringstream ss;
1118
1119 DE_ASSERT(rowsCount * colsCount == matrixElemCount);
1120
1121 for (uint32_t r = 0; r < rowsCount; r++)
1122 {
1123 for (uint32_t c = 0; c < colsCount; c++)
1124 {
1125 const uint32_t i = colMajor ? rowsCount * c + r : colsCount * r + c;
1126
1127 if (floatType)
1128 ss << getDataFloat(data, dt, i) << "\t";
1129 else if (sIntType)
1130 ss << (int32_t)getDataInt(data, dt, i) << "\t";
1131 else
1132 ss << getDataInt(data, dt, i) << "\t";
1133 }
1134
1135 ss << std::endl;
1136 }
1137
1138 return ss.str();
1139 }
1140 #endif
1141
iterate(void)1142 tcu::TestStatus CooperativeMatrixTestInstance::iterate(void)
1143 {
1144 const DeviceInterface &vk = m_context.getDeviceInterface();
1145 const VkDevice device = m_context.getDevice();
1146 Allocator &allocator = m_context.getDefaultAllocator();
1147 MemoryRequirement memoryDeviceAddress =
1148 m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER &&
1149 m_context.isDeviceFunctionalitySupported("VK_KHR_buffer_device_address") ?
1150 MemoryRequirement::DeviceAddress :
1151 MemoryRequirement::Any;
1152 qpTestResult finalres = QP_TEST_RESULT_NOT_SUPPORTED;
1153 tcu::TestLog &log = m_context.getTestContext().getLog();
1154 const bool saturated = (m_data.testType == TT_MATRIXMULADD_SATURATED);
1155 const uint32_t subgroupSize = getSubgroupSizeFromMode(m_context, m_data.subgroupSizeMode);
1156 const float epsilon = 1.0f / float(1ull << 17); // 131072 is epsilon circa 1e-5
1157 vk::VkPhysicalDeviceProperties vkproperties;
1158
1159 m_context.getInstanceInterface().getPhysicalDeviceProperties(m_context.getPhysicalDevice(), &vkproperties);
1160
1161 deRandom rnd;
1162 deRandom_init(&rnd, 1234);
1163
1164 std::vector<VkCooperativeMatrixPropertiesKHR> properties =
1165 getCooperativeMatrixPropertiesConverted(m_context, isKhr(m_data.useType));
1166
1167 struct TestTuple
1168 {
1169 TestTuple()
1170 {
1171 }
1172 TestTuple(uint32_t m, uint32_t n, uint32_t k) : M(m), N(n), K(k)
1173 {
1174 }
1175
1176 bool operator<(const TestTuple &other) const
1177 {
1178 return M < other.M || (M == other.M && N < other.N) || (M == other.M && N == other.N && K < other.K);
1179 }
1180
1181 uint32_t M, N, K;
1182 };
1183
1184 vector<TestTuple> testSizes;
1185
1186 if (isMatrixMulAddOp(m_data.testType))
1187 {
1188 for (size_t i = 0; i < properties.size(); ++i)
1189 {
1190 VkCooperativeMatrixPropertiesKHR *p = &properties[i];
1191
1192 if (p->AType == m_data.inputType && p->BType == m_data.inputType && p->CType == m_data.outputType &&
1193 p->ResultType == m_data.outputType && p->scope == VK_SCOPE_SUBGROUP_KHR)
1194 {
1195 testSizes.push_back(TestTuple(p->MSize, p->NSize, p->KSize));
1196 }
1197 }
1198 }
1199 else
1200 {
1201 set<TestTuple> typeSizes[2];
1202 VkComponentTypeKHR types[2] = {m_data.inputType, m_data.outputType};
1203 const bool aType = (m_data.useType == UT_KHR_A) || (m_data.useType == UT_NV);
1204 const bool bType = (m_data.useType == UT_KHR_B) || (m_data.useType == UT_NV);
1205 const bool rType = (m_data.useType == UT_KHR_Result) || (m_data.useType == UT_NV);
1206
1207 for (uint32_t i = 0; i < properties.size(); ++i)
1208 {
1209 VkCooperativeMatrixPropertiesKHR *p = &properties[i];
1210
1211 if (p->scope != VK_SCOPE_SUBGROUP_KHR)
1212 continue;
1213
1214 for (uint32_t j = 0; j < 2; ++j)
1215 {
1216 // For these tests, m_data.M/N are always the matrix size. Check if they match
1217 // any input or output in the list.
1218 if (aType && p->AType == types[j])
1219 typeSizes[j].insert(TestTuple(p->MSize, p->KSize, 0));
1220 if (bType && p->BType == types[j])
1221 typeSizes[j].insert(TestTuple(p->KSize, p->NSize, 0));
1222 if (rType && (p->CType == types[j] || p->ResultType == types[j]))
1223 typeSizes[j].insert(TestTuple(p->MSize, p->NSize, 0));
1224 }
1225 }
1226 // Test those sizes that are supported for both the input and output type.
1227 std::set_intersection(typeSizes[0].begin(), typeSizes[0].end(), typeSizes[1].begin(), typeSizes[1].end(),
1228 std::back_inserter(testSizes));
1229 }
1230
1231 properties.resize(0);
1232
1233 for (unsigned int s = 0; s < testSizes.size(); ++s)
1234 {
1235 // When testing a multiply, MxNxK is the type of matrix multiply.
1236 // Otherwise, MxN is the size of the input/output matrices
1237 uint32_t M, N, K;
1238 M = testSizes[s].M;
1239 N = testSizes[s].N;
1240 K = testSizes[s].K;
1241
1242 log << tcu::TestLog::Message << "Testing M = " << M << ", N = " << N << ", K = " << K
1243 << tcu::TestLog::EndMessage;
1244
1245 struct
1246 {
1247 uint32_t rows, cols;
1248 } dims[4];
1249
1250 if (isMatrixMulAddOp(m_data.testType))
1251 {
1252 dims[0].rows = M;
1253 dims[0].cols = K;
1254 dims[1].rows = K;
1255 dims[1].cols = N;
1256 dims[2].rows = M;
1257 dims[2].cols = N;
1258 dims[3].rows = M;
1259 dims[3].cols = N;
1260 }
1261 else
1262 {
1263 dims[0].rows = M;
1264 dims[0].cols = N;
1265 dims[1].rows = M;
1266 dims[1].cols = N;
1267 dims[2].rows = M;
1268 dims[2].cols = N;
1269 dims[3].rows = M;
1270 dims[3].cols = N;
1271 }
1272
1273 VkComponentTypeKHR dataTypes[4];
1274 size_t elementSize[4];
1275 VkDeviceSize bufferSizes[5];
1276 de::MovePtr<BufferWithMemory> buffers[5];
1277 vk::VkDescriptorBufferInfo bufferDescriptors[5];
1278 uint32_t strides[4]; // in elements
1279 uint32_t loadStrides[4];
1280 uint32_t totalElements[4];
1281 size_t sharedMemoryUsage[4];
1282 size_t totalSharedMemoryUsage = 0;
1283
1284 for (uint32_t i = 0; i < 5; ++i)
1285 {
1286 if (i < 4)
1287 {
1288 // A/B use input type, C/D use output type
1289 dataTypes[i] = (i < 2) ? m_data.inputType : m_data.outputType;
1290 elementSize[i] = componentTypeInfo[dataTypes[i]].bits / 8;
1291
1292 strides[i] = (m_data.colMajor ? dims[i].rows : dims[i].cols) * m_data.subgroupsPerWorkgroupX *
1293 m_data.workgroupsX;
1294 loadStrides[i] = strides[i];
1295 totalElements[i] = strides[i] * (m_data.colMajor ? dims[i].cols : dims[i].rows) *
1296 m_data.subgroupsPerWorkgroupY * m_data.workgroupsY;
1297 sharedMemoryUsage[i] = dims[i].cols * dims[i].rows * m_data.subgroupsPerWorkgroupX *
1298 m_data.subgroupsPerWorkgroupY * elementSize[i] *
1299 ((i < 2) ? m_data.inputComponentCount : m_data.outputComponentCount);
1300
1301 bufferSizes[i] = totalElements[i] * elementSize[i];
1302
1303 // Check there is enough shared memory supported
1304 if ((m_data.useType != UT_NV) &&
1305 ((m_data.storageClass == SC_WORKGROUP) || (m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)))
1306 {
1307 totalSharedMemoryUsage += sharedMemoryUsage[i];
1308 if (totalSharedMemoryUsage > vkproperties.limits.maxComputeSharedMemorySize)
1309 throw tcu::NotSupportedError("Not enough shared memory supported.");
1310 }
1311 }
1312 else
1313 {
1314 bufferSizes[4] = sizeof(VkDeviceAddress) * 4;
1315 }
1316
1317 try
1318 {
1319 buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1320 vk, device, allocator,
1321 makeBufferCreateInfo(bufferSizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT |
1322 VK_BUFFER_USAGE_TRANSFER_DST_BIT |
1323 VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
1324 (memoryDeviceAddress == MemoryRequirement::DeviceAddress ?
1325 VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT :
1326 0)),
1327 MemoryRequirement::HostVisible | MemoryRequirement::Cached | MemoryRequirement::Coherent |
1328 memoryDeviceAddress));
1329 }
1330 catch (const tcu::NotSupportedError &)
1331 {
1332 buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1333 vk, device, allocator,
1334 makeBufferCreateInfo(bufferSizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT |
1335 VK_BUFFER_USAGE_TRANSFER_DST_BIT |
1336 VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
1337 (memoryDeviceAddress == MemoryRequirement::DeviceAddress ?
1338 VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT :
1339 0)),
1340 MemoryRequirement::HostVisible | memoryDeviceAddress));
1341 }
1342
1343 bufferDescriptors[i] = makeDescriptorBufferInfo(**buffers[i], 0, bufferSizes[i]);
1344 }
1345
1346 // Load with a stride of 0
1347 if (m_data.testType == TT_MATRIXMULADD_STRIDE0)
1348 loadStrides[0] = loadStrides[1] = loadStrides[2] = loadStrides[3] = 0;
1349
1350 void *ptrs[5];
1351 for (uint32_t i = 0; i < 5; ++i)
1352 {
1353 ptrs[i] = buffers[i]->getAllocation().getHostPtr();
1354 }
1355
1356 vk::DescriptorSetLayoutBuilder layoutBuilder;
1357
1358 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1359 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1360 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1361 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1362 layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1363
1364 vk::Unique<vk::VkDescriptorSetLayout> descriptorSetLayout(layoutBuilder.build(vk, device));
1365
1366 vk::Unique<vk::VkDescriptorPool> descriptorPool(
1367 vk::DescriptorPoolBuilder()
1368 .addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 5u)
1369 .build(vk, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u));
1370 vk::Unique<vk::VkDescriptorSet> descriptorSet(
1371 makeDescriptorSet(vk, device, *descriptorPool, *descriptorSetLayout));
1372
1373 vk::DescriptorSetUpdateBuilder setUpdateBuilder;
1374 if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
1375 {
1376 VkBufferDeviceAddressInfo info{
1377 VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO, // VkStructureType sType;
1378 DE_NULL, // const void* pNext;
1379 0, // VkBuffer buffer
1380 };
1381 VkDeviceAddress *addrsInMemory = (VkDeviceAddress *)ptrs[4];
1382 for (uint32_t i = 0; i < 4; ++i)
1383 {
1384 info.buffer = **buffers[i];
1385 VkDeviceAddress addr = vk.getBufferDeviceAddress(device, &info);
1386 addrsInMemory[i] = addr;
1387 }
1388 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(4),
1389 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[4]);
1390 }
1391 else
1392 {
1393 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(0),
1394 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[0]);
1395 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(1),
1396 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[1]);
1397 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(2),
1398 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[2]);
1399 setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(3),
1400 VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[3]);
1401 }
1402
1403 setUpdateBuilder.update(vk, device);
1404
1405 VkPipelineBindPoint bindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
1406
1407 const uint32_t specData[9] = {
1408 subgroupSize * m_data.subgroupsPerWorkgroupX,
1409 m_data.subgroupsPerWorkgroupY,
1410 strides[0],
1411 strides[1],
1412 strides[2],
1413 strides[3],
1414 M,
1415 N,
1416 K,
1417 };
1418
1419 const vk::VkSpecializationMapEntry entries[9] = {
1420 {0, (uint32_t)(sizeof(uint32_t) * 0), sizeof(uint32_t)},
1421 {1, (uint32_t)(sizeof(uint32_t) * 1), sizeof(uint32_t)},
1422 {2, (uint32_t)(sizeof(uint32_t) * 2), sizeof(uint32_t)},
1423 {3, (uint32_t)(sizeof(uint32_t) * 3), sizeof(uint32_t)},
1424 {4, (uint32_t)(sizeof(uint32_t) * 4), sizeof(uint32_t)},
1425 {5, (uint32_t)(sizeof(uint32_t) * 5), sizeof(uint32_t)},
1426 {6, (uint32_t)(sizeof(uint32_t) * 6), sizeof(uint32_t)},
1427 {7, (uint32_t)(sizeof(uint32_t) * 7), sizeof(uint32_t)},
1428 {8, (uint32_t)(sizeof(uint32_t) * 8), sizeof(uint32_t)},
1429 };
1430
1431 const vk::VkSpecializationInfo specInfo = {
1432 9, // mapEntryCount
1433 entries, // pMapEntries
1434 sizeof(specData), // dataSize
1435 specData // pData
1436 };
1437
1438 for (uint32_t i = 0; i < 4; ++i)
1439 for (uint32_t j = 0; j < totalElements[i]; ++j)
1440 {
1441 if (isFloatType(dataTypes[i]))
1442 {
1443 if (!isMatrixMulAddOp(m_data.testType))
1444 setDataFloat(ptrs[i], dataTypes[i], j,
1445 ((float)(deRandom_getUint32(&rnd) & 0xff) - 64.0f) / 2.0f);
1446 else
1447 setDataFloat(ptrs[i], dataTypes[i], j, ((float)(deRandom_getUint32(&rnd) & 0xf) - 4.0f) / 2.0f);
1448 }
1449 else
1450 {
1451 if (m_data.testType == TT_MATRIXMULADD_WRAPPING)
1452 {
1453 // Choose matrix values that should cause overflow and underflow, to
1454 // verify wrapping behavior. Use the full range of values for A and B.
1455 // For matrix C, use values clustered near where the type wraps (zero
1456 // for unsigned, 2^(N-1) for signed).
1457 uint32_t bits = componentTypeInfo[dataTypes[i]].bits;
1458 uint32_t value;
1459 if (i == 2)
1460 {
1461 value = (deRandom_getUint32(&rnd) & 0xff) - 128;
1462 if (componentTypeInfo[dataTypes[i]].isSigned)
1463 value += (1U << (bits - 1));
1464 }
1465 else
1466 {
1467 uint32_t mask = (bits == 32) ? 0xFFFFFFFFU : ((1U << bits) - 1U);
1468 value = deRandom_getUint32(&rnd) & mask;
1469 }
1470 setDataInt(ptrs[i], dataTypes[i], j, value);
1471 }
1472 else if (m_data.testType == TT_MATRIXMULADD_SATURATED)
1473 {
1474 setDataInt(ptrs[i], dataTypes[i], j, 0);
1475 }
1476 else
1477 {
1478 uint32_t value = (deRandom_getUint32(&rnd) & 0xff) - 128;
1479 setDataInt(ptrs[i], dataTypes[i], j, value);
1480 }
1481 }
1482 }
1483
1484 if (m_data.testType == TT_MATRIXMULADD_SATURATED)
1485 {
1486 // Set 1st row of A to 1,0,0...
1487 setSingleElementInt(ptrs[0], dataTypes[0], 0, dims[0].cols, (m_data.colMajor ? strides[0] : 1), 0, 1);
1488
1489 // Set 1st column of B to 1,0,0...
1490 setSingleElementInt(ptrs[1], dataTypes[1], 0, dims[1].rows, (m_data.colMajor ? 1 : strides[1]), 0, 1);
1491
1492 // Set C element at {0,0} to maximum type value, thus we will have overflow at plus operation in D=A*B+C for this element
1493 setDataInt(ptrs[2], dataTypes[2], 0, getLimit(dataTypes[2], true));
1494
1495 // Check underflow if all involved elements support negative values
1496 if (isSIntType(dataTypes[1]) && isSIntType(dataTypes[2]) && isSIntType(dataTypes[3]))
1497 {
1498 // Set 2nd row of A to 0,1,0,0...
1499 setSingleElementInt(ptrs[0], dataTypes[0], (m_data.colMajor ? 1 : strides[0]), dims[0].cols,
1500 (m_data.colMajor ? strides[0] : 1), 1, 1);
1501
1502 // Set 2nd column of B to 0,-1,0,0...
1503 setSingleElementInt(ptrs[1], dataTypes[1], (m_data.colMajor ? strides[1] : 1), dims[1].rows,
1504 (m_data.colMajor ? 1 : strides[1]), 1, -1);
1505
1506 // Set C element at {1,1} to minimum type value, thus we will have underflow at plus operation in D=A*B+C for this element
1507 setDataInt(ptrs[2], dataTypes[2], strides[2] + 1, getLimit(dataTypes[2], false));
1508 }
1509 }
1510
1511 flushAlloc(vk, device, buffers[0]->getAllocation());
1512 flushAlloc(vk, device, buffers[1]->getAllocation());
1513 flushAlloc(vk, device, buffers[2]->getAllocation());
1514 flushAlloc(vk, device, buffers[3]->getAllocation());
1515
1516 ComputePipelineWrapper pipeline(vk, device, m_data.computePipelineConstructionType,
1517 m_context.getBinaryCollection().get("test"));
1518 pipeline.setDescriptorSetLayout(descriptorSetLayout.get());
1519 pipeline.setSpecializationInfo(specInfo);
1520 pipeline.setSubgroupSize(m_data.subgroupSizeMode == SUBGROUP_SIZE_NONE ?
1521 0 :
1522 getSubgroupSizeFromMode(m_context, m_data.subgroupSizeMode));
1523 pipeline.buildPipeline();
1524
1525 const VkQueue queue = m_context.getUniversalQueue();
1526 Move<VkCommandPool> cmdPool = createCommandPool(vk, device, 0, m_context.getUniversalQueueFamilyIndex());
1527 Move<VkCommandBuffer> cmdBuffer = allocateCommandBuffer(vk, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
1528
1529 beginCommandBuffer(vk, *cmdBuffer, 0u);
1530
1531 vk.cmdBindDescriptorSets(*cmdBuffer, bindPoint, pipeline.getPipelineLayout(), 0u, 1, &*descriptorSet, 0u,
1532 DE_NULL);
1533 pipeline.bind(*cmdBuffer);
1534
1535 vk.cmdDispatch(*cmdBuffer, m_data.workgroupsX, m_data.workgroupsY, 1);
1536
1537 const VkMemoryBarrier barrier = {
1538 VK_STRUCTURE_TYPE_MEMORY_BARRIER, // sType
1539 nullptr, // pNext
1540 VK_ACCESS_SHADER_WRITE_BIT, // srcAccessMask
1541 VK_ACCESS_HOST_READ_BIT, // dstAccessMask
1542 };
1543 vk.cmdPipelineBarrier(*cmdBuffer, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_HOST_BIT,
1544 (VkDependencyFlags)0, 1, &barrier, 0, nullptr, 0, nullptr);
1545
1546 endCommandBuffer(vk, *cmdBuffer);
1547
1548 submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
1549
1550 invalidateAlloc(vk, device, buffers[3]->getAllocation());
1551
1552 qpTestResult res = QP_TEST_RESULT_PASS;
1553
1554 if (m_data.testType == TT_CONVERT)
1555 {
1556 for (uint32_t i = 0; i < totalElements[3]; ++i)
1557 {
1558 // Store results as double, which has enough range to hold all the other types exactly.
1559 double inputA, output;
1560
1561 // This loads the data according to dataTypes[0], and then converts to the template parameter type
1562 switch (dataTypes[3])
1563 {
1564 case VK_COMPONENT_TYPE_UINT8_KHR:
1565 inputA = getDataConvertedToT<uint8_t>(ptrs[0], dataTypes[0], i);
1566 break;
1567 case VK_COMPONENT_TYPE_UINT16_KHR:
1568 inputA = getDataConvertedToT<uint16_t>(ptrs[0], dataTypes[0], i);
1569 break;
1570 case VK_COMPONENT_TYPE_UINT32_KHR:
1571 inputA = getDataConvertedToT<uint32_t>(ptrs[0], dataTypes[0], i);
1572 break;
1573 case VK_COMPONENT_TYPE_SINT8_KHR:
1574 inputA = getDataConvertedToT<int8_t>(ptrs[0], dataTypes[0], i);
1575 break;
1576 case VK_COMPONENT_TYPE_SINT16_KHR:
1577 inputA = getDataConvertedToT<int16_t>(ptrs[0], dataTypes[0], i);
1578 break;
1579 case VK_COMPONENT_TYPE_SINT32_KHR:
1580 inputA = getDataConvertedToT<int32_t>(ptrs[0], dataTypes[0], i);
1581 break;
1582 case VK_COMPONENT_TYPE_FLOAT32_KHR:
1583 inputA = getDataConvertedToT<float>(ptrs[0], dataTypes[0], i);
1584 break;
1585 case VK_COMPONENT_TYPE_FLOAT16_KHR:
1586 {
1587 float temp = getDataConvertedToT<float>(ptrs[0], dataTypes[0], i);
1588 inputA = tcu::Float16(temp).asDouble();
1589 break;
1590 }
1591 default:
1592 TCU_THROW(InternalError, "Unexpected type");
1593 }
1594
1595 switch (dataTypes[3])
1596 {
1597 case VK_COMPONENT_TYPE_UINT8_KHR:
1598 output = getDataConvertedToT<uint8_t>(ptrs[3], dataTypes[3], i);
1599 break;
1600 case VK_COMPONENT_TYPE_UINT16_KHR:
1601 output = getDataConvertedToT<uint16_t>(ptrs[3], dataTypes[3], i);
1602 break;
1603 case VK_COMPONENT_TYPE_UINT32_KHR:
1604 output = getDataConvertedToT<uint32_t>(ptrs[3], dataTypes[3], i);
1605 break;
1606 case VK_COMPONENT_TYPE_SINT8_KHR:
1607 output = getDataConvertedToT<int8_t>(ptrs[3], dataTypes[3], i);
1608 break;
1609 case VK_COMPONENT_TYPE_SINT16_KHR:
1610 output = getDataConvertedToT<int16_t>(ptrs[3], dataTypes[3], i);
1611 break;
1612 case VK_COMPONENT_TYPE_SINT32_KHR:
1613 output = getDataConvertedToT<int32_t>(ptrs[3], dataTypes[3], i);
1614 break;
1615 case VK_COMPONENT_TYPE_FLOAT32_KHR:
1616 output = getDataConvertedToT<float>(ptrs[3], dataTypes[3], i);
1617 break;
1618 case VK_COMPONENT_TYPE_FLOAT16_KHR:
1619 {
1620 float temp = getDataConvertedToT<float>(ptrs[3], dataTypes[3], i);
1621 output = tcu::Float16(temp).asDouble();
1622 break;
1623 }
1624 default:
1625 TCU_THROW(InternalError, "Unexpected type");
1626 }
1627
1628 if (inputA != output)
1629 {
1630 res = QP_TEST_RESULT_FAIL;
1631 break;
1632 }
1633 }
1634 }
1635 else if (isFloatType(dataTypes[0]))
1636 {
1637 if (!isMatrixMulAddOp(m_data.testType))
1638 {
1639 for (uint32_t i = 0; i < totalElements[3]; ++i)
1640 {
1641 float inputA = getDataFloat(ptrs[0], dataTypes[0], i);
1642 float inputB = getDataFloat(ptrs[1], dataTypes[1], i);
1643 float output = getDataFloat(ptrs[3], dataTypes[3], i);
1644 switch (m_data.testType)
1645 {
1646 case TT_LENGTH:
1647 if (output < 1.0f || output > (float)(N * M))
1648 res = QP_TEST_RESULT_FAIL;
1649 // We expect the matrix to be spread evenly across invocations, it is
1650 // surprising (but not necessarily illegal) if not
1651 if (output != (float)(N * M / subgroupSize) && res == QP_TEST_RESULT_PASS)
1652 res = QP_TEST_RESULT_QUALITY_WARNING;
1653 break;
1654 case TT_CONSTANT:
1655 if (output != 1.0f)
1656 res = QP_TEST_RESULT_FAIL;
1657 break;
1658 case TT_COMPOSITE:
1659 case TT_COMPOSITE_RVALUE:
1660 case TT_COMPOSITE_ARRAY:
1661 case TT_ADD:
1662 if (output != inputA + inputB)
1663 res = QP_TEST_RESULT_FAIL;
1664 break;
1665 case TT_SUB:
1666 if (output != inputA - inputB)
1667 res = QP_TEST_RESULT_FAIL;
1668 break;
1669 case TT_DIV:
1670 {
1671 float ulp = (m_data.inputType == VK_COMPONENT_TYPE_FLOAT16_KHR) ?
1672 1.0f / 1024.0f :
1673 1.0f / (8.0f * 1024.0f * 1024.0f);
1674 // division allows 2.5ulp, but we'll use 3.
1675 ulp *= 3;
1676 if (inputB != 0 && fabs(output - inputA / inputB) > ulp * fabs(inputA / inputB))
1677 res = QP_TEST_RESULT_FAIL;
1678 }
1679 break;
1680 case TT_MUL:
1681 {
1682 if (dataTypes[0] == VK_COMPONENT_TYPE_FLOAT16_KHR)
1683 {
1684 const float expected32 = inputA * inputB;
1685 const tcu::float16_t expected16 = tcu::Float16(expected32).bits();
1686 const float expected = tcu::Float16(expected16).asFloat();
1687
1688 if (output != expected)
1689 res = QP_TEST_RESULT_FAIL;
1690 }
1691 else
1692 {
1693 if (output != inputA * inputB)
1694 res = QP_TEST_RESULT_FAIL;
1695 }
1696 break;
1697 }
1698 case TT_NEGATE:
1699 case TT_FUNC:
1700 if (output != -inputA)
1701 res = QP_TEST_RESULT_FAIL;
1702 break;
1703 case TT_MATRIXTIMESSCALAR:
1704 if (output != 6.0 * inputA)
1705 res = QP_TEST_RESULT_FAIL;
1706 break;
1707 case TT_MULTICOMPONENT_LOAD:
1708 {
1709 if (output != inputA)
1710 res = QP_TEST_RESULT_FAIL;
1711 break;
1712 }
1713 case TT_MULTICOMPONENT_SAVE:
1714 {
1715 if (output != inputA)
1716 res = QP_TEST_RESULT_FAIL;
1717 break;
1718 }
1719 default:
1720 TCU_THROW(InternalError, "Unimplemented");
1721 }
1722 }
1723 }
1724 else
1725 {
1726 uint32_t ik, kj, ij;
1727 for (uint32_t mX = 0; mX < m_data.subgroupsPerWorkgroupX * m_data.workgroupsX; ++mX)
1728 {
1729 for (uint32_t mY = 0; mY < m_data.subgroupsPerWorkgroupY * m_data.workgroupsY; ++mY)
1730 {
1731 for (uint32_t i = 0; i < M; ++i)
1732 {
1733 for (uint32_t j = 0; j < N; ++j)
1734 {
1735 float ref = 0;
1736 for (uint32_t k = 0; k < K; ++k)
1737 {
1738 if (m_data.colMajor)
1739 ik = mX * M + i + strides[0] * mY * K + loadStrides[0] * k;
1740 else
1741 ik = mX * K + k + strides[0] * mY * M + loadStrides[0] * i;
1742
1743 float Aik = getDataFloat(ptrs[0], dataTypes[0], ik);
1744
1745 if (m_data.colMajor)
1746 kj = mX * K + k + strides[1] * mY * N + loadStrides[1] * j;
1747 else
1748 kj = mX * N + j + strides[1] * mY * K + loadStrides[1] * k;
1749
1750 float Bkj = getDataFloat(ptrs[1], dataTypes[1], kj);
1751
1752 ref += Aik * Bkj;
1753 }
1754
1755 if (m_data.colMajor)
1756 ij = mX * M + i + strides[2] * mY * N + loadStrides[2] * j;
1757 else
1758 ij = mX * N + j + strides[2] * mY * M + loadStrides[2] * i;
1759
1760 float Cij = getDataFloat(ptrs[2], dataTypes[2], ij);
1761
1762 ref += Cij;
1763
1764 // When loading with stride 0, ij for matrix D is different from matrix C
1765 if (m_data.colMajor)
1766 ij = mX * M + i + strides[2] * (mY * N + j);
1767 else
1768 ij = mX * N + j + strides[2] * (mY * M + i);
1769
1770 float Dij = getDataFloat(ptrs[3], dataTypes[3], ij);
1771
1772 if (fabs(ref - Dij) > epsilon)
1773 {
1774 res = QP_TEST_RESULT_FAIL;
1775 }
1776 }
1777 }
1778 }
1779 }
1780 }
1781 }
1782 else
1783 {
1784 if (!isMatrixMulAddOp(m_data.testType))
1785 {
1786 for (uint32_t i = 0; i < totalElements[3]; ++i)
1787 {
1788 uint32_t inputA = getDataInt(ptrs[0], dataTypes[0], i);
1789 uint32_t inputB = getDataInt(ptrs[1], dataTypes[1], i);
1790 uint32_t output = getDataInt(ptrs[3], dataTypes[3], i);
1791 int resultSize = componentTypeInfo[dataTypes[3]].bits;
1792 uint32_t mask = resultSize == 32 ? ~0 : ((1 << resultSize) - 1);
1793 switch (m_data.testType)
1794 {
1795 case TT_LENGTH:
1796 if (output < 1 || output > N * M)
1797 res = QP_TEST_RESULT_FAIL;
1798 // We expect the matrix to be spread evenly across invocations, it is
1799 // surprising (but not necessarily illegal) if not
1800 if (output != N * M / subgroupSize && res == QP_TEST_RESULT_PASS)
1801 res = QP_TEST_RESULT_QUALITY_WARNING;
1802 break;
1803 case TT_CONSTANT:
1804 if (output != 1)
1805 res = QP_TEST_RESULT_FAIL;
1806 break;
1807 case TT_COMPOSITE:
1808 case TT_COMPOSITE_RVALUE:
1809 case TT_COMPOSITE_ARRAY:
1810 case TT_ADD:
1811 if ((output & mask) != ((inputA + inputB) & mask))
1812 {
1813 res = QP_TEST_RESULT_FAIL;
1814 }
1815 break;
1816 case TT_SUB:
1817 if ((output & mask) != ((inputA - inputB) & mask))
1818 res = QP_TEST_RESULT_FAIL;
1819 break;
1820 case TT_DIV:
1821 {
1822 if (isSIntType(dataTypes[3]))
1823 {
1824 if (inputB != 0 && ((int32_t)output & mask) != (((int32_t)inputA / (int32_t)inputB) & mask))
1825 res = QP_TEST_RESULT_FAIL;
1826 }
1827 else
1828 {
1829 if (inputB != 0 && output != inputA / inputB)
1830 res = QP_TEST_RESULT_FAIL;
1831 }
1832 }
1833 break;
1834 case TT_MUL:
1835 {
1836 if (((int32_t)output & mask) != (((int32_t)inputA * (int32_t)inputB) & mask))
1837 {
1838 res = QP_TEST_RESULT_FAIL;
1839 }
1840
1841 break;
1842 }
1843 case TT_NEGATE:
1844 case TT_FUNC:
1845 if ((output & mask) != ((-(int32_t)inputA) & mask))
1846 res = QP_TEST_RESULT_FAIL;
1847 break;
1848 case TT_MATRIXTIMESSCALAR:
1849 if ((output & mask) != ((6 * inputA) & mask))
1850 {
1851 res = QP_TEST_RESULT_FAIL;
1852 }
1853 break;
1854 case TT_MULTICOMPONENT_LOAD:
1855 {
1856 if (output != inputA)
1857 res = QP_TEST_RESULT_FAIL;
1858 break;
1859 }
1860 case TT_MULTICOMPONENT_SAVE:
1861 {
1862 if (output != inputA)
1863 res = QP_TEST_RESULT_FAIL;
1864 break;
1865 }
1866 default:
1867 TCU_THROW(InternalError, "Unimplemented");
1868 }
1869 }
1870 }
1871 else
1872 {
1873 uint32_t ik, kj, ij;
1874 for (uint32_t mX = 0; mX < m_data.subgroupsPerWorkgroupX * m_data.workgroupsX; ++mX)
1875 {
1876 for (uint32_t mY = 0; mY < m_data.subgroupsPerWorkgroupY * m_data.workgroupsY; ++mY)
1877 {
1878 for (uint32_t i = 0; i < M; ++i)
1879 {
1880 for (uint32_t j = 0; j < N; ++j)
1881 {
1882 uint32_t ref = 0;
1883
1884 for (uint32_t k = 0; k < K; ++k)
1885 {
1886 if (m_data.colMajor)
1887 ik = mX * M + i + strides[0] * mY * K + loadStrides[0] * k;
1888 else
1889 ik = mX * K + k + strides[0] * mY * M + loadStrides[0] * i;
1890
1891 uint32_t Aik = getDataInt(ptrs[0], dataTypes[0], ik);
1892
1893 if (m_data.colMajor)
1894 kj = mX * K + k + strides[1] * mY * N + loadStrides[1] * j;
1895 else
1896 kj = mX * N + j + strides[1] * mY * K + loadStrides[1] * k;
1897
1898 uint32_t Bkj = getDataInt(ptrs[1], dataTypes[1], kj);
1899
1900 ref += Aik * Bkj;
1901 }
1902
1903 if (m_data.colMajor)
1904 ij = mX * M + i + strides[2] * mY * N + loadStrides[2] * j;
1905 else
1906 ij = mX * N + j + strides[2] * mY * M + loadStrides[2] * i;
1907
1908 uint32_t Cij = getDataInt(ptrs[2], dataTypes[2], ij);
1909
1910 if (saturated)
1911 {
1912 ref = satAddData(dataTypes[2], ref, Cij);
1913 }
1914 else
1915 {
1916 ref += Cij;
1917 // truncate the result to the size of C's type.
1918 uint32_t bits = componentTypeInfo[dataTypes[3]].bits;
1919 uint32_t mask = (bits == 32) ? 0xFFFFFFFFU : ((1U << bits) - 1U);
1920 ref &= mask;
1921 }
1922
1923 // When loading with stride 0, ij for matrix D is different from matrix C
1924 if (m_data.colMajor)
1925 ij = mX * M + i + strides[2] * (mY * N + j);
1926 else
1927 ij = mX * N + j + strides[2] * (mY * M + i);
1928
1929 uint32_t Dij = getDataInt(ptrs[3], dataTypes[3], ij);
1930
1931 if (ref != Dij)
1932 {
1933 res = QP_TEST_RESULT_FAIL;
1934 }
1935 }
1936 }
1937 }
1938 }
1939 }
1940 }
1941
1942 if (res != QP_TEST_RESULT_PASS)
1943 {
1944 finalres = res;
1945
1946 log << tcu::TestLog::Message << "failed with M = " << M << ", N = " << N << ", K = " << K
1947 << tcu::TestLog::EndMessage;
1948
1949 #ifdef COOPERATIVE_MATRIX_EXTENDED_DEBUG
1950 for (int i = 0; i < 4; i++)
1951 {
1952 const char *matrixNames[] = {"A", "B", "C", "D"};
1953
1954 log << tcu::TestLog::Message << "Matrix " << matrixNames[i]
1955 << "[rows=" << m_data.subgroupsPerWorkgroupY * m_data.workgroupsY * dims[i].rows
1956 << ", cols=" << m_data.subgroupsPerWorkgroupX * m_data.workgroupsX * dims[i].cols << "]:\n"
1957 << dumpWholeMatrix(ptrs[i], dataTypes[i], m_data.colMajor, totalElements[i], strides[i])
1958 << tcu::TestLog::EndMessage;
1959 }
1960 #endif
1961 }
1962 else
1963 {
1964 if (finalres == QP_TEST_RESULT_NOT_SUPPORTED)
1965 finalres = res;
1966 }
1967 }
1968
1969 return tcu::TestStatus(finalres, qpGetTestResultName(finalres));
1970 }
1971
getUseType(UseType useType)1972 const char *getUseType(UseType useType)
1973 {
1974 switch (useType)
1975 {
1976 case UT_NV:
1977 return "nv";
1978 case UT_KHR_A:
1979 return "khr_a";
1980 case UT_KHR_B:
1981 return "khr_b";
1982 case UT_KHR_Result:
1983 return "khr_r";
1984 default:
1985 TCU_THROW(InternalError, "Unknown use type");
1986 }
1987 }
1988
createCooperativeMatrixTestsInternal(tcu::TestContext & testCtx,vk::ComputePipelineConstructionType computePipelineConstructionType,UseType useType)1989 tcu::TestCaseGroup *createCooperativeMatrixTestsInternal(
1990 tcu::TestContext &testCtx, vk::ComputePipelineConstructionType computePipelineConstructionType, UseType useType)
1991 {
1992 de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, getUseType(useType)));
1993
1994 typedef struct
1995 {
1996 uint32_t value;
1997 const char *name;
1998 } TestGroupCase;
1999
2000 typedef struct
2001 {
2002 uint32_t value[2];
2003 const char *name;
2004 } TestGroupCase2;
2005
2006 typedef struct
2007 {
2008 SubgroupSizeMode value;
2009 const char *name;
2010 } SubGroubSizes;
2011
2012 typedef struct
2013 {
2014 const char *name;
2015 const char *description;
2016 uint32_t componentCount;
2017 } MulticomponentTypes;
2018
2019 typedef struct
2020 {
2021 const char *name;
2022 const char *description;
2023 TestType testType;
2024 } IOTypes;
2025
2026 TestGroupCase ttCases[] = {
2027 // OpCooperativeMatrixLength
2028 {TT_LENGTH, "length"},
2029 // OpConstantComposite
2030 {TT_CONSTANT, "constant"},
2031 // OpCompositeConstruct
2032 {TT_COMPOSITE, "composite"},
2033 // OpCompositeExtract
2034 {TT_COMPOSITE_RVALUE, "composite_rvalue"},
2035 // OpFAdd/OpIAdd
2036 {TT_ADD, "add"},
2037 // OpFSub/OpISub
2038 {TT_SUB, "sub"},
2039 // OpFDiv/OpSDiv/OpUDiv
2040 {TT_DIV, "div"},
2041 // OpFMul/OpIMul
2042 {TT_MUL, "mul"},
2043 // OpFNegate/OpSNegate
2044 {TT_NEGATE, "negate"},
2045 // OpMatrixTimesScalar
2046 {TT_MATRIXTIMESSCALAR, "matrixtimesscalar"},
2047 // OpFunctionParameter
2048 {TT_FUNC, "func"},
2049 // OpCooperativeMatrixMulAdd
2050 {TT_MATRIXMULADD, "matrixmuladd"},
2051 // OpCompositeConstruct w/array
2052 {TT_COMPOSITE_ARRAY, "composite_array"},
2053 // OpCooperativeMatrixMulAdd w/array
2054 {TT_MATRIXMULADD_ARRAY, "matrixmuladd_array"},
2055 // OpCooperativeMatrixMulAdd w/saturations
2056 {TT_MATRIXMULADD_SATURATED, "matrixmuladd_saturated"},
2057 // OpCooperativeMatrixMulAdd w/wrapping
2058 {TT_MATRIXMULADD_WRAPPING, "matrixmuladd_wrapping"},
2059 // OpCooperativeMatrixMulAdd w/stride==0
2060 {TT_MATRIXMULADD_STRIDE0, "matrixmuladd_stride0"},
2061 };
2062 TestGroupCase2 dtCases[] = {
2063 // A/B are fp32 C/D are fp32
2064 {{VK_COMPONENT_TYPE_FLOAT32_KHR, VK_COMPONENT_TYPE_FLOAT32_KHR}, "float32_float32"},
2065 // A/B are fp32 C/D are fp16
2066 {{VK_COMPONENT_TYPE_FLOAT32_KHR, VK_COMPONENT_TYPE_FLOAT16_KHR}, "float32_float16"},
2067 // A/B are fp16 C/D are fp32
2068 {{VK_COMPONENT_TYPE_FLOAT16_KHR, VK_COMPONENT_TYPE_FLOAT32_KHR}, "float16_float32"},
2069 // A/B are fp16 C/D are fp16
2070 {{VK_COMPONENT_TYPE_FLOAT16_KHR, VK_COMPONENT_TYPE_FLOAT16_KHR}, "float16_float16"},
2071 // A/B are u8 C/D are u8
2072 {{VK_COMPONENT_TYPE_UINT8_KHR, VK_COMPONENT_TYPE_UINT8_KHR}, "uint8_uint8"},
2073 // A/B are u8 C/D are u32
2074 {{VK_COMPONENT_TYPE_UINT8_KHR, VK_COMPONENT_TYPE_UINT32_KHR}, "uint8_uint32"},
2075 // A/B are s8 C/D are s8
2076 {{VK_COMPONENT_TYPE_SINT8_KHR, VK_COMPONENT_TYPE_SINT8_KHR}, "sint8_sint8"},
2077 // A/B are s8 C/D are s32
2078 {{VK_COMPONENT_TYPE_SINT8_KHR, VK_COMPONENT_TYPE_SINT32_KHR}, "sint8_sint32"},
2079 // A/B are u8 C/D are s32
2080 {{VK_COMPONENT_TYPE_UINT8_KHR, VK_COMPONENT_TYPE_SINT32_KHR}, "uint8_sint32"},
2081 // A/B are u32 C/D are u32
2082 {{VK_COMPONENT_TYPE_UINT32_KHR, VK_COMPONENT_TYPE_UINT32_KHR}, "uint32_uint32"},
2083 // A/B are u32 C/D are u8
2084 {{VK_COMPONENT_TYPE_UINT32_KHR, VK_COMPONENT_TYPE_UINT8_KHR}, "uint32_uint8"},
2085 // A/B are s32 C/D are s32
2086 {{VK_COMPONENT_TYPE_SINT32_KHR, VK_COMPONENT_TYPE_SINT32_KHR}, "sint32_sint32"},
2087 // A/B are s32 C/D are s8
2088 {{VK_COMPONENT_TYPE_SINT32_KHR, VK_COMPONENT_TYPE_SINT8_KHR}, "sint32_sint8"},
2089 };
2090 SubGroubSizes sgsCases[] = {
2091 // Default subgroup size
2092 {SUBGROUP_SIZE_NONE, ""},
2093 // Minimum subgroup size
2094 {SUBGROUP_SIZE_MIN, "_min"},
2095 // Maximum subgroup size
2096 {SUBGROUP_SIZE_MAX, "_max"},
2097 };
2098
2099 TestGroupCase colCases[] = {
2100 {0, "rowmajor"},
2101 {1, "colmajor"},
2102 };
2103
2104 TestGroupCase scCases[] = {
2105 // SSBO
2106 {SC_BUFFER, "buffer"},
2107 // shared memory
2108 {SC_WORKGROUP, "workgroup"},
2109 // SSBO w/variable pointers
2110 {SC_BUFFER_VARIABLE_POINTERS, "buffer_varptr"},
2111 // shared memory w/variable pointers
2112 {SC_WORKGROUP_VARIABLE_POINTERS, "workgroup_varptr"},
2113 // physical_storage_buffer
2114 {SC_PHYSICAL_STORAGE_BUFFER, "physical_buffer"},
2115 };
2116
2117 // Types tested for conversions. Excludes 64b types.
2118 VkComponentTypeKHR allTypes[] = {
2119 VK_COMPONENT_TYPE_FLOAT16_KHR, VK_COMPONENT_TYPE_FLOAT32_KHR, VK_COMPONENT_TYPE_SINT8_KHR,
2120 VK_COMPONENT_TYPE_SINT16_KHR, VK_COMPONENT_TYPE_SINT32_KHR, VK_COMPONENT_TYPE_UINT8_KHR,
2121 VK_COMPONENT_TYPE_UINT16_KHR, VK_COMPONENT_TYPE_UINT32_KHR,
2122 };
2123
2124 // Types tested for load/store from/into multicomponent types
2125 MulticomponentTypes multicomponentTypes[] = {
2126 {"vec2", "2-component vector type as input or output", 2},
2127 {"vec4", "4-component vector type as input or output", 4},
2128 };
2129
2130 // Types tested for load/store from/into multicomponent types
2131 IOTypes ioTypes[] = {
2132 {"load", "Test multicomponent type as input in load operation", TT_MULTICOMPONENT_LOAD},
2133 {"save", "Test multicomponent type as output in store operation", TT_MULTICOMPONENT_SAVE},
2134 };
2135
2136 for (int ttNdx = 0; ttNdx < DE_LENGTH_OF_ARRAY(ttCases); ttNdx++)
2137 {
2138 const TestType testType = (TestType)ttCases[ttNdx].value;
2139
2140 for (int sgsNdx = 0; sgsNdx < DE_LENGTH_OF_ARRAY(sgsCases); sgsNdx++)
2141 {
2142 if (testType != TT_MATRIXMULADD && sgsCases[sgsNdx].value != SUBGROUP_SIZE_NONE)
2143 continue;
2144
2145 if (testType == TT_MATRIXMULADD && sgsCases[sgsNdx].value != SUBGROUP_SIZE_NONE && useType == UT_NV)
2146 continue;
2147
2148 const string name = string(ttCases[ttNdx].name) + sgsCases[sgsNdx].name;
2149 de::MovePtr<tcu::TestCaseGroup> ttGroup(new tcu::TestCaseGroup(testCtx, name.c_str()));
2150
2151 for (int dtNdx = 0; dtNdx < DE_LENGTH_OF_ARRAY(dtCases); dtNdx++)
2152 {
2153 de::MovePtr<tcu::TestCaseGroup> dtGroup(new tcu::TestCaseGroup(testCtx, dtCases[dtNdx].name));
2154 for (int scNdx = 0; scNdx < DE_LENGTH_OF_ARRAY(scCases); scNdx++)
2155 {
2156 de::MovePtr<tcu::TestCaseGroup> scGroup(new tcu::TestCaseGroup(testCtx, scCases[scNdx].name));
2157 for (int colNdx = 0; colNdx < DE_LENGTH_OF_ARRAY(colCases); colNdx++)
2158 {
2159 const VkComponentTypeKHR inputType = (VkComponentTypeKHR)dtCases[dtNdx].value[0];
2160 const VkComponentTypeKHR outputType = (VkComponentTypeKHR)dtCases[dtNdx].value[1];
2161 const bool isMatrixMul = isMatrixMulAddOp(testType);
2162
2163 // useType isn't used for matrixmul shaders. Don't generate 3 copies of those tests.
2164 if (isMatrixMul && (useType == UT_KHR_A || useType == UT_KHR_B))
2165 {
2166 continue;
2167 }
2168
2169 // NV extension doesn't support mixing signedness
2170 if (isMatrixMul && (useType == UT_NV) && isSIntType(inputType) != isSIntType(outputType))
2171 {
2172 continue;
2173 }
2174
2175 if (!isMatrixMul && inputType != outputType)
2176 continue;
2177
2178 if (isMatrixMul && componentTypeInfo[inputType].bits > componentTypeInfo[outputType].bits)
2179 continue;
2180
2181 if (testType == TT_MUL && useType == UT_NV)
2182 continue;
2183
2184 if (testType == TT_MATRIXMULADD_SATURATED && (isFloatType(inputType) || useType == UT_NV))
2185 continue;
2186
2187 if (testType == TT_MATRIXMULADD_WRAPPING && (isFloatType(inputType) || useType == UT_NV))
2188 continue;
2189
2190 if (testType == TT_MATRIXMULADD_STRIDE0 && useType == UT_NV)
2191 continue;
2192
2193 if (testType == TT_LENGTH && useType != UT_NV &&
2194 (outputType == VK_COMPONENT_TYPE_SINT8_KHR || outputType == VK_COMPONENT_TYPE_UINT8_KHR))
2195 continue;
2196
2197 CaseDef c = {
2198 testType, // TestType testtype;
2199 2u, // uint32_t subgroupsPerWorkgroupX;
2200 2u, // uint32_t subgroupsPerWorkgroupY;
2201 4u, // uint32_t workgroupsX;
2202 4u, // uint32_t workgroupsY;
2203 inputType, // VkComponentTypeKHR inputType;
2204 outputType, // VkComponentTypeKHR outputType;
2205 !!colCases[colNdx].value, // bool colMajor;
2206 (StorageClass)scCases[scNdx].value, // StorageClass storageClass;
2207 useType, // UseType useType;
2208 sgsCases[sgsNdx].value, // SubgroupSizeMode subgroupSizeMode;
2209 computePipelineConstructionType, // vk::ComputePipelineConstructionType computePipelineConstructionType;
2210 1, // uint32_t inputComponentCount;
2211 1, // uint32_t outputComponentCount;
2212 };
2213
2214 scGroup->addChild(new CooperativeMatrixTestCase(testCtx, colCases[colNdx].name, c));
2215 }
2216 dtGroup->addChild(scGroup.release());
2217 }
2218 ttGroup->addChild(dtGroup.release());
2219 }
2220 group->addChild(ttGroup.release());
2221 }
2222 }
2223
2224 {
2225 const string name = string("convert");
2226 const string desc = string("OpFConvert/OpSConvert/OpUConvert/OpBitcast");
2227 de::MovePtr<tcu::TestCaseGroup> ttGroup(new tcu::TestCaseGroup(testCtx, name.c_str()));
2228
2229 for (int dtNdx1 = 0; dtNdx1 < DE_LENGTH_OF_ARRAY(allTypes); dtNdx1++)
2230 {
2231 for (int dtNdx2 = 0; dtNdx2 < DE_LENGTH_OF_ARRAY(allTypes); dtNdx2++)
2232 {
2233 const VkComponentTypeKHR inputType = (VkComponentTypeKHR)allTypes[dtNdx1];
2234 const VkComponentTypeKHR outputType = (VkComponentTypeKHR)allTypes[dtNdx2];
2235 const string name2 = string("input_") + string(componentTypeInfo[inputType].typeName) +
2236 string("_output_") + string(componentTypeInfo[outputType].typeName);
2237 de::MovePtr<tcu::TestCaseGroup> dtGroup(new tcu::TestCaseGroup(testCtx, name2.c_str()));
2238 for (int scNdx = 0; scNdx < DE_LENGTH_OF_ARRAY(scCases); scNdx++)
2239 {
2240 de::MovePtr<tcu::TestCaseGroup> scGroup(new tcu::TestCaseGroup(testCtx, scCases[scNdx].name));
2241 for (int colNdx = 0; colNdx < DE_LENGTH_OF_ARRAY(colCases); colNdx++)
2242 {
2243
2244 CaseDef c = {
2245 TT_CONVERT, // TestType testtype;
2246 2u, // uint32_t subgroupsPerWorkgroupX;
2247 2u, // uint32_t subgroupsPerWorkgroupY;
2248 4u, // uint32_t workgroupsX;
2249 4u, // uint32_t workgroupsY;
2250 inputType, // VkComponentTypeKHR inputType;
2251 outputType, // VkComponentTypeKHR outputType;
2252 !!colCases[colNdx].value, // bool colMajor;
2253 (StorageClass)scCases[scNdx].value, // StorageClass storageClass;
2254 useType, // UseType useType;
2255 SUBGROUP_SIZE_NONE, // SubgroupSizeMode subgroupSizeMode;
2256 computePipelineConstructionType, // vk::ComputePipelineConstructionType computePipelineConstructionType;
2257 1, // uint32_t inputComponentCount;
2258 1, // uint32_t outputComponentCount;
2259 };
2260
2261 scGroup->addChild(new CooperativeMatrixTestCase(testCtx, colCases[colNdx].name, c));
2262 }
2263 dtGroup->addChild(scGroup.release());
2264 }
2265 ttGroup->addChild(dtGroup.release());
2266 }
2267 }
2268 group->addChild(ttGroup.release());
2269 }
2270
2271 if (useType != UT_NV)
2272 {
2273 de::MovePtr<tcu::TestCaseGroup> ttGroup(
2274 new tcu::TestCaseGroup(testCtx, "multicomponent", "Multicomponent types tests"));
2275 for (int ctNdx = 0; ctNdx < DE_LENGTH_OF_ARRAY(multicomponentTypes); ctNdx++)
2276 {
2277 de::MovePtr<tcu::TestCaseGroup> ctGroup(new tcu::TestCaseGroup(testCtx, multicomponentTypes[ctNdx].name,
2278 multicomponentTypes[ctNdx].description));
2279 const uint32_t componentCount = multicomponentTypes[ctNdx].componentCount;
2280
2281 for (int ioNdx = 0; ioNdx < DE_LENGTH_OF_ARRAY(ioTypes); ioNdx++)
2282 {
2283 de::MovePtr<tcu::TestCaseGroup> ioGroup(
2284 new tcu::TestCaseGroup(testCtx, ioTypes[ioNdx].name, ioTypes[ioNdx].description));
2285 const TestType testType = ioTypes[ioNdx].testType;
2286 const uint32_t inputComponentCount = testType == TT_MULTICOMPONENT_LOAD ? componentCount : 1;
2287 const uint32_t outputComponentCount = testType == TT_MULTICOMPONENT_LOAD ? 1 : componentCount;
2288
2289 for (int dtNdx = 0; dtNdx < DE_LENGTH_OF_ARRAY(allTypes); dtNdx++)
2290 {
2291 const VkComponentTypeKHR inputType = allTypes[dtNdx];
2292 const string name = componentTypeInfo[inputType].typeName;
2293
2294 de::MovePtr<tcu::TestCaseGroup> dtGroup(new tcu::TestCaseGroup(testCtx, name.c_str(), ""));
2295 for (int scNdx = 0; scNdx < DE_LENGTH_OF_ARRAY(scCases); scNdx++)
2296 {
2297 de::MovePtr<tcu::TestCaseGroup> scGroup(
2298 new tcu::TestCaseGroup(testCtx, scCases[scNdx].name, ""));
2299 for (int colNdx = 0; colNdx < DE_LENGTH_OF_ARRAY(colCases); colNdx++)
2300 {
2301 CaseDef c = {
2302 testType, // TestType testtype;
2303 2u, // uint32_t subgroupsPerWorkgroupX;
2304 2u, // uint32_t subgroupsPerWorkgroupY;
2305 4u, // uint32_t workgroupsX;
2306 4u, // uint32_t workgroupsY;
2307 inputType, // VkComponentTypeKHR inputType;
2308 inputType, // VkComponentTypeKHR outputType;
2309 !!colCases[colNdx].value, // bool colMajor;
2310 (StorageClass)scCases[scNdx].value, // StorageClass storageClass;
2311 useType, // UseType useType;
2312 SUBGROUP_SIZE_NONE, // SubgroupSizeMode subgroupSizeMode;
2313 computePipelineConstructionType, // vk::ComputePipelineConstructionType computePipelineConstructionType;
2314 inputComponentCount, // uint32_t inputComponentCount;
2315 outputComponentCount, // uint32_t outputComponentCount;
2316 };
2317
2318 scGroup->addChild(new CooperativeMatrixTestCase(testCtx, colCases[colNdx].name, c));
2319 }
2320 dtGroup->addChild(scGroup.release());
2321 }
2322 ioGroup->addChild(dtGroup.release());
2323 }
2324 ctGroup->addChild(ioGroup.release());
2325 }
2326 ttGroup->addChild(ctGroup.release());
2327 }
2328 group->addChild(ttGroup.release());
2329 }
2330
2331 return group.release();
2332 }
2333
2334 } // namespace
2335
createCooperativeMatrixTests(tcu::TestContext & testCtx,vk::ComputePipelineConstructionType computePipelineConstructionType)2336 tcu::TestCaseGroup *createCooperativeMatrixTests(tcu::TestContext &testCtx,
2337 vk::ComputePipelineConstructionType computePipelineConstructionType)
2338 {
2339 de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "cooperative_matrix"));
2340
2341 group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_NV));
2342 group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_KHR_A));
2343 group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_KHR_B));
2344 group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_KHR_Result));
2345
2346 return group.release();
2347 }
2348
2349 } // namespace compute
2350 } // namespace vkt
2351