1 /*------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2019 The Khronos Group Inc.
6  * Copyright (c) 2018-2019 NVIDIA Corporation
7  * Copyright (c) 2023 LunarG, Inc.
8  * Copyright (c) 2023 Nintendo
9  *
10  * Licensed under the Apache License, Version 2.0 (the "License");
11  * you may not use this file except in compliance with the License.
12  * You may obtain a copy of the License at
13  *
14  *      http://www.apache.org/licenses/LICENSE-2.0
15  *
16  * Unless required by applicable law or agreed to in writing, software
17  * distributed under the License is distributed on an "AS IS" BASIS,
18  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19  * See the License for the specific language governing permissions and
20  * limitations under the License.
21  *
22  *//*!
23  * \file
24  * \brief Vulkan Cooperative Matrix tests
25  *//*--------------------------------------------------------------------*/
26 
27 #include "vktComputeCooperativeMatrixTests.hpp"
28 
29 #include "vkBufferWithMemory.hpp"
30 #include "vkImageWithMemory.hpp"
31 #include "vkQueryUtil.hpp"
32 #include "vkBuilderUtil.hpp"
33 #include "vkCmdUtil.hpp"
34 #include "vkTypeUtil.hpp"
35 #include "vkObjUtil.hpp"
36 
37 #include "vktTestGroupUtil.hpp"
38 #include "vktTestCase.hpp"
39 
40 #include "deDefs.h"
41 #include "tcuFloat.hpp"
42 #include "deMath.h"
43 #include "deRandom.h"
44 #include "deSharedPtr.hpp"
45 #include "deString.h"
46 
47 #include "tcuTestCase.hpp"
48 #include "tcuTestLog.hpp"
49 
50 #include <string>
51 #include <sstream>
52 #include <set>
53 #include <algorithm>
54 
55 namespace vkt
56 {
57 namespace compute
58 {
59 namespace
60 {
61 using namespace vk;
62 using namespace std;
63 
64 //#define COOPERATIVE_MATRIX_EXTENDED_DEBUG 1
65 
66 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_FLOAT16_KHR == (uint32_t)VK_COMPONENT_TYPE_FLOAT16_NV);
67 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_FLOAT32_KHR == (uint32_t)VK_COMPONENT_TYPE_FLOAT32_NV);
68 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_FLOAT64_KHR == (uint32_t)VK_COMPONENT_TYPE_FLOAT64_NV);
69 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT8_KHR == (uint32_t)VK_COMPONENT_TYPE_SINT8_NV);
70 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT16_KHR == (uint32_t)VK_COMPONENT_TYPE_SINT16_NV);
71 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT32_KHR == (uint32_t)VK_COMPONENT_TYPE_SINT32_NV);
72 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_SINT64_KHR == (uint32_t)VK_COMPONENT_TYPE_SINT64_NV);
73 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT8_KHR == (uint32_t)VK_COMPONENT_TYPE_UINT8_NV);
74 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT16_KHR == (uint32_t)VK_COMPONENT_TYPE_UINT16_NV);
75 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT32_KHR == (uint32_t)VK_COMPONENT_TYPE_UINT32_NV);
76 DE_STATIC_ASSERT((uint32_t)VK_COMPONENT_TYPE_UINT64_KHR == (uint32_t)VK_COMPONENT_TYPE_UINT64_NV);
77 
78 DE_STATIC_ASSERT((uint32_t)VK_SCOPE_DEVICE_KHR == (uint32_t)VK_SCOPE_DEVICE_NV);
79 DE_STATIC_ASSERT((uint32_t)VK_SCOPE_WORKGROUP_KHR == (uint32_t)VK_SCOPE_WORKGROUP_NV);
80 DE_STATIC_ASSERT((uint32_t)VK_SCOPE_SUBGROUP_KHR == (uint32_t)VK_SCOPE_SUBGROUP_NV);
81 DE_STATIC_ASSERT((uint32_t)VK_SCOPE_QUEUE_FAMILY_KHR == (uint32_t)VK_SCOPE_QUEUE_FAMILY_NV);
82 
83 typedef enum
84 {
85     UT_NV = 0,
86     UT_KHR_A,
87     UT_KHR_B,
88     UT_KHR_Result,
89 } UseType;
90 
91 typedef enum
92 {
93     TT_LENGTH = 0,
94     TT_CONSTANT,
95     TT_CONVERT,
96     TT_COMPOSITE,
97     TT_COMPOSITE_RVALUE,
98     TT_ADD,
99     TT_SUB,
100     TT_DIV,
101     TT_MUL,
102     TT_NEGATE,
103     TT_MATRIXTIMESSCALAR,
104     TT_FUNC,
105     TT_MATRIXMULADD,
106     TT_COMPOSITE_ARRAY,
107     TT_MATRIXMULADD_ARRAY,
108     TT_MATRIXMULADD_SATURATED,
109     TT_MATRIXMULADD_WRAPPING,
110     TT_MATRIXMULADD_STRIDE0,
111     TT_MULTICOMPONENT_LOAD,
112     TT_MULTICOMPONENT_SAVE,
113 } TestType;
114 
115 typedef enum
116 {
117     SC_BUFFER = 0,
118     SC_WORKGROUP,
119     SC_WORKGROUP_VARIABLE_POINTERS,
120     SC_BUFFER_VARIABLE_POINTERS,
121     SC_PHYSICAL_STORAGE_BUFFER,
122 } StorageClass;
123 
124 enum SubgroupSizeMode
125 {
126     SUBGROUP_SIZE_NONE = 0,
127     SUBGROUP_SIZE_MIN  = 1,
128     SUBGROUP_SIZE_MAX  = 2,
129 };
130 
131 const VkFlags allShaderStages = VK_SHADER_STAGE_COMPUTE_BIT;
132 
133 struct CaseDef
134 {
135     TestType testType;
136     uint32_t subgroupsPerWorkgroupX;
137     uint32_t subgroupsPerWorkgroupY;
138     uint32_t workgroupsX;
139     uint32_t workgroupsY;
140     VkComponentTypeKHR inputType;
141     VkComponentTypeKHR outputType;
142     bool colMajor;
143     StorageClass storageClass;
144     UseType useType;
145     SubgroupSizeMode subgroupSizeMode;
146     vk::ComputePipelineConstructionType computePipelineConstructionType;
147     uint32_t inputComponentCount;
148     uint32_t outputComponentCount;
149 };
150 
isKhr(UseType useType)151 bool isKhr(UseType useType)
152 {
153     return useType != UT_NV;
154 }
155 
isMatrixMulAddOp(TestType testType)156 bool isMatrixMulAddOp(TestType testType)
157 {
158     return testType == TT_MATRIXMULADD || testType == TT_MATRIXMULADD_ARRAY || testType == TT_MATRIXMULADD_SATURATED ||
159            testType == TT_MATRIXMULADD_WRAPPING || testType == TT_MATRIXMULADD_STRIDE0;
160 }
161 
162 template <typename T>
getCooperativeMatrixProperties(const InstanceInterface &,VkPhysicalDevice,uint32_t *,T *)163 VkResult getCooperativeMatrixProperties(const InstanceInterface &, VkPhysicalDevice, uint32_t *, T *)
164 {
165     TCU_THROW(InternalError, "Not Implementetd");
166 }
167 
getCooperativeMatrixProperties(const InstanceInterface & vki,VkPhysicalDevice physicalDevice,uint32_t * pPropertyCount,VkCooperativeMatrixPropertiesKHR * pProperties)168 VkResult getCooperativeMatrixProperties(const InstanceInterface &vki, VkPhysicalDevice physicalDevice,
169                                         uint32_t *pPropertyCount, VkCooperativeMatrixPropertiesKHR *pProperties)
170 {
171     return vki.getPhysicalDeviceCooperativeMatrixPropertiesKHR(physicalDevice, pPropertyCount, pProperties);
172 }
173 
getCooperativeMatrixProperties(const InstanceInterface & vki,VkPhysicalDevice physicalDevice,uint32_t * pPropertyCount,VkCooperativeMatrixPropertiesNV * pProperties)174 VkResult getCooperativeMatrixProperties(const InstanceInterface &vki, VkPhysicalDevice physicalDevice,
175                                         uint32_t *pPropertyCount, VkCooperativeMatrixPropertiesNV *pProperties)
176 {
177     return vki.getPhysicalDeviceCooperativeMatrixPropertiesNV(physicalDevice, pPropertyCount, pProperties);
178 }
179 
convertCooperativeMatrixProperties(const VkCooperativeMatrixPropertiesNV & properties)180 VkCooperativeMatrixPropertiesKHR convertCooperativeMatrixProperties(const VkCooperativeMatrixPropertiesNV &properties)
181 {
182     VkCooperativeMatrixPropertiesKHR result = initVulkanStructure();
183 
184     result.sType                  = (VkStructureType)properties.sType;
185     result.pNext                  = (void *)properties.pNext;
186     result.MSize                  = (uint32_t)properties.MSize;
187     result.NSize                  = (uint32_t)properties.NSize;
188     result.KSize                  = (uint32_t)properties.KSize;
189     result.AType                  = (VkComponentTypeKHR)properties.AType;
190     result.BType                  = (VkComponentTypeKHR)properties.BType;
191     result.CType                  = (VkComponentTypeKHR)properties.CType;
192     result.ResultType             = (VkComponentTypeKHR)properties.DType;
193     result.saturatingAccumulation = (VkBool32)VK_FALSE;
194     result.scope                  = (VkScopeKHR)properties.scope;
195 
196     return result;
197 }
198 
convertCooperativeMatrixProperties(const std::vector<VkCooperativeMatrixPropertiesNV> & properties)199 std::vector<VkCooperativeMatrixPropertiesKHR> convertCooperativeMatrixProperties(
200     const std::vector<VkCooperativeMatrixPropertiesNV> &properties)
201 {
202     std::vector<VkCooperativeMatrixPropertiesKHR> result(properties.size());
203 
204     for (size_t i = 0; i < properties.size(); ++i)
205         result[i] = convertCooperativeMatrixProperties(properties[i]);
206 
207     return result;
208 }
209 
210 template <typename T>
getCooperativeMatrixPropertiesAll(Context & context,std::vector<T> & properties)211 void getCooperativeMatrixPropertiesAll(Context &context, std::vector<T> &properties)
212 {
213     uint32_t propertyCount = 0;
214 
215     VK_CHECK(getCooperativeMatrixProperties(context.getInstanceInterface(), context.getPhysicalDevice(), &propertyCount,
216                                             (T *)DE_NULL));
217 
218     if (propertyCount > 0)
219     {
220         const T sample = initVulkanStructureConst();
221 
222         properties.resize(propertyCount, sample);
223 
224         VK_CHECK(getCooperativeMatrixProperties(context.getInstanceInterface(), context.getPhysicalDevice(),
225                                                 &propertyCount, properties.data()));
226     }
227     else
228     {
229         properties.clear();
230     }
231 }
232 
getCooperativeMatrixPropertiesConverted(Context & context,const bool khr)233 std::vector<VkCooperativeMatrixPropertiesKHR> getCooperativeMatrixPropertiesConverted(Context &context, const bool khr)
234 {
235     std::vector<VkCooperativeMatrixPropertiesKHR> properties;
236 
237     if (khr)
238     {
239         getCooperativeMatrixPropertiesAll(context, properties);
240     }
241     else
242     {
243         std::vector<VkCooperativeMatrixPropertiesNV> propertiesNV;
244 
245         getCooperativeMatrixPropertiesAll(context, propertiesNV);
246 
247         properties = convertCooperativeMatrixProperties(propertiesNV);
248     }
249 
250     return properties;
251 }
252 
getSubgroupSizeFromMode(Context & context,const SubgroupSizeMode subgroupSizeMode)253 uint32_t getSubgroupSizeFromMode(Context &context, const SubgroupSizeMode subgroupSizeMode)
254 {
255 #ifndef CTS_USES_VULKANSC
256     const VkPhysicalDeviceSubgroupSizeControlProperties &subgroupSizeControlProperties =
257         context.getSubgroupSizeControlProperties();
258 #else
259     const VkPhysicalDeviceSubgroupSizeControlPropertiesEXT &subgroupSizeControlProperties =
260         context.getSubgroupSizeControlPropertiesEXT();
261 #endif // CTS_USES_VULKANSC
262 
263     switch (subgroupSizeMode)
264     {
265     case SUBGROUP_SIZE_MAX:
266         return subgroupSizeControlProperties.maxSubgroupSize;
267     case SUBGROUP_SIZE_MIN:
268         return subgroupSizeControlProperties.minSubgroupSize;
269     case SUBGROUP_SIZE_NONE:
270         return context.getSubgroupProperties().subgroupSize;
271     default:
272         TCU_THROW(NotSupportedError, "Unsupported Subgroup size");
273     }
274 }
275 
276 class CooperativeMatrixTestInstance : public TestInstance
277 {
278 public:
279     CooperativeMatrixTestInstance(Context &context, const CaseDef &data);
280     ~CooperativeMatrixTestInstance(void);
281     tcu::TestStatus iterate(void);
282 
283 private:
284     CaseDef m_data;
285 };
286 
CooperativeMatrixTestInstance(Context & context,const CaseDef & data)287 CooperativeMatrixTestInstance::CooperativeMatrixTestInstance(Context &context, const CaseDef &data)
288     : vkt::TestInstance(context)
289     , m_data(data)
290 {
291 }
292 
~CooperativeMatrixTestInstance(void)293 CooperativeMatrixTestInstance::~CooperativeMatrixTestInstance(void)
294 {
295 }
296 
297 class CooperativeMatrixTestCase : public TestCase
298 {
299 public:
300     CooperativeMatrixTestCase(tcu::TestContext &context, const char *name, const CaseDef data);
301     ~CooperativeMatrixTestCase(void);
302     virtual void initPrograms(SourceCollections &programCollection) const;
303     virtual TestInstance *createInstance(Context &context) const;
304     virtual void checkSupport(Context &context) const;
305 
306 private:
307     CaseDef m_data;
308 };
309 
CooperativeMatrixTestCase(tcu::TestContext & context,const char * name,const CaseDef data)310 CooperativeMatrixTestCase::CooperativeMatrixTestCase(tcu::TestContext &context, const char *name, const CaseDef data)
311     : vkt::TestCase(context, name)
312     , m_data(data)
313 {
314 }
315 
~CooperativeMatrixTestCase(void)316 CooperativeMatrixTestCase::~CooperativeMatrixTestCase(void)
317 {
318 }
319 
checkSupport(Context & context) const320 void CooperativeMatrixTestCase::checkSupport(Context &context) const
321 {
322     if (!context.contextSupports(vk::ApiVersion(0, 1, 1, 0)))
323     {
324         TCU_THROW(NotSupportedError, "Vulkan 1.1 not supported");
325     }
326 
327     if (isKhr(m_data.useType))
328     {
329         if (!context.getCooperativeMatrixFeatures().cooperativeMatrix)
330         {
331             TCU_THROW(NotSupportedError,
332                       "VkPhysicalDeviceCooperativeMatrixFeaturesKHR::cooperativeMatrix not supported");
333         }
334     }
335     else
336     {
337         if (!context.getCooperativeMatrixFeaturesNV().cooperativeMatrix)
338         {
339             TCU_THROW(NotSupportedError,
340                       "VkPhysicalDeviceCooperativeMatrixFeaturesNV::cooperativeMatrix not supported");
341         }
342     }
343 
344     if (!context.getVulkanMemoryModelFeatures().vulkanMemoryModel)
345     {
346         TCU_THROW(NotSupportedError, "vulkanMemoryModel not supported");
347     }
348 
349     if ((m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS || m_data.storageClass == SC_BUFFER_VARIABLE_POINTERS) &&
350         !context.getVariablePointersFeatures().variablePointers)
351     {
352         TCU_THROW(NotSupportedError, "variable pointers not supported");
353     }
354 
355     if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER && !context.isBufferDeviceAddressSupported())
356     {
357         TCU_THROW(NotSupportedError, "buffer device address not supported");
358     }
359 
360     if (!context.getShaderFloat16Int8Features().shaderFloat16 &&
361         (m_data.inputType == VK_COMPONENT_TYPE_FLOAT16_KHR || m_data.outputType == VK_COMPONENT_TYPE_FLOAT16_KHR))
362     {
363         TCU_THROW(NotSupportedError, "shaderFloat16 not supported");
364     }
365 
366     std::vector<VkCooperativeMatrixPropertiesKHR> properties =
367         getCooperativeMatrixPropertiesConverted(context, isKhr(m_data.useType));
368     bool supported[2]   = {false, false};
369     const auto isMMA    = isMatrixMulAddOp(m_data.testType);
370     const auto isMMASat = m_data.testType == TT_MATRIXMULADD_SATURATED;
371 
372     for (size_t i = 0; i < properties.size(); ++i)
373     {
374         const VkCooperativeMatrixPropertiesKHR *p = &properties[i];
375 
376         if (p->scope != VK_SCOPE_SUBGROUP_KHR)
377             continue;
378 
379         if (isMMA && isMMASat != static_cast<bool>(p->saturatingAccumulation))
380             continue;
381 
382         if (isMMA)
383         {
384             if (p->AType == m_data.inputType && p->BType == m_data.inputType && p->CType == m_data.outputType &&
385                 p->ResultType == m_data.outputType)
386             {
387                 supported[0] = supported[1] = true;
388             }
389         }
390         else
391         {
392             const VkComponentTypeKHR types[2] = {m_data.inputType, m_data.outputType};
393 
394             for (uint32_t j = 0; j < 2; ++j)
395             {
396                 switch (m_data.useType)
397                 {
398                 case UT_NV:
399                 {
400                     if (p->AType == types[j] || p->BType == types[j] || p->CType == types[j] ||
401                         p->ResultType == types[j])
402                         supported[j] = true;
403 
404                     break;
405                 }
406                 case UT_KHR_A:
407                 {
408                     if (p->AType == types[j])
409                         supported[j] = true;
410 
411                     break;
412                 }
413                 case UT_KHR_B:
414                 {
415                     if (p->BType == types[j])
416                         supported[j] = true;
417 
418                     break;
419                 }
420                 case UT_KHR_Result:
421                 {
422                     if (p->ResultType == types[j])
423                         supported[j] = true;
424 
425                     break;
426                 }
427                 default:
428                     TCU_THROW(InternalError, "Unsupported use type");
429                 }
430             }
431         }
432     }
433 
434     if (!supported[0] || !supported[1])
435         TCU_THROW(NotSupportedError, "cooperative matrix combination not supported");
436 
437     checkShaderObjectRequirements(context.getInstanceInterface(), context.getPhysicalDevice(),
438                                   m_data.computePipelineConstructionType);
439 }
440 
441 struct
442 {
443     const char *typeName;
444     const char *coopmatTypeName;
445     uint32_t bits;
446     bool isSigned;
447 } componentTypeInfo[] = {
448     {"float16_t", "fcoopmatNV", 16, true}, {"float32_t", "fcoopmatNV", 32, true}, {"float64_t", "fcoopmatNV", 64, true},
449     {"int8_t", "icoopmatNV", 8, true},     {"int16_t", "icoopmatNV", 16, true},   {"int32_t", "icoopmatNV", 32, true},
450     {"int64_t", "icoopmatNV", 64, true},   {"uint8_t", "ucoopmatNV", 8, false},   {"uint16_t", "ucoopmatNV", 16, false},
451     {"uint32_t", "ucoopmatNV", 32, false}, {"uint64_t", "ucoopmatNV", 64, false},
452 };
453 
isFloatType(VkComponentTypeKHR t)454 bool isFloatType(VkComponentTypeKHR t)
455 {
456     switch (t)
457     {
458     case VK_COMPONENT_TYPE_FLOAT16_KHR:
459     case VK_COMPONENT_TYPE_FLOAT32_KHR:
460     case VK_COMPONENT_TYPE_FLOAT64_KHR:
461         return true;
462     default:
463         return false;
464     }
465 }
466 
isSIntType(VkComponentTypeKHR t)467 bool isSIntType(VkComponentTypeKHR t)
468 {
469     switch (t)
470     {
471     case VK_COMPONENT_TYPE_SINT8_KHR:
472     case VK_COMPONENT_TYPE_SINT16_KHR:
473     case VK_COMPONENT_TYPE_SINT32_KHR:
474     case VK_COMPONENT_TYPE_SINT64_KHR:
475         return true;
476     default:
477         return false;
478     }
479 }
480 
initPrograms(SourceCollections & programCollection) const481 void CooperativeMatrixTestCase::initPrograms(SourceCollections &programCollection) const
482 {
483     const char *suffix = (isKhr(m_data.useType) ? "" : "NV");
484     const char *ext    = isKhr(m_data.useType) ? "#extension GL_KHR_cooperative_matrix : enable\n" :
485                                                  "#extension GL_NV_cooperative_matrix : enable\n"
486                                                  "#extension GL_NV_integer_cooperative_matrix : enable\n";
487     const char *sat = (m_data.testType == TT_MATRIXMULADD_SATURATED) ? ", gl_MatrixOperandsSaturatingAccumulation" : "";
488     std::stringstream css;
489     css << "#version 450 core\n";
490     css << "#pragma use_vulkan_memory_model\n";
491     css << "#extension GL_KHR_shader_subgroup_basic : enable\n"
492            "#extension GL_KHR_memory_scope_semantics : enable\n"
493         << ext
494         << "#extension GL_EXT_shader_explicit_arithmetic_types : enable\n"
495            "#extension GL_EXT_buffer_reference : enable\n"
496            "// strides overriden by spec constants\n"
497            "layout(constant_id = 2) const int AStride = 1;\n"
498            "layout(constant_id = 3) const int BStride = 1;\n"
499            "layout(constant_id = 4) const int CStride = 1;\n"
500            "layout(constant_id = 5) const int OStride = 1;\n"
501            "layout(constant_id = 6) const int M = 1;\n"
502            "layout(constant_id = 7) const int N = 1;\n"
503            "layout(constant_id = 8) const int K = 1;\n"
504            "layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;\n";
505 
506     if (m_data.storageClass == SC_BUFFER_VARIABLE_POINTERS || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
507         css << "#pragma use_variable_pointers\n";
508 
509     struct
510     {
511         string rows, cols;
512     } dims[4];
513 
514     if (isMatrixMulAddOp(m_data.testType))
515     {
516         dims[0].rows = "M";
517         dims[0].cols = "K";
518         dims[1].rows = "K";
519         dims[1].cols = "N";
520         dims[2].rows = "M";
521         dims[2].cols = "N";
522         dims[3].rows = "M";
523         dims[3].cols = "N";
524     }
525     else
526     {
527         dims[0].rows = "M";
528         dims[0].cols = "N";
529         dims[1].rows = "M";
530         dims[1].cols = "N";
531         dims[2].rows = "M";
532         dims[2].cols = "N";
533         dims[3].rows = "M";
534         dims[3].cols = "N";
535     }
536 
537     const char *typeStrA = componentTypeInfo[m_data.inputType].typeName;
538     const char *typeStrB = componentTypeInfo[m_data.inputType].typeName;
539     const char *typeStrC = componentTypeInfo[m_data.outputType].typeName;
540     const char *typeStrO = componentTypeInfo[m_data.outputType].typeName;
541     string inputType;
542     string outputType;
543     string divisorA;
544     string divisorB;
545     string divisorC;
546     string divisorO;
547     string *divisors[4] = {&divisorA, &divisorB, &divisorC, &divisorO};
548 
549     if (m_data.testType == TT_MULTICOMPONENT_LOAD)
550     {
551         const char *componentSuffix = m_data.inputComponentCount == 2 ? "vec2" :
552                                       m_data.inputComponentCount == 4 ? "vec4" :
553                                                                         "";
554 
555         inputType = string(1, componentTypeInfo[m_data.inputType].coopmatTypeName[0]) +
556                     de::toString(componentTypeInfo[m_data.inputType].bits) + componentSuffix;
557 
558         typeStrA = inputType.c_str();
559         typeStrB = inputType.c_str();
560         divisorA = m_data.inputComponentCount == 2 ? "/2" : m_data.inputComponentCount == 4 ? "/4" : "";
561         divisorB = divisorA;
562     }
563 
564     if (m_data.testType == TT_MULTICOMPONENT_SAVE)
565     {
566         const char *componentSuffix = m_data.outputComponentCount == 2 ? "vec2" :
567                                       m_data.outputComponentCount == 4 ? "vec4" :
568                                                                          "";
569 
570         outputType = string(1, componentTypeInfo[m_data.outputType].coopmatTypeName[0]) +
571                      de::toString(componentTypeInfo[m_data.outputType].bits) + componentSuffix;
572 
573         typeStrC = outputType.c_str();
574         typeStrO = outputType.c_str();
575         divisorC = m_data.outputComponentCount == 2 ? "/2" : m_data.outputComponentCount == 4 ? "/4" : "";
576         divisorO = divisorC;
577     }
578 
579     css << "const int workgroupsX = " << m_data.workgroupsX << ";\n";
580     css << "const uvec2 subgroupsPerWG = uvec2(" << m_data.subgroupsPerWorkgroupX << ", "
581         << m_data.subgroupsPerWorkgroupY << ");\n";
582 
583     if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
584     {
585         css << "layout(buffer_reference) buffer InputA { " << typeStrA << " x[]; };\n";
586         css << "layout(buffer_reference) buffer InputB { " << typeStrB << " x[]; };\n";
587         css << "layout(buffer_reference) buffer InputC { " << typeStrC << " x[]; };\n";
588         css << "layout(buffer_reference) buffer Output { " << typeStrO << " x[]; };\n";
589         css << "layout(set=0, binding=4) buffer Params { InputA inputA; InputB inputB; InputC inputC; Output outputO; "
590                "} params;\n";
591     }
592     else
593     {
594         css << "layout(set=0, binding=0) coherent buffer InputA { " << typeStrA << " x[]; } inputA;\n";
595         css << "layout(set=0, binding=1) coherent buffer InputB { " << typeStrB << " x[]; } inputB;\n";
596         css << "layout(set=0, binding=2) coherent buffer InputC { " << typeStrC << " x[]; } inputC;\n";
597         css << "layout(set=0, binding=3) coherent buffer Output { " << typeStrO << " x[]; } outputO;\n";
598     }
599 
600     if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
601     {
602         css << "shared " << typeStrA << " sharedA[" << dims[0].rows << " * " << dims[0].cols
603             << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
604         css << "shared " << typeStrB << " sharedB[" << dims[1].rows << " * " << dims[1].cols
605             << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
606         css << "shared " << typeStrC << " sharedC[" << dims[2].rows << " * " << dims[2].cols
607             << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
608         css << "shared " << typeStrO << " sharedO[" << dims[3].rows << " * " << dims[3].cols
609             << " * subgroupsPerWG.x * subgroupsPerWG.y];\n";
610     }
611 
612     std::stringstream matAType, matBType, matCType, outputMatType;
613 
614     if (isKhr(m_data.useType))
615     {
616         const bool useSame   = !isMatrixMulAddOp(m_data.testType);
617         const char *sameType = m_data.useType == UT_KHR_A      ? "gl_MatrixUseA" :
618                                m_data.useType == UT_KHR_B      ? "gl_MatrixUseB" :
619                                m_data.useType == UT_KHR_Result ? "gl_MatrixUseAccumulator" :
620                                                                  "Invalid use";
621         const char *atype    = useSame ? sameType : "gl_MatrixUseA";
622         const char *btype    = useSame ? sameType : "gl_MatrixUseB";
623         const char *ctype    = useSame ? sameType : "gl_MatrixUseAccumulator";
624         const char *rtype    = useSame ? sameType : "gl_MatrixUseAccumulator";
625 
626         matAType << "coopmat<" << componentTypeInfo[m_data.inputType].typeName << ", gl_ScopeSubgroup, " << dims[0].rows
627                  << ", " << dims[0].cols << ", " << atype << ">";
628         matBType << "coopmat<" << componentTypeInfo[m_data.inputType].typeName << ", gl_ScopeSubgroup, " << dims[1].rows
629                  << ", " << dims[1].cols << ", " << btype << ">";
630         matCType << "coopmat<" << componentTypeInfo[m_data.outputType].typeName << ", gl_ScopeSubgroup, "
631                  << dims[2].rows << ", " << dims[2].cols << ", " << ctype << ">";
632         outputMatType << "coopmat<" << componentTypeInfo[m_data.outputType].typeName << ", gl_ScopeSubgroup, "
633                       << dims[3].rows << ", " << dims[3].cols << ", " << rtype << ">";
634     }
635     else
636     {
637         matAType << componentTypeInfo[m_data.inputType].coopmatTypeName << "<"
638                  << componentTypeInfo[m_data.inputType].bits << ", gl_ScopeSubgroup, " << dims[0].rows << ", "
639                  << dims[0].cols << ">";
640         matBType << componentTypeInfo[m_data.inputType].coopmatTypeName << "<"
641                  << componentTypeInfo[m_data.inputType].bits << ", gl_ScopeSubgroup, " << dims[1].rows << ", "
642                  << dims[1].cols << ">";
643         matCType << componentTypeInfo[m_data.outputType].coopmatTypeName << "<"
644                  << componentTypeInfo[m_data.outputType].bits << ", gl_ScopeSubgroup, " << dims[2].rows << ", "
645                  << dims[2].cols << ">";
646         outputMatType << componentTypeInfo[m_data.outputType].coopmatTypeName << "<"
647                       << componentTypeInfo[m_data.outputType].bits << ", gl_ScopeSubgroup, " << dims[3].rows << ", "
648                       << dims[3].cols << ">";
649     }
650 
651     css << matAType.str() << " matA;\n";
652     css << matBType.str() << " matB;\n";
653     css << matCType.str() << " matC;\n";
654     css << outputMatType.str() << " matO;\n";
655 
656     if (m_data.testType == TT_CONSTANT)
657         css << "const " << outputMatType.str() << " matConst = " << outputMatType.str() << "(1.0);\n";
658 
659     if (m_data.testType == TT_FUNC)
660         css << matAType.str() << " f(" << matAType.str() << " m) { return -m; }\n";
661 
662     css << "void main()\n"
663            "{\n"
664            // matrixID is the x,y index of the matrix owned by this subgroup.
665            "   uvec2 subgroupXY = uvec2(gl_SubgroupID % subgroupsPerWG.x, gl_SubgroupID / subgroupsPerWG.x);\n"
666            "   uvec2 matrixID = uvec2(gl_WorkGroupID.xy) * subgroupsPerWG + subgroupXY;\n";
667 
668     if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
669     {
670         css << "   InputA inputA = params.inputA;\n";
671         css << "   InputB inputB = params.inputB;\n";
672         css << "   InputC inputC = params.inputC;\n";
673         css << "   Output outputO = params.outputO;\n";
674     }
675 
676     string strides[4];
677     for (uint32_t i = 0; i < 4; ++i)
678     {
679         strides[i] = (m_data.colMajor ? dims[i].rows : dims[i].cols) + string(" * ") +
680                      de::toString(m_data.subgroupsPerWorkgroupX * m_data.workgroupsX);
681     }
682 
683     // element<i> is the starting element in buffer memory.
684     // elementS<i> is the starting element in shared memory.
685     css << "   uint element0 = (" << strides[0] << " * " << (m_data.colMajor ? dims[0].cols : dims[0].rows)
686         << " * matrixID.y + " << (m_data.colMajor ? dims[0].rows : dims[0].cols) << " * matrixID.x)" << divisorA
687         << ";\n"
688            "   uint element1 = ("
689         << strides[1] << " * " << (m_data.colMajor ? dims[1].cols : dims[1].rows) << " * matrixID.y + "
690         << (m_data.colMajor ? dims[1].rows : dims[1].cols) << " * matrixID.x)" << divisorB
691         << ";\n"
692            "   uint element2 = ("
693         << strides[2] << " * " << (m_data.colMajor ? dims[2].cols : dims[2].rows) << " * matrixID.y + "
694         << (m_data.colMajor ? dims[2].rows : dims[2].cols) << " * matrixID.x)" << divisorC
695         << ";\n"
696            "   uint element3 = ("
697         << strides[3] << " * " << (m_data.colMajor ? dims[3].cols : dims[3].rows) << " * matrixID.y + "
698         << (m_data.colMajor ? dims[3].rows : dims[3].cols) << " * matrixID.x)" << divisorO
699         << ";\n"
700            "   uint elementS0, elementS1, elementS2, elementS3;\n";
701 
702     // For shared memory tests, copy the matrix from buffer memory into
703     // workgroup memory. For simplicity, do it all on a single thread.
704     if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
705     {
706         const char *name[] = {
707             "sharedA",
708             "sharedB",
709             "sharedC",
710         };
711         const char *inputName[] = {
712             "inputA",
713             "inputB",
714             "inputC",
715         };
716         for (uint32_t m = 0; m < 4; ++m)
717         {
718             string sharedStride = strides[m] + " / workgroupsX";
719             css << "       elementS" << m << " = (" << sharedStride << " * "
720                 << (m_data.colMajor ? dims[m].cols : dims[m].rows) << " * subgroupXY.y + "
721                 << (m_data.colMajor ? dims[m].rows : dims[m].cols) << " * subgroupXY.x)" << *divisors[m] << ";\n";
722         }
723         css << "   if (subgroupElect()) {\n";
724         // copy all three input buffers.
725         for (uint32_t m = 0; m < 3; ++m)
726         {
727             string sharedStride = strides[m] + " / workgroupsX";
728             css << "       for (int i = 0; i < " << dims[m].rows
729                 << "; ++i) {\n"
730                    "       for (int j = 0; j < "
731                 << dims[m].cols
732                 << "; ++j) {\n"
733                    "           int localElementInput = ("
734                 << strides[m] << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ")"
735                 << *divisors[m]
736                 << ";\n"
737                    "           int localElementShared = ("
738                 << sharedStride << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j")
739                 << ")" << *divisors[m]
740                 << ";\n"
741                    "           "
742                 << name[m] << "[elementS" << m << " + localElementShared] = " << inputName[m] << ".x[element" << m
743                 << " + localElementInput];\n"
744                    "       }\n"
745                    "       }\n";
746             strides[m] = sharedStride;
747         }
748         css << "   }\n";
749         css << "   controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, "
750                "gl_SemanticsAcquireRelease);\n";
751     }
752 
753     const char *colMajorNV = (m_data.colMajor ? "true" : "false");
754     const char *colMajorKHR =
755         (m_data.colMajor ? "gl_CooperativeMatrixLayoutColumnMajor" : "gl_CooperativeMatrixLayoutRowMajor");
756     const char *colMajor = (isKhr(m_data.useType) ? colMajorKHR : colMajorNV);
757 
758     string loadStrides[3] = {strides[0] + divisorA, strides[1] + divisorB, strides[2] + divisorC};
759     // Load with a stride of 0
760     if (m_data.testType == TT_MATRIXMULADD_STRIDE0)
761         loadStrides[0] = loadStrides[1] = loadStrides[2] = "0";
762 
763     if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
764     {
765         css << "   coopMatLoad" << suffix << "(matA, sharedA, elementS0, " << loadStrides[0] << ", " << colMajor
766             << ");\n"
767                "   coopMatLoad"
768             << suffix << "(matB, sharedB, elementS1, " << loadStrides[1] << ", " << colMajor
769             << ");\n"
770                "   coopMatLoad"
771             << suffix << "(matC, sharedC, elementS2, " << loadStrides[2] << ", " << colMajor << ");\n";
772     }
773     else
774     {
775         css << "   coopMatLoad" << suffix << "(matA, inputA.x, element0, " << loadStrides[0] << ", " << colMajor
776             << ");\n"
777                "   coopMatLoad"
778             << suffix << "(matB, inputB.x, element1, " << loadStrides[1] << ", " << colMajor
779             << ");\n"
780                "   coopMatLoad"
781             << suffix << "(matC, inputC.x, element2, " << loadStrides[2] << ", " << colMajor << ");\n";
782     }
783 
784     if (m_data.testType == TT_COMPOSITE_ARRAY || m_data.testType == TT_MATRIXMULADD_ARRAY)
785     {
786         css << "   " << matAType.str() << " matAArr[2];\n    matAArr[1] = matA; matAArr[0] = " << matAType.str()
787             << "(0.0);\n"
788                "   "
789             << matBType.str() << " matBArr[2];\n    matBArr[1] = matB; matBArr[0] = " << matBType.str()
790             << "(0.0);\n"
791                "   "
792             << matCType.str() << " matCArr[2];\n    matCArr[1] = matC; matCArr[0] = " << matCType.str()
793             << "(0.0);\n"
794                "   "
795             << outputMatType.str() << " matOArr[2];\n";
796     }
797 
798     switch (m_data.testType)
799     {
800     default:
801         DE_ASSERT(0);
802         // fall through
803     case TT_LENGTH:
804         css << "   matO = " << outputMatType.str() << "(matO.length());\n";
805         break;
806     case TT_CONSTANT:
807         css << "   matO = matConst;\n";
808         break;
809     case TT_CONVERT:
810         css << "   matO = " << outputMatType.str() << "(matA);\n";
811         break;
812     case TT_COMPOSITE:
813         css << "   " << matAType.str() << " t = " << matAType.str()
814             << "(matB[0]);\n"
815                "   for (int i = 1; i < matA.length(); ++i) {\n"
816                "       matO[i] = matA[i] + matB[i];\n"
817                "   }\n"
818                "   if (matA.length() > 0)\n"
819                "       matO[0] = matA[0] + t[0];\n";
820         break;
821     case TT_COMPOSITE_RVALUE:
822         css << "   for (int i = 1; i < matA.length(); ++i) {\n"
823                "       matO[i] = matA[i] + matB[i];\n"
824                "   }\n"
825                "   "
826             << matAType.str()
827             << " t = matA;\n"
828                "   if (matA.length() > 0) {\n"
829                "       matO[0] = (t += matB)[0];\n"
830                "   }\n";
831         break;
832     case TT_COMPOSITE_ARRAY:
833         css << "   for (int i = 0; i < matA.length(); ++i) {\n"
834                "       matOArr[1][i] = matAArr[1][i] + matBArr[1][i];\n"
835                "   }\n";
836         break;
837     case TT_ADD:
838         css << "   matO = matA + matB;\n";
839         break;
840     case TT_SUB:
841         css << "   matO = matA - matB;\n";
842         break;
843     case TT_DIV:
844         css << "   matO = matA / matB;\n";
845         break;
846     case TT_MUL:
847         css << "   matO = matA * matB;\n";
848         break;
849     case TT_NEGATE:
850         css << "   matO = -matA;\n";
851         break;
852     case TT_FUNC:
853         css << "   matO = f(matA);\n";
854         break;
855     case TT_MATRIXTIMESSCALAR:
856         css << "   matO = (" << typeStrA << "(2.0)*matA)*" << typeStrA << "(3.0);\n";
857         break;
858     case TT_MATRIXMULADD_STRIDE0:
859     case TT_MATRIXMULADD_WRAPPING:
860     case TT_MATRIXMULADD_SATURATED:
861     case TT_MATRIXMULADD:
862         css << "   matO = coopMatMulAdd" << suffix << "(matA, matB, matC" << sat << ");\n";
863         break;
864     case TT_MATRIXMULADD_ARRAY:
865         css << "   matOArr[1] = coopMatMulAdd" << suffix << "(matAArr[1], matBArr[1], matCArr[1]);\n";
866         break;
867     case TT_MULTICOMPONENT_LOAD:
868         css << "   matO = matA;\n";
869         break;
870     case TT_MULTICOMPONENT_SAVE:
871         css << "   matO = matA;\n";
872         break;
873     }
874 
875     if (m_data.testType == TT_COMPOSITE_ARRAY || m_data.testType == TT_MATRIXMULADD_ARRAY)
876     {
877         css << "   matOArr[0] = " << outputMatType.str() << "(0.0);\n";
878         css << "   matO = matOArr[1];\n";
879     }
880 
881     if (m_data.storageClass == SC_WORKGROUP || m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)
882     {
883         string sharedStride = strides[3] + " / workgroupsX";
884         css << "   coopMatStore" << suffix << "(matO, sharedO, elementS3, " << sharedStride << divisorO << ", "
885             << colMajor << ");\n";
886         css << "   controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, "
887                "gl_SemanticsAcquireRelease);\n";
888         css << "   if (subgroupElect()) {\n";
889         css << "       for (int i = 0; i < " << dims[3].rows
890             << "; ++i) {\n"
891                "       for (int j = 0; j < "
892             << dims[3].cols
893             << "; ++j) {\n"
894                "           int localElementInput = ("
895             << strides[3] << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ")"
896             << *divisors[3]
897             << ";\n"
898                "           int localElementShared = ("
899             << sharedStride << " * " << (m_data.colMajor ? "j" : "i") << " + " << (m_data.colMajor ? "i" : "j") << ")"
900             << *divisors[3]
901             << ";\n"
902                "           outputO.x[element3 + localElementInput] = sharedO[elementS3 + localElementShared];\n"
903                "       }\n"
904                "       }\n";
905         css << "   }\n";
906     }
907     else
908     {
909         css << "   coopMatStore" << suffix << "(matO, outputO.x, element3, " << strides[3] << divisorO << ", "
910             << colMajor << ");\n";
911     }
912 
913     css << "}\n";
914 
915     const vk::ShaderBuildOptions buildOptions(programCollection.usedVulkanVersion, vk::SPIRV_VERSION_1_3, 0u);
916 
917     programCollection.glslSources.add("test") << glu::ComputeSource(css.str()) << buildOptions;
918 }
919 
createInstance(Context & context) const920 TestInstance *CooperativeMatrixTestCase::createInstance(Context &context) const
921 {
922     return new CooperativeMatrixTestInstance(context, m_data);
923 }
924 
setDataFloat(void * base,VkComponentTypeKHR dt,uint32_t i,float value)925 void setDataFloat(void *base, VkComponentTypeKHR dt, uint32_t i, float value)
926 {
927     if (dt == VK_COMPONENT_TYPE_FLOAT32_KHR)
928     {
929         ((float *)base)[i] = value;
930     }
931     else
932     {
933         DE_ASSERT(dt == VK_COMPONENT_TYPE_FLOAT16_KHR);
934         ((tcu::float16_t *)base)[i] = tcu::Float16(value).bits();
935     }
936 }
937 
getDataFloat(void * base,VkComponentTypeKHR dt,uint32_t i)938 float getDataFloat(void *base, VkComponentTypeKHR dt, uint32_t i)
939 {
940     if (dt == VK_COMPONENT_TYPE_FLOAT32_KHR)
941     {
942         return ((float *)base)[i];
943     }
944     else
945     {
946         DE_ASSERT(dt == VK_COMPONENT_TYPE_FLOAT16_KHR);
947         return tcu::Float16(((const tcu::float16_t *)base)[i]).asFloat();
948     }
949 }
950 
setDataInt(void * base,VkComponentTypeKHR dt,uint32_t i,uint32_t value)951 void setDataInt(void *base, VkComponentTypeKHR dt, uint32_t i, uint32_t value)
952 {
953     DE_ASSERT(componentTypeInfo[dt].bits <= 32);
954 
955     switch (dt)
956     {
957     case VK_COMPONENT_TYPE_UINT8_KHR:
958         ((uint8_t *)base)[i] = (uint8_t)value;
959         break;
960     case VK_COMPONENT_TYPE_UINT16_KHR:
961         ((uint16_t *)base)[i] = (uint16_t)value;
962         break;
963     case VK_COMPONENT_TYPE_UINT32_KHR:
964         ((uint32_t *)base)[i] = (uint32_t)value;
965         break;
966     case VK_COMPONENT_TYPE_SINT8_KHR:
967         ((int8_t *)base)[i] = (int8_t)value;
968         break;
969     case VK_COMPONENT_TYPE_SINT16_KHR:
970         ((int16_t *)base)[i] = (int16_t)value;
971         break;
972     case VK_COMPONENT_TYPE_SINT32_KHR:
973         ((int32_t *)base)[i] = (int32_t)value;
974         break;
975     default:
976         TCU_THROW(InternalError, "Unsupported type");
977     }
978 }
979 
getDataInt(void * base,VkComponentTypeKHR dt,uint32_t i)980 uint32_t getDataInt(void *base, VkComponentTypeKHR dt, uint32_t i)
981 {
982     DE_ASSERT(componentTypeInfo[dt].bits <= 32);
983 
984     switch (dt)
985     {
986     case VK_COMPONENT_TYPE_UINT8_KHR:
987         return ((uint8_t *)base)[i];
988     case VK_COMPONENT_TYPE_UINT16_KHR:
989         return ((uint16_t *)base)[i];
990     case VK_COMPONENT_TYPE_UINT32_KHR:
991         return ((uint32_t *)base)[i];
992     case VK_COMPONENT_TYPE_SINT8_KHR:
993         return ((int8_t *)base)[i];
994     case VK_COMPONENT_TYPE_SINT16_KHR:
995         return ((int16_t *)base)[i];
996     case VK_COMPONENT_TYPE_SINT32_KHR:
997         return ((int32_t *)base)[i];
998     default:
999         TCU_THROW(InternalError, "Unsupported type");
1000     }
1001 }
1002 
1003 template <typename T>
getDataConvertedToT(void * base,VkComponentTypeKHR dt,uint32_t i)1004 T getDataConvertedToT(void *base, VkComponentTypeKHR dt, uint32_t i)
1005 {
1006     DE_ASSERT(componentTypeInfo[dt].bits <= 32);
1007 
1008     switch (dt)
1009     {
1010     case VK_COMPONENT_TYPE_UINT8_KHR:
1011         return (T)((uint8_t *)base)[i];
1012     case VK_COMPONENT_TYPE_UINT16_KHR:
1013         return (T)((uint16_t *)base)[i];
1014     case VK_COMPONENT_TYPE_UINT32_KHR:
1015         return (T)((uint32_t *)base)[i];
1016     case VK_COMPONENT_TYPE_SINT8_KHR:
1017         return (T)((int8_t *)base)[i];
1018     case VK_COMPONENT_TYPE_SINT16_KHR:
1019         return (T)((int16_t *)base)[i];
1020     case VK_COMPONENT_TYPE_SINT32_KHR:
1021         return (T)((int32_t *)base)[i];
1022     case VK_COMPONENT_TYPE_FLOAT32_KHR:
1023     {
1024         float temp = ((float *)base)[i];
1025         if (std::numeric_limits<T>::min() == 0)
1026             temp = std::max(temp, 0.0f);
1027         return (T)temp;
1028     }
1029     case VK_COMPONENT_TYPE_FLOAT16_KHR:
1030     {
1031         float temp = tcu::Float16(((tcu::float16_t *)base)[i]).asFloat();
1032         if (std::numeric_limits<T>::min() == 0)
1033             temp = std::max(temp, 0.0f);
1034         return (T)temp;
1035     }
1036     default:
1037         TCU_THROW(InternalError, "Unsupported type");
1038     }
1039 }
1040 
1041 template <typename T>
satAdd(T a,T b)1042 T satAdd(T a, T b)
1043 {
1044     if (a > 0)
1045     {
1046         if (b > std::numeric_limits<T>::max() - a)
1047             return std::numeric_limits<T>::max();
1048     }
1049     else if (b < std::numeric_limits<T>::min() - a)
1050     {
1051         return std::numeric_limits<T>::min();
1052     }
1053 
1054     return (T)(a + b);
1055 }
1056 
satAddData(VkComponentTypeKHR dt,uint32_t a,uint32_t b)1057 uint32_t satAddData(VkComponentTypeKHR dt, uint32_t a, uint32_t b)
1058 {
1059     DE_ASSERT(componentTypeInfo[dt].bits <= 32);
1060 
1061     switch (dt)
1062     {
1063     case VK_COMPONENT_TYPE_UINT8_KHR:
1064         return deMinu32(a + b, std::numeric_limits<uint8_t>::max());
1065     case VK_COMPONENT_TYPE_UINT16_KHR:
1066         return deMinu32(a + b, std::numeric_limits<uint16_t>::max());
1067     case VK_COMPONENT_TYPE_UINT32_KHR:
1068         return (a + b >= a) ? a + b : std::numeric_limits<uint32_t>::max();
1069     case VK_COMPONENT_TYPE_SINT8_KHR:
1070         return (uint32_t)satAdd((int8_t)a, (int8_t)b);
1071     case VK_COMPONENT_TYPE_SINT16_KHR:
1072         return (uint32_t)satAdd((int16_t)a, (int16_t)b);
1073     case VK_COMPONENT_TYPE_SINT32_KHR:
1074         return (uint32_t)satAdd((int32_t)a, (int32_t)b);
1075     default:
1076         TCU_THROW(InternalError, "Unsupported type");
1077     }
1078 }
1079 
getLimit(VkComponentTypeKHR dt,bool positive)1080 uint32_t getLimit(VkComponentTypeKHR dt, bool positive)
1081 {
1082     DE_ASSERT(componentTypeInfo[dt].bits <= 32);
1083 
1084     switch (dt)
1085     {
1086     case VK_COMPONENT_TYPE_UINT8_KHR:
1087         return uint32_t(positive ? std::numeric_limits<uint8_t>::max() : std::numeric_limits<uint8_t>::min());
1088     case VK_COMPONENT_TYPE_UINT16_KHR:
1089         return uint32_t(positive ? std::numeric_limits<uint16_t>::max() : std::numeric_limits<uint16_t>::min());
1090     case VK_COMPONENT_TYPE_UINT32_KHR:
1091         return uint32_t(positive ? std::numeric_limits<uint32_t>::max() : std::numeric_limits<uint32_t>::min());
1092     case VK_COMPONENT_TYPE_SINT8_KHR:
1093         return uint32_t(positive ? std::numeric_limits<int8_t>::max() : std::numeric_limits<int8_t>::min());
1094     case VK_COMPONENT_TYPE_SINT16_KHR:
1095         return uint32_t(positive ? std::numeric_limits<int16_t>::max() : std::numeric_limits<int16_t>::min());
1096     case VK_COMPONENT_TYPE_SINT32_KHR:
1097         return uint32_t(positive ? std::numeric_limits<int32_t>::max() : std::numeric_limits<int32_t>::min());
1098     default:
1099         TCU_THROW(InternalError, "Unsupported type");
1100     }
1101 }
1102 
setSingleElementInt(void * data,VkComponentTypeKHR dt,uint32_t start,uint32_t count,uint32_t step,uint32_t at,uint32_t val)1103 void setSingleElementInt(void *data, VkComponentTypeKHR dt, uint32_t start, uint32_t count, uint32_t step, uint32_t at,
1104                          uint32_t val)
1105 {
1106     for (uint32_t i = 0; i < count; i++)
1107         setDataInt(data, dt, start + i * step, (i == at) ? val : 0);
1108 }
1109 
1110 #ifdef COOPERATIVE_MATRIX_EXTENDED_DEBUG
dumpWholeMatrix(void * data,VkComponentTypeKHR dt,bool colMajor,uint32_t matrixElemCount,uint32_t stride)1111 string dumpWholeMatrix(void *data, VkComponentTypeKHR dt, bool colMajor, uint32_t matrixElemCount, uint32_t stride)
1112 {
1113     const uint32_t rowsCount = colMajor ? stride : matrixElemCount / stride;
1114     const uint32_t colsCount = colMajor ? matrixElemCount / stride : stride;
1115     bool floatType           = isFloatType(dt);
1116     bool sIntType            = isSIntType(dt);
1117     std::stringstream ss;
1118 
1119     DE_ASSERT(rowsCount * colsCount == matrixElemCount);
1120 
1121     for (uint32_t r = 0; r < rowsCount; r++)
1122     {
1123         for (uint32_t c = 0; c < colsCount; c++)
1124         {
1125             const uint32_t i = colMajor ? rowsCount * c + r : colsCount * r + c;
1126 
1127             if (floatType)
1128                 ss << getDataFloat(data, dt, i) << "\t";
1129             else if (sIntType)
1130                 ss << (int32_t)getDataInt(data, dt, i) << "\t";
1131             else
1132                 ss << getDataInt(data, dt, i) << "\t";
1133         }
1134 
1135         ss << std::endl;
1136     }
1137 
1138     return ss.str();
1139 }
1140 #endif
1141 
iterate(void)1142 tcu::TestStatus CooperativeMatrixTestInstance::iterate(void)
1143 {
1144     const DeviceInterface &vk = m_context.getDeviceInterface();
1145     const VkDevice device     = m_context.getDevice();
1146     Allocator &allocator      = m_context.getDefaultAllocator();
1147     MemoryRequirement memoryDeviceAddress =
1148         m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER &&
1149                 m_context.isDeviceFunctionalitySupported("VK_KHR_buffer_device_address") ?
1150             MemoryRequirement::DeviceAddress :
1151             MemoryRequirement::Any;
1152     qpTestResult finalres       = QP_TEST_RESULT_NOT_SUPPORTED;
1153     tcu::TestLog &log           = m_context.getTestContext().getLog();
1154     const bool saturated        = (m_data.testType == TT_MATRIXMULADD_SATURATED);
1155     const uint32_t subgroupSize = getSubgroupSizeFromMode(m_context, m_data.subgroupSizeMode);
1156     const float epsilon         = 1.0f / float(1ull << 17); // 131072 is epsilon circa 1e-5
1157     vk::VkPhysicalDeviceProperties vkproperties;
1158 
1159     m_context.getInstanceInterface().getPhysicalDeviceProperties(m_context.getPhysicalDevice(), &vkproperties);
1160 
1161     deRandom rnd;
1162     deRandom_init(&rnd, 1234);
1163 
1164     std::vector<VkCooperativeMatrixPropertiesKHR> properties =
1165         getCooperativeMatrixPropertiesConverted(m_context, isKhr(m_data.useType));
1166 
1167     struct TestTuple
1168     {
1169         TestTuple()
1170         {
1171         }
1172         TestTuple(uint32_t m, uint32_t n, uint32_t k) : M(m), N(n), K(k)
1173         {
1174         }
1175 
1176         bool operator<(const TestTuple &other) const
1177         {
1178             return M < other.M || (M == other.M && N < other.N) || (M == other.M && N == other.N && K < other.K);
1179         }
1180 
1181         uint32_t M, N, K;
1182     };
1183 
1184     vector<TestTuple> testSizes;
1185 
1186     if (isMatrixMulAddOp(m_data.testType))
1187     {
1188         for (size_t i = 0; i < properties.size(); ++i)
1189         {
1190             VkCooperativeMatrixPropertiesKHR *p = &properties[i];
1191 
1192             if (p->AType == m_data.inputType && p->BType == m_data.inputType && p->CType == m_data.outputType &&
1193                 p->ResultType == m_data.outputType && p->scope == VK_SCOPE_SUBGROUP_KHR)
1194             {
1195                 testSizes.push_back(TestTuple(p->MSize, p->NSize, p->KSize));
1196             }
1197         }
1198     }
1199     else
1200     {
1201         set<TestTuple> typeSizes[2];
1202         VkComponentTypeKHR types[2] = {m_data.inputType, m_data.outputType};
1203         const bool aType            = (m_data.useType == UT_KHR_A) || (m_data.useType == UT_NV);
1204         const bool bType            = (m_data.useType == UT_KHR_B) || (m_data.useType == UT_NV);
1205         const bool rType            = (m_data.useType == UT_KHR_Result) || (m_data.useType == UT_NV);
1206 
1207         for (uint32_t i = 0; i < properties.size(); ++i)
1208         {
1209             VkCooperativeMatrixPropertiesKHR *p = &properties[i];
1210 
1211             if (p->scope != VK_SCOPE_SUBGROUP_KHR)
1212                 continue;
1213 
1214             for (uint32_t j = 0; j < 2; ++j)
1215             {
1216                 // For these tests, m_data.M/N are always the matrix size. Check if they match
1217                 // any input or output in the list.
1218                 if (aType && p->AType == types[j])
1219                     typeSizes[j].insert(TestTuple(p->MSize, p->KSize, 0));
1220                 if (bType && p->BType == types[j])
1221                     typeSizes[j].insert(TestTuple(p->KSize, p->NSize, 0));
1222                 if (rType && (p->CType == types[j] || p->ResultType == types[j]))
1223                     typeSizes[j].insert(TestTuple(p->MSize, p->NSize, 0));
1224             }
1225         }
1226         // Test those sizes that are supported for both the input and output type.
1227         std::set_intersection(typeSizes[0].begin(), typeSizes[0].end(), typeSizes[1].begin(), typeSizes[1].end(),
1228                               std::back_inserter(testSizes));
1229     }
1230 
1231     properties.resize(0);
1232 
1233     for (unsigned int s = 0; s < testSizes.size(); ++s)
1234     {
1235         // When testing a multiply, MxNxK is the type of matrix multiply.
1236         // Otherwise, MxN is the size of the input/output matrices
1237         uint32_t M, N, K;
1238         M = testSizes[s].M;
1239         N = testSizes[s].N;
1240         K = testSizes[s].K;
1241 
1242         log << tcu::TestLog::Message << "Testing M = " << M << ", N = " << N << ", K = " << K
1243             << tcu::TestLog::EndMessage;
1244 
1245         struct
1246         {
1247             uint32_t rows, cols;
1248         } dims[4];
1249 
1250         if (isMatrixMulAddOp(m_data.testType))
1251         {
1252             dims[0].rows = M;
1253             dims[0].cols = K;
1254             dims[1].rows = K;
1255             dims[1].cols = N;
1256             dims[2].rows = M;
1257             dims[2].cols = N;
1258             dims[3].rows = M;
1259             dims[3].cols = N;
1260         }
1261         else
1262         {
1263             dims[0].rows = M;
1264             dims[0].cols = N;
1265             dims[1].rows = M;
1266             dims[1].cols = N;
1267             dims[2].rows = M;
1268             dims[2].cols = N;
1269             dims[3].rows = M;
1270             dims[3].cols = N;
1271         }
1272 
1273         VkComponentTypeKHR dataTypes[4];
1274         size_t elementSize[4];
1275         VkDeviceSize bufferSizes[5];
1276         de::MovePtr<BufferWithMemory> buffers[5];
1277         vk::VkDescriptorBufferInfo bufferDescriptors[5];
1278         uint32_t strides[4]; // in elements
1279         uint32_t loadStrides[4];
1280         uint32_t totalElements[4];
1281         size_t sharedMemoryUsage[4];
1282         size_t totalSharedMemoryUsage = 0;
1283 
1284         for (uint32_t i = 0; i < 5; ++i)
1285         {
1286             if (i < 4)
1287             {
1288                 // A/B use input type, C/D use output type
1289                 dataTypes[i]   = (i < 2) ? m_data.inputType : m_data.outputType;
1290                 elementSize[i] = componentTypeInfo[dataTypes[i]].bits / 8;
1291 
1292                 strides[i] = (m_data.colMajor ? dims[i].rows : dims[i].cols) * m_data.subgroupsPerWorkgroupX *
1293                              m_data.workgroupsX;
1294                 loadStrides[i]   = strides[i];
1295                 totalElements[i] = strides[i] * (m_data.colMajor ? dims[i].cols : dims[i].rows) *
1296                                    m_data.subgroupsPerWorkgroupY * m_data.workgroupsY;
1297                 sharedMemoryUsage[i] = dims[i].cols * dims[i].rows * m_data.subgroupsPerWorkgroupX *
1298                                        m_data.subgroupsPerWorkgroupY * elementSize[i] *
1299                                        ((i < 2) ? m_data.inputComponentCount : m_data.outputComponentCount);
1300 
1301                 bufferSizes[i] = totalElements[i] * elementSize[i];
1302 
1303                 // Check there is enough shared memory supported
1304                 if ((m_data.useType != UT_NV) &&
1305                     ((m_data.storageClass == SC_WORKGROUP) || (m_data.storageClass == SC_WORKGROUP_VARIABLE_POINTERS)))
1306                 {
1307                     totalSharedMemoryUsage += sharedMemoryUsage[i];
1308                     if (totalSharedMemoryUsage > vkproperties.limits.maxComputeSharedMemorySize)
1309                         throw tcu::NotSupportedError("Not enough shared memory supported.");
1310                 }
1311             }
1312             else
1313             {
1314                 bufferSizes[4] = sizeof(VkDeviceAddress) * 4;
1315             }
1316 
1317             try
1318             {
1319                 buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1320                     vk, device, allocator,
1321                     makeBufferCreateInfo(bufferSizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT |
1322                                                              VK_BUFFER_USAGE_TRANSFER_DST_BIT |
1323                                                              VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
1324                                                              (memoryDeviceAddress == MemoryRequirement::DeviceAddress ?
1325                                                                   VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT :
1326                                                                   0)),
1327                     MemoryRequirement::HostVisible | MemoryRequirement::Cached | MemoryRequirement::Coherent |
1328                         memoryDeviceAddress));
1329             }
1330             catch (const tcu::NotSupportedError &)
1331             {
1332                 buffers[i] = de::MovePtr<BufferWithMemory>(new BufferWithMemory(
1333                     vk, device, allocator,
1334                     makeBufferCreateInfo(bufferSizes[i], VK_BUFFER_USAGE_STORAGE_BUFFER_BIT |
1335                                                              VK_BUFFER_USAGE_TRANSFER_DST_BIT |
1336                                                              VK_BUFFER_USAGE_TRANSFER_SRC_BIT |
1337                                                              (memoryDeviceAddress == MemoryRequirement::DeviceAddress ?
1338                                                                   VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT :
1339                                                                   0)),
1340                     MemoryRequirement::HostVisible | memoryDeviceAddress));
1341             }
1342 
1343             bufferDescriptors[i] = makeDescriptorBufferInfo(**buffers[i], 0, bufferSizes[i]);
1344         }
1345 
1346         // Load with a stride of 0
1347         if (m_data.testType == TT_MATRIXMULADD_STRIDE0)
1348             loadStrides[0] = loadStrides[1] = loadStrides[2] = loadStrides[3] = 0;
1349 
1350         void *ptrs[5];
1351         for (uint32_t i = 0; i < 5; ++i)
1352         {
1353             ptrs[i] = buffers[i]->getAllocation().getHostPtr();
1354         }
1355 
1356         vk::DescriptorSetLayoutBuilder layoutBuilder;
1357 
1358         layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1359         layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1360         layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1361         layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1362         layoutBuilder.addSingleBinding(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, allShaderStages);
1363 
1364         vk::Unique<vk::VkDescriptorSetLayout> descriptorSetLayout(layoutBuilder.build(vk, device));
1365 
1366         vk::Unique<vk::VkDescriptorPool> descriptorPool(
1367             vk::DescriptorPoolBuilder()
1368                 .addType(VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 5u)
1369                 .build(vk, device, VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u));
1370         vk::Unique<vk::VkDescriptorSet> descriptorSet(
1371             makeDescriptorSet(vk, device, *descriptorPool, *descriptorSetLayout));
1372 
1373         vk::DescriptorSetUpdateBuilder setUpdateBuilder;
1374         if (m_data.storageClass == SC_PHYSICAL_STORAGE_BUFFER)
1375         {
1376             VkBufferDeviceAddressInfo info{
1377                 VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO, // VkStructureType  sType;
1378                 DE_NULL,                                      // const void*  pNext;
1379                 0,                                            // VkBuffer            buffer
1380             };
1381             VkDeviceAddress *addrsInMemory = (VkDeviceAddress *)ptrs[4];
1382             for (uint32_t i = 0; i < 4; ++i)
1383             {
1384                 info.buffer          = **buffers[i];
1385                 VkDeviceAddress addr = vk.getBufferDeviceAddress(device, &info);
1386                 addrsInMemory[i]     = addr;
1387             }
1388             setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(4),
1389                                          VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[4]);
1390         }
1391         else
1392         {
1393             setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(0),
1394                                          VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[0]);
1395             setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(1),
1396                                          VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[1]);
1397             setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(2),
1398                                          VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[2]);
1399             setUpdateBuilder.writeSingle(*descriptorSet, vk::DescriptorSetUpdateBuilder::Location::binding(3),
1400                                          VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &bufferDescriptors[3]);
1401         }
1402 
1403         setUpdateBuilder.update(vk, device);
1404 
1405         VkPipelineBindPoint bindPoint = VK_PIPELINE_BIND_POINT_COMPUTE;
1406 
1407         const uint32_t specData[9] = {
1408             subgroupSize * m_data.subgroupsPerWorkgroupX,
1409             m_data.subgroupsPerWorkgroupY,
1410             strides[0],
1411             strides[1],
1412             strides[2],
1413             strides[3],
1414             M,
1415             N,
1416             K,
1417         };
1418 
1419         const vk::VkSpecializationMapEntry entries[9] = {
1420             {0, (uint32_t)(sizeof(uint32_t) * 0), sizeof(uint32_t)},
1421             {1, (uint32_t)(sizeof(uint32_t) * 1), sizeof(uint32_t)},
1422             {2, (uint32_t)(sizeof(uint32_t) * 2), sizeof(uint32_t)},
1423             {3, (uint32_t)(sizeof(uint32_t) * 3), sizeof(uint32_t)},
1424             {4, (uint32_t)(sizeof(uint32_t) * 4), sizeof(uint32_t)},
1425             {5, (uint32_t)(sizeof(uint32_t) * 5), sizeof(uint32_t)},
1426             {6, (uint32_t)(sizeof(uint32_t) * 6), sizeof(uint32_t)},
1427             {7, (uint32_t)(sizeof(uint32_t) * 7), sizeof(uint32_t)},
1428             {8, (uint32_t)(sizeof(uint32_t) * 8), sizeof(uint32_t)},
1429         };
1430 
1431         const vk::VkSpecializationInfo specInfo = {
1432             9,                // mapEntryCount
1433             entries,          // pMapEntries
1434             sizeof(specData), // dataSize
1435             specData          // pData
1436         };
1437 
1438         for (uint32_t i = 0; i < 4; ++i)
1439             for (uint32_t j = 0; j < totalElements[i]; ++j)
1440             {
1441                 if (isFloatType(dataTypes[i]))
1442                 {
1443                     if (!isMatrixMulAddOp(m_data.testType))
1444                         setDataFloat(ptrs[i], dataTypes[i], j,
1445                                      ((float)(deRandom_getUint32(&rnd) & 0xff) - 64.0f) / 2.0f);
1446                     else
1447                         setDataFloat(ptrs[i], dataTypes[i], j, ((float)(deRandom_getUint32(&rnd) & 0xf) - 4.0f) / 2.0f);
1448                 }
1449                 else
1450                 {
1451                     if (m_data.testType == TT_MATRIXMULADD_WRAPPING)
1452                     {
1453                         // Choose matrix values that should cause overflow and underflow, to
1454                         // verify wrapping behavior. Use the full range of values for A and B.
1455                         // For matrix C, use values clustered near where the type wraps (zero
1456                         // for unsigned, 2^(N-1) for signed).
1457                         uint32_t bits = componentTypeInfo[dataTypes[i]].bits;
1458                         uint32_t value;
1459                         if (i == 2)
1460                         {
1461                             value = (deRandom_getUint32(&rnd) & 0xff) - 128;
1462                             if (componentTypeInfo[dataTypes[i]].isSigned)
1463                                 value += (1U << (bits - 1));
1464                         }
1465                         else
1466                         {
1467                             uint32_t mask = (bits == 32) ? 0xFFFFFFFFU : ((1U << bits) - 1U);
1468                             value         = deRandom_getUint32(&rnd) & mask;
1469                         }
1470                         setDataInt(ptrs[i], dataTypes[i], j, value);
1471                     }
1472                     else if (m_data.testType == TT_MATRIXMULADD_SATURATED)
1473                     {
1474                         setDataInt(ptrs[i], dataTypes[i], j, 0);
1475                     }
1476                     else
1477                     {
1478                         uint32_t value = (deRandom_getUint32(&rnd) & 0xff) - 128;
1479                         setDataInt(ptrs[i], dataTypes[i], j, value);
1480                     }
1481                 }
1482             }
1483 
1484         if (m_data.testType == TT_MATRIXMULADD_SATURATED)
1485         {
1486             // Set 1st row of A to 1,0,0...
1487             setSingleElementInt(ptrs[0], dataTypes[0], 0, dims[0].cols, (m_data.colMajor ? strides[0] : 1), 0, 1);
1488 
1489             // Set 1st column of B to 1,0,0...
1490             setSingleElementInt(ptrs[1], dataTypes[1], 0, dims[1].rows, (m_data.colMajor ? 1 : strides[1]), 0, 1);
1491 
1492             // Set C element at {0,0} to maximum type value, thus we will have overflow at plus operation in D=A*B+C for this element
1493             setDataInt(ptrs[2], dataTypes[2], 0, getLimit(dataTypes[2], true));
1494 
1495             // Check underflow if all involved elements support negative values
1496             if (isSIntType(dataTypes[1]) && isSIntType(dataTypes[2]) && isSIntType(dataTypes[3]))
1497             {
1498                 // Set 2nd row of A to 0,1,0,0...
1499                 setSingleElementInt(ptrs[0], dataTypes[0], (m_data.colMajor ? 1 : strides[0]), dims[0].cols,
1500                                     (m_data.colMajor ? strides[0] : 1), 1, 1);
1501 
1502                 // Set 2nd column of B to 0,-1,0,0...
1503                 setSingleElementInt(ptrs[1], dataTypes[1], (m_data.colMajor ? strides[1] : 1), dims[1].rows,
1504                                     (m_data.colMajor ? 1 : strides[1]), 1, -1);
1505 
1506                 // Set C element at {1,1} to minimum type value, thus we will have underflow at plus operation in D=A*B+C for this element
1507                 setDataInt(ptrs[2], dataTypes[2], strides[2] + 1, getLimit(dataTypes[2], false));
1508             }
1509         }
1510 
1511         flushAlloc(vk, device, buffers[0]->getAllocation());
1512         flushAlloc(vk, device, buffers[1]->getAllocation());
1513         flushAlloc(vk, device, buffers[2]->getAllocation());
1514         flushAlloc(vk, device, buffers[3]->getAllocation());
1515 
1516         ComputePipelineWrapper pipeline(vk, device, m_data.computePipelineConstructionType,
1517                                         m_context.getBinaryCollection().get("test"));
1518         pipeline.setDescriptorSetLayout(descriptorSetLayout.get());
1519         pipeline.setSpecializationInfo(specInfo);
1520         pipeline.setSubgroupSize(m_data.subgroupSizeMode == SUBGROUP_SIZE_NONE ?
1521                                      0 :
1522                                      getSubgroupSizeFromMode(m_context, m_data.subgroupSizeMode));
1523         pipeline.buildPipeline();
1524 
1525         const VkQueue queue             = m_context.getUniversalQueue();
1526         Move<VkCommandPool> cmdPool     = createCommandPool(vk, device, 0, m_context.getUniversalQueueFamilyIndex());
1527         Move<VkCommandBuffer> cmdBuffer = allocateCommandBuffer(vk, device, *cmdPool, VK_COMMAND_BUFFER_LEVEL_PRIMARY);
1528 
1529         beginCommandBuffer(vk, *cmdBuffer, 0u);
1530 
1531         vk.cmdBindDescriptorSets(*cmdBuffer, bindPoint, pipeline.getPipelineLayout(), 0u, 1, &*descriptorSet, 0u,
1532                                  DE_NULL);
1533         pipeline.bind(*cmdBuffer);
1534 
1535         vk.cmdDispatch(*cmdBuffer, m_data.workgroupsX, m_data.workgroupsY, 1);
1536 
1537         const VkMemoryBarrier barrier = {
1538             VK_STRUCTURE_TYPE_MEMORY_BARRIER, // sType
1539             nullptr,                          // pNext
1540             VK_ACCESS_SHADER_WRITE_BIT,       // srcAccessMask
1541             VK_ACCESS_HOST_READ_BIT,          // dstAccessMask
1542         };
1543         vk.cmdPipelineBarrier(*cmdBuffer, VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, VK_PIPELINE_STAGE_HOST_BIT,
1544                               (VkDependencyFlags)0, 1, &barrier, 0, nullptr, 0, nullptr);
1545 
1546         endCommandBuffer(vk, *cmdBuffer);
1547 
1548         submitCommandsAndWait(vk, device, queue, cmdBuffer.get());
1549 
1550         invalidateAlloc(vk, device, buffers[3]->getAllocation());
1551 
1552         qpTestResult res = QP_TEST_RESULT_PASS;
1553 
1554         if (m_data.testType == TT_CONVERT)
1555         {
1556             for (uint32_t i = 0; i < totalElements[3]; ++i)
1557             {
1558                 // Store results as double, which has enough range to hold all the other types exactly.
1559                 double inputA, output;
1560 
1561                 // This loads the data according to dataTypes[0], and then converts to the template parameter type
1562                 switch (dataTypes[3])
1563                 {
1564                 case VK_COMPONENT_TYPE_UINT8_KHR:
1565                     inputA = getDataConvertedToT<uint8_t>(ptrs[0], dataTypes[0], i);
1566                     break;
1567                 case VK_COMPONENT_TYPE_UINT16_KHR:
1568                     inputA = getDataConvertedToT<uint16_t>(ptrs[0], dataTypes[0], i);
1569                     break;
1570                 case VK_COMPONENT_TYPE_UINT32_KHR:
1571                     inputA = getDataConvertedToT<uint32_t>(ptrs[0], dataTypes[0], i);
1572                     break;
1573                 case VK_COMPONENT_TYPE_SINT8_KHR:
1574                     inputA = getDataConvertedToT<int8_t>(ptrs[0], dataTypes[0], i);
1575                     break;
1576                 case VK_COMPONENT_TYPE_SINT16_KHR:
1577                     inputA = getDataConvertedToT<int16_t>(ptrs[0], dataTypes[0], i);
1578                     break;
1579                 case VK_COMPONENT_TYPE_SINT32_KHR:
1580                     inputA = getDataConvertedToT<int32_t>(ptrs[0], dataTypes[0], i);
1581                     break;
1582                 case VK_COMPONENT_TYPE_FLOAT32_KHR:
1583                     inputA = getDataConvertedToT<float>(ptrs[0], dataTypes[0], i);
1584                     break;
1585                 case VK_COMPONENT_TYPE_FLOAT16_KHR:
1586                 {
1587                     float temp = getDataConvertedToT<float>(ptrs[0], dataTypes[0], i);
1588                     inputA     = tcu::Float16(temp).asDouble();
1589                     break;
1590                 }
1591                 default:
1592                     TCU_THROW(InternalError, "Unexpected type");
1593                 }
1594 
1595                 switch (dataTypes[3])
1596                 {
1597                 case VK_COMPONENT_TYPE_UINT8_KHR:
1598                     output = getDataConvertedToT<uint8_t>(ptrs[3], dataTypes[3], i);
1599                     break;
1600                 case VK_COMPONENT_TYPE_UINT16_KHR:
1601                     output = getDataConvertedToT<uint16_t>(ptrs[3], dataTypes[3], i);
1602                     break;
1603                 case VK_COMPONENT_TYPE_UINT32_KHR:
1604                     output = getDataConvertedToT<uint32_t>(ptrs[3], dataTypes[3], i);
1605                     break;
1606                 case VK_COMPONENT_TYPE_SINT8_KHR:
1607                     output = getDataConvertedToT<int8_t>(ptrs[3], dataTypes[3], i);
1608                     break;
1609                 case VK_COMPONENT_TYPE_SINT16_KHR:
1610                     output = getDataConvertedToT<int16_t>(ptrs[3], dataTypes[3], i);
1611                     break;
1612                 case VK_COMPONENT_TYPE_SINT32_KHR:
1613                     output = getDataConvertedToT<int32_t>(ptrs[3], dataTypes[3], i);
1614                     break;
1615                 case VK_COMPONENT_TYPE_FLOAT32_KHR:
1616                     output = getDataConvertedToT<float>(ptrs[3], dataTypes[3], i);
1617                     break;
1618                 case VK_COMPONENT_TYPE_FLOAT16_KHR:
1619                 {
1620                     float temp = getDataConvertedToT<float>(ptrs[3], dataTypes[3], i);
1621                     output     = tcu::Float16(temp).asDouble();
1622                     break;
1623                 }
1624                 default:
1625                     TCU_THROW(InternalError, "Unexpected type");
1626                 }
1627 
1628                 if (inputA != output)
1629                 {
1630                     res = QP_TEST_RESULT_FAIL;
1631                     break;
1632                 }
1633             }
1634         }
1635         else if (isFloatType(dataTypes[0]))
1636         {
1637             if (!isMatrixMulAddOp(m_data.testType))
1638             {
1639                 for (uint32_t i = 0; i < totalElements[3]; ++i)
1640                 {
1641                     float inputA = getDataFloat(ptrs[0], dataTypes[0], i);
1642                     float inputB = getDataFloat(ptrs[1], dataTypes[1], i);
1643                     float output = getDataFloat(ptrs[3], dataTypes[3], i);
1644                     switch (m_data.testType)
1645                     {
1646                     case TT_LENGTH:
1647                         if (output < 1.0f || output > (float)(N * M))
1648                             res = QP_TEST_RESULT_FAIL;
1649                         // We expect the matrix to be spread evenly across invocations, it is
1650                         // surprising (but not necessarily illegal) if not
1651                         if (output != (float)(N * M / subgroupSize) && res == QP_TEST_RESULT_PASS)
1652                             res = QP_TEST_RESULT_QUALITY_WARNING;
1653                         break;
1654                     case TT_CONSTANT:
1655                         if (output != 1.0f)
1656                             res = QP_TEST_RESULT_FAIL;
1657                         break;
1658                     case TT_COMPOSITE:
1659                     case TT_COMPOSITE_RVALUE:
1660                     case TT_COMPOSITE_ARRAY:
1661                     case TT_ADD:
1662                         if (output != inputA + inputB)
1663                             res = QP_TEST_RESULT_FAIL;
1664                         break;
1665                     case TT_SUB:
1666                         if (output != inputA - inputB)
1667                             res = QP_TEST_RESULT_FAIL;
1668                         break;
1669                     case TT_DIV:
1670                     {
1671                         float ulp = (m_data.inputType == VK_COMPONENT_TYPE_FLOAT16_KHR) ?
1672                                         1.0f / 1024.0f :
1673                                         1.0f / (8.0f * 1024.0f * 1024.0f);
1674                         // division allows 2.5ulp, but we'll use 3.
1675                         ulp *= 3;
1676                         if (inputB != 0 && fabs(output - inputA / inputB) > ulp * fabs(inputA / inputB))
1677                             res = QP_TEST_RESULT_FAIL;
1678                     }
1679                     break;
1680                     case TT_MUL:
1681                     {
1682                         if (dataTypes[0] == VK_COMPONENT_TYPE_FLOAT16_KHR)
1683                         {
1684                             const float expected32          = inputA * inputB;
1685                             const tcu::float16_t expected16 = tcu::Float16(expected32).bits();
1686                             const float expected            = tcu::Float16(expected16).asFloat();
1687 
1688                             if (output != expected)
1689                                 res = QP_TEST_RESULT_FAIL;
1690                         }
1691                         else
1692                         {
1693                             if (output != inputA * inputB)
1694                                 res = QP_TEST_RESULT_FAIL;
1695                         }
1696                         break;
1697                     }
1698                     case TT_NEGATE:
1699                     case TT_FUNC:
1700                         if (output != -inputA)
1701                             res = QP_TEST_RESULT_FAIL;
1702                         break;
1703                     case TT_MATRIXTIMESSCALAR:
1704                         if (output != 6.0 * inputA)
1705                             res = QP_TEST_RESULT_FAIL;
1706                         break;
1707                     case TT_MULTICOMPONENT_LOAD:
1708                     {
1709                         if (output != inputA)
1710                             res = QP_TEST_RESULT_FAIL;
1711                         break;
1712                     }
1713                     case TT_MULTICOMPONENT_SAVE:
1714                     {
1715                         if (output != inputA)
1716                             res = QP_TEST_RESULT_FAIL;
1717                         break;
1718                     }
1719                     default:
1720                         TCU_THROW(InternalError, "Unimplemented");
1721                     }
1722                 }
1723             }
1724             else
1725             {
1726                 uint32_t ik, kj, ij;
1727                 for (uint32_t mX = 0; mX < m_data.subgroupsPerWorkgroupX * m_data.workgroupsX; ++mX)
1728                 {
1729                     for (uint32_t mY = 0; mY < m_data.subgroupsPerWorkgroupY * m_data.workgroupsY; ++mY)
1730                     {
1731                         for (uint32_t i = 0; i < M; ++i)
1732                         {
1733                             for (uint32_t j = 0; j < N; ++j)
1734                             {
1735                                 float ref = 0;
1736                                 for (uint32_t k = 0; k < K; ++k)
1737                                 {
1738                                     if (m_data.colMajor)
1739                                         ik = mX * M + i + strides[0] * mY * K + loadStrides[0] * k;
1740                                     else
1741                                         ik = mX * K + k + strides[0] * mY * M + loadStrides[0] * i;
1742 
1743                                     float Aik = getDataFloat(ptrs[0], dataTypes[0], ik);
1744 
1745                                     if (m_data.colMajor)
1746                                         kj = mX * K + k + strides[1] * mY * N + loadStrides[1] * j;
1747                                     else
1748                                         kj = mX * N + j + strides[1] * mY * K + loadStrides[1] * k;
1749 
1750                                     float Bkj = getDataFloat(ptrs[1], dataTypes[1], kj);
1751 
1752                                     ref += Aik * Bkj;
1753                                 }
1754 
1755                                 if (m_data.colMajor)
1756                                     ij = mX * M + i + strides[2] * mY * N + loadStrides[2] * j;
1757                                 else
1758                                     ij = mX * N + j + strides[2] * mY * M + loadStrides[2] * i;
1759 
1760                                 float Cij = getDataFloat(ptrs[2], dataTypes[2], ij);
1761 
1762                                 ref += Cij;
1763 
1764                                 // When loading with stride 0, ij for matrix D is different from matrix C
1765                                 if (m_data.colMajor)
1766                                     ij = mX * M + i + strides[2] * (mY * N + j);
1767                                 else
1768                                     ij = mX * N + j + strides[2] * (mY * M + i);
1769 
1770                                 float Dij = getDataFloat(ptrs[3], dataTypes[3], ij);
1771 
1772                                 if (fabs(ref - Dij) > epsilon)
1773                                 {
1774                                     res = QP_TEST_RESULT_FAIL;
1775                                 }
1776                             }
1777                         }
1778                     }
1779                 }
1780             }
1781         }
1782         else
1783         {
1784             if (!isMatrixMulAddOp(m_data.testType))
1785             {
1786                 for (uint32_t i = 0; i < totalElements[3]; ++i)
1787                 {
1788                     uint32_t inputA = getDataInt(ptrs[0], dataTypes[0], i);
1789                     uint32_t inputB = getDataInt(ptrs[1], dataTypes[1], i);
1790                     uint32_t output = getDataInt(ptrs[3], dataTypes[3], i);
1791                     int resultSize  = componentTypeInfo[dataTypes[3]].bits;
1792                     uint32_t mask   = resultSize == 32 ? ~0 : ((1 << resultSize) - 1);
1793                     switch (m_data.testType)
1794                     {
1795                     case TT_LENGTH:
1796                         if (output < 1 || output > N * M)
1797                             res = QP_TEST_RESULT_FAIL;
1798                         // We expect the matrix to be spread evenly across invocations, it is
1799                         // surprising (but not necessarily illegal) if not
1800                         if (output != N * M / subgroupSize && res == QP_TEST_RESULT_PASS)
1801                             res = QP_TEST_RESULT_QUALITY_WARNING;
1802                         break;
1803                     case TT_CONSTANT:
1804                         if (output != 1)
1805                             res = QP_TEST_RESULT_FAIL;
1806                         break;
1807                     case TT_COMPOSITE:
1808                     case TT_COMPOSITE_RVALUE:
1809                     case TT_COMPOSITE_ARRAY:
1810                     case TT_ADD:
1811                         if ((output & mask) != ((inputA + inputB) & mask))
1812                         {
1813                             res = QP_TEST_RESULT_FAIL;
1814                         }
1815                         break;
1816                     case TT_SUB:
1817                         if ((output & mask) != ((inputA - inputB) & mask))
1818                             res = QP_TEST_RESULT_FAIL;
1819                         break;
1820                     case TT_DIV:
1821                     {
1822                         if (isSIntType(dataTypes[3]))
1823                         {
1824                             if (inputB != 0 && ((int32_t)output & mask) != (((int32_t)inputA / (int32_t)inputB) & mask))
1825                                 res = QP_TEST_RESULT_FAIL;
1826                         }
1827                         else
1828                         {
1829                             if (inputB != 0 && output != inputA / inputB)
1830                                 res = QP_TEST_RESULT_FAIL;
1831                         }
1832                     }
1833                     break;
1834                     case TT_MUL:
1835                     {
1836                         if (((int32_t)output & mask) != (((int32_t)inputA * (int32_t)inputB) & mask))
1837                         {
1838                             res = QP_TEST_RESULT_FAIL;
1839                         }
1840 
1841                         break;
1842                     }
1843                     case TT_NEGATE:
1844                     case TT_FUNC:
1845                         if ((output & mask) != ((-(int32_t)inputA) & mask))
1846                             res = QP_TEST_RESULT_FAIL;
1847                         break;
1848                     case TT_MATRIXTIMESSCALAR:
1849                         if ((output & mask) != ((6 * inputA) & mask))
1850                         {
1851                             res = QP_TEST_RESULT_FAIL;
1852                         }
1853                         break;
1854                     case TT_MULTICOMPONENT_LOAD:
1855                     {
1856                         if (output != inputA)
1857                             res = QP_TEST_RESULT_FAIL;
1858                         break;
1859                     }
1860                     case TT_MULTICOMPONENT_SAVE:
1861                     {
1862                         if (output != inputA)
1863                             res = QP_TEST_RESULT_FAIL;
1864                         break;
1865                     }
1866                     default:
1867                         TCU_THROW(InternalError, "Unimplemented");
1868                     }
1869                 }
1870             }
1871             else
1872             {
1873                 uint32_t ik, kj, ij;
1874                 for (uint32_t mX = 0; mX < m_data.subgroupsPerWorkgroupX * m_data.workgroupsX; ++mX)
1875                 {
1876                     for (uint32_t mY = 0; mY < m_data.subgroupsPerWorkgroupY * m_data.workgroupsY; ++mY)
1877                     {
1878                         for (uint32_t i = 0; i < M; ++i)
1879                         {
1880                             for (uint32_t j = 0; j < N; ++j)
1881                             {
1882                                 uint32_t ref = 0;
1883 
1884                                 for (uint32_t k = 0; k < K; ++k)
1885                                 {
1886                                     if (m_data.colMajor)
1887                                         ik = mX * M + i + strides[0] * mY * K + loadStrides[0] * k;
1888                                     else
1889                                         ik = mX * K + k + strides[0] * mY * M + loadStrides[0] * i;
1890 
1891                                     uint32_t Aik = getDataInt(ptrs[0], dataTypes[0], ik);
1892 
1893                                     if (m_data.colMajor)
1894                                         kj = mX * K + k + strides[1] * mY * N + loadStrides[1] * j;
1895                                     else
1896                                         kj = mX * N + j + strides[1] * mY * K + loadStrides[1] * k;
1897 
1898                                     uint32_t Bkj = getDataInt(ptrs[1], dataTypes[1], kj);
1899 
1900                                     ref += Aik * Bkj;
1901                                 }
1902 
1903                                 if (m_data.colMajor)
1904                                     ij = mX * M + i + strides[2] * mY * N + loadStrides[2] * j;
1905                                 else
1906                                     ij = mX * N + j + strides[2] * mY * M + loadStrides[2] * i;
1907 
1908                                 uint32_t Cij = getDataInt(ptrs[2], dataTypes[2], ij);
1909 
1910                                 if (saturated)
1911                                 {
1912                                     ref = satAddData(dataTypes[2], ref, Cij);
1913                                 }
1914                                 else
1915                                 {
1916                                     ref += Cij;
1917                                     // truncate the result to the size of C's type.
1918                                     uint32_t bits = componentTypeInfo[dataTypes[3]].bits;
1919                                     uint32_t mask = (bits == 32) ? 0xFFFFFFFFU : ((1U << bits) - 1U);
1920                                     ref &= mask;
1921                                 }
1922 
1923                                 // When loading with stride 0, ij for matrix D is different from matrix C
1924                                 if (m_data.colMajor)
1925                                     ij = mX * M + i + strides[2] * (mY * N + j);
1926                                 else
1927                                     ij = mX * N + j + strides[2] * (mY * M + i);
1928 
1929                                 uint32_t Dij = getDataInt(ptrs[3], dataTypes[3], ij);
1930 
1931                                 if (ref != Dij)
1932                                 {
1933                                     res = QP_TEST_RESULT_FAIL;
1934                                 }
1935                             }
1936                         }
1937                     }
1938                 }
1939             }
1940         }
1941 
1942         if (res != QP_TEST_RESULT_PASS)
1943         {
1944             finalres = res;
1945 
1946             log << tcu::TestLog::Message << "failed with M = " << M << ", N = " << N << ", K = " << K
1947                 << tcu::TestLog::EndMessage;
1948 
1949 #ifdef COOPERATIVE_MATRIX_EXTENDED_DEBUG
1950             for (int i = 0; i < 4; i++)
1951             {
1952                 const char *matrixNames[] = {"A", "B", "C", "D"};
1953 
1954                 log << tcu::TestLog::Message << "Matrix " << matrixNames[i]
1955                     << "[rows=" << m_data.subgroupsPerWorkgroupY * m_data.workgroupsY * dims[i].rows
1956                     << ", cols=" << m_data.subgroupsPerWorkgroupX * m_data.workgroupsX * dims[i].cols << "]:\n"
1957                     << dumpWholeMatrix(ptrs[i], dataTypes[i], m_data.colMajor, totalElements[i], strides[i])
1958                     << tcu::TestLog::EndMessage;
1959             }
1960 #endif
1961         }
1962         else
1963         {
1964             if (finalres == QP_TEST_RESULT_NOT_SUPPORTED)
1965                 finalres = res;
1966         }
1967     }
1968 
1969     return tcu::TestStatus(finalres, qpGetTestResultName(finalres));
1970 }
1971 
getUseType(UseType useType)1972 const char *getUseType(UseType useType)
1973 {
1974     switch (useType)
1975     {
1976     case UT_NV:
1977         return "nv";
1978     case UT_KHR_A:
1979         return "khr_a";
1980     case UT_KHR_B:
1981         return "khr_b";
1982     case UT_KHR_Result:
1983         return "khr_r";
1984     default:
1985         TCU_THROW(InternalError, "Unknown use type");
1986     }
1987 }
1988 
createCooperativeMatrixTestsInternal(tcu::TestContext & testCtx,vk::ComputePipelineConstructionType computePipelineConstructionType,UseType useType)1989 tcu::TestCaseGroup *createCooperativeMatrixTestsInternal(
1990     tcu::TestContext &testCtx, vk::ComputePipelineConstructionType computePipelineConstructionType, UseType useType)
1991 {
1992     de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, getUseType(useType)));
1993 
1994     typedef struct
1995     {
1996         uint32_t value;
1997         const char *name;
1998     } TestGroupCase;
1999 
2000     typedef struct
2001     {
2002         uint32_t value[2];
2003         const char *name;
2004     } TestGroupCase2;
2005 
2006     typedef struct
2007     {
2008         SubgroupSizeMode value;
2009         const char *name;
2010     } SubGroubSizes;
2011 
2012     typedef struct
2013     {
2014         const char *name;
2015         const char *description;
2016         uint32_t componentCount;
2017     } MulticomponentTypes;
2018 
2019     typedef struct
2020     {
2021         const char *name;
2022         const char *description;
2023         TestType testType;
2024     } IOTypes;
2025 
2026     TestGroupCase ttCases[] = {
2027         // OpCooperativeMatrixLength
2028         {TT_LENGTH, "length"},
2029         // OpConstantComposite
2030         {TT_CONSTANT, "constant"},
2031         // OpCompositeConstruct
2032         {TT_COMPOSITE, "composite"},
2033         // OpCompositeExtract
2034         {TT_COMPOSITE_RVALUE, "composite_rvalue"},
2035         // OpFAdd/OpIAdd
2036         {TT_ADD, "add"},
2037         // OpFSub/OpISub
2038         {TT_SUB, "sub"},
2039         // OpFDiv/OpSDiv/OpUDiv
2040         {TT_DIV, "div"},
2041         // OpFMul/OpIMul
2042         {TT_MUL, "mul"},
2043         // OpFNegate/OpSNegate
2044         {TT_NEGATE, "negate"},
2045         // OpMatrixTimesScalar
2046         {TT_MATRIXTIMESSCALAR, "matrixtimesscalar"},
2047         // OpFunctionParameter
2048         {TT_FUNC, "func"},
2049         // OpCooperativeMatrixMulAdd
2050         {TT_MATRIXMULADD, "matrixmuladd"},
2051         // OpCompositeConstruct w/array
2052         {TT_COMPOSITE_ARRAY, "composite_array"},
2053         // OpCooperativeMatrixMulAdd w/array
2054         {TT_MATRIXMULADD_ARRAY, "matrixmuladd_array"},
2055         // OpCooperativeMatrixMulAdd w/saturations
2056         {TT_MATRIXMULADD_SATURATED, "matrixmuladd_saturated"},
2057         // OpCooperativeMatrixMulAdd w/wrapping
2058         {TT_MATRIXMULADD_WRAPPING, "matrixmuladd_wrapping"},
2059         // OpCooperativeMatrixMulAdd w/stride==0
2060         {TT_MATRIXMULADD_STRIDE0, "matrixmuladd_stride0"},
2061     };
2062     TestGroupCase2 dtCases[] = {
2063         // A/B are fp32 C/D are fp32
2064         {{VK_COMPONENT_TYPE_FLOAT32_KHR, VK_COMPONENT_TYPE_FLOAT32_KHR}, "float32_float32"},
2065         // A/B are fp32 C/D are fp16
2066         {{VK_COMPONENT_TYPE_FLOAT32_KHR, VK_COMPONENT_TYPE_FLOAT16_KHR}, "float32_float16"},
2067         // A/B are fp16 C/D are fp32
2068         {{VK_COMPONENT_TYPE_FLOAT16_KHR, VK_COMPONENT_TYPE_FLOAT32_KHR}, "float16_float32"},
2069         // A/B are fp16 C/D are fp16
2070         {{VK_COMPONENT_TYPE_FLOAT16_KHR, VK_COMPONENT_TYPE_FLOAT16_KHR}, "float16_float16"},
2071         // A/B are u8 C/D are u8
2072         {{VK_COMPONENT_TYPE_UINT8_KHR, VK_COMPONENT_TYPE_UINT8_KHR}, "uint8_uint8"},
2073         // A/B are u8 C/D are u32
2074         {{VK_COMPONENT_TYPE_UINT8_KHR, VK_COMPONENT_TYPE_UINT32_KHR}, "uint8_uint32"},
2075         // A/B are s8 C/D are s8
2076         {{VK_COMPONENT_TYPE_SINT8_KHR, VK_COMPONENT_TYPE_SINT8_KHR}, "sint8_sint8"},
2077         // A/B are s8 C/D are s32
2078         {{VK_COMPONENT_TYPE_SINT8_KHR, VK_COMPONENT_TYPE_SINT32_KHR}, "sint8_sint32"},
2079         // A/B are u8 C/D are s32
2080         {{VK_COMPONENT_TYPE_UINT8_KHR, VK_COMPONENT_TYPE_SINT32_KHR}, "uint8_sint32"},
2081         // A/B are u32 C/D are u32
2082         {{VK_COMPONENT_TYPE_UINT32_KHR, VK_COMPONENT_TYPE_UINT32_KHR}, "uint32_uint32"},
2083         // A/B are u32 C/D are u8
2084         {{VK_COMPONENT_TYPE_UINT32_KHR, VK_COMPONENT_TYPE_UINT8_KHR}, "uint32_uint8"},
2085         // A/B are s32 C/D are s32
2086         {{VK_COMPONENT_TYPE_SINT32_KHR, VK_COMPONENT_TYPE_SINT32_KHR}, "sint32_sint32"},
2087         // A/B are s32 C/D are s8
2088         {{VK_COMPONENT_TYPE_SINT32_KHR, VK_COMPONENT_TYPE_SINT8_KHR}, "sint32_sint8"},
2089     };
2090     SubGroubSizes sgsCases[] = {
2091         // Default subgroup size
2092         {SUBGROUP_SIZE_NONE, ""},
2093         // Minimum subgroup size
2094         {SUBGROUP_SIZE_MIN, "_min"},
2095         // Maximum subgroup size
2096         {SUBGROUP_SIZE_MAX, "_max"},
2097     };
2098 
2099     TestGroupCase colCases[] = {
2100         {0, "rowmajor"},
2101         {1, "colmajor"},
2102     };
2103 
2104     TestGroupCase scCases[] = {
2105         // SSBO
2106         {SC_BUFFER, "buffer"},
2107         // shared memory
2108         {SC_WORKGROUP, "workgroup"},
2109         // SSBO w/variable pointers
2110         {SC_BUFFER_VARIABLE_POINTERS, "buffer_varptr"},
2111         // shared memory w/variable pointers
2112         {SC_WORKGROUP_VARIABLE_POINTERS, "workgroup_varptr"},
2113         // physical_storage_buffer
2114         {SC_PHYSICAL_STORAGE_BUFFER, "physical_buffer"},
2115     };
2116 
2117     // Types tested for conversions. Excludes 64b types.
2118     VkComponentTypeKHR allTypes[] = {
2119         VK_COMPONENT_TYPE_FLOAT16_KHR, VK_COMPONENT_TYPE_FLOAT32_KHR, VK_COMPONENT_TYPE_SINT8_KHR,
2120         VK_COMPONENT_TYPE_SINT16_KHR,  VK_COMPONENT_TYPE_SINT32_KHR,  VK_COMPONENT_TYPE_UINT8_KHR,
2121         VK_COMPONENT_TYPE_UINT16_KHR,  VK_COMPONENT_TYPE_UINT32_KHR,
2122     };
2123 
2124     // Types tested for load/store from/into multicomponent types
2125     MulticomponentTypes multicomponentTypes[] = {
2126         {"vec2", "2-component vector type as input or output", 2},
2127         {"vec4", "4-component vector type as input or output", 4},
2128     };
2129 
2130     // Types tested for load/store from/into multicomponent types
2131     IOTypes ioTypes[] = {
2132         {"load", "Test multicomponent type as input in load operation", TT_MULTICOMPONENT_LOAD},
2133         {"save", "Test multicomponent type as output in store operation", TT_MULTICOMPONENT_SAVE},
2134     };
2135 
2136     for (int ttNdx = 0; ttNdx < DE_LENGTH_OF_ARRAY(ttCases); ttNdx++)
2137     {
2138         const TestType testType = (TestType)ttCases[ttNdx].value;
2139 
2140         for (int sgsNdx = 0; sgsNdx < DE_LENGTH_OF_ARRAY(sgsCases); sgsNdx++)
2141         {
2142             if (testType != TT_MATRIXMULADD && sgsCases[sgsNdx].value != SUBGROUP_SIZE_NONE)
2143                 continue;
2144 
2145             if (testType == TT_MATRIXMULADD && sgsCases[sgsNdx].value != SUBGROUP_SIZE_NONE && useType == UT_NV)
2146                 continue;
2147 
2148             const string name = string(ttCases[ttNdx].name) + sgsCases[sgsNdx].name;
2149             de::MovePtr<tcu::TestCaseGroup> ttGroup(new tcu::TestCaseGroup(testCtx, name.c_str()));
2150 
2151             for (int dtNdx = 0; dtNdx < DE_LENGTH_OF_ARRAY(dtCases); dtNdx++)
2152             {
2153                 de::MovePtr<tcu::TestCaseGroup> dtGroup(new tcu::TestCaseGroup(testCtx, dtCases[dtNdx].name));
2154                 for (int scNdx = 0; scNdx < DE_LENGTH_OF_ARRAY(scCases); scNdx++)
2155                 {
2156                     de::MovePtr<tcu::TestCaseGroup> scGroup(new tcu::TestCaseGroup(testCtx, scCases[scNdx].name));
2157                     for (int colNdx = 0; colNdx < DE_LENGTH_OF_ARRAY(colCases); colNdx++)
2158                     {
2159                         const VkComponentTypeKHR inputType  = (VkComponentTypeKHR)dtCases[dtNdx].value[0];
2160                         const VkComponentTypeKHR outputType = (VkComponentTypeKHR)dtCases[dtNdx].value[1];
2161                         const bool isMatrixMul              = isMatrixMulAddOp(testType);
2162 
2163                         // useType isn't used for matrixmul shaders. Don't generate 3 copies of those tests.
2164                         if (isMatrixMul && (useType == UT_KHR_A || useType == UT_KHR_B))
2165                         {
2166                             continue;
2167                         }
2168 
2169                         // NV extension doesn't support mixing signedness
2170                         if (isMatrixMul && (useType == UT_NV) && isSIntType(inputType) != isSIntType(outputType))
2171                         {
2172                             continue;
2173                         }
2174 
2175                         if (!isMatrixMul && inputType != outputType)
2176                             continue;
2177 
2178                         if (isMatrixMul && componentTypeInfo[inputType].bits > componentTypeInfo[outputType].bits)
2179                             continue;
2180 
2181                         if (testType == TT_MUL && useType == UT_NV)
2182                             continue;
2183 
2184                         if (testType == TT_MATRIXMULADD_SATURATED && (isFloatType(inputType) || useType == UT_NV))
2185                             continue;
2186 
2187                         if (testType == TT_MATRIXMULADD_WRAPPING && (isFloatType(inputType) || useType == UT_NV))
2188                             continue;
2189 
2190                         if (testType == TT_MATRIXMULADD_STRIDE0 && useType == UT_NV)
2191                             continue;
2192 
2193                         if (testType == TT_LENGTH && useType != UT_NV &&
2194                             (outputType == VK_COMPONENT_TYPE_SINT8_KHR || outputType == VK_COMPONENT_TYPE_UINT8_KHR))
2195                             continue;
2196 
2197                         CaseDef c = {
2198                             testType,                           //  TestType testtype;
2199                             2u,                                 //  uint32_t subgroupsPerWorkgroupX;
2200                             2u,                                 //  uint32_t subgroupsPerWorkgroupY;
2201                             4u,                                 //  uint32_t workgroupsX;
2202                             4u,                                 //  uint32_t workgroupsY;
2203                             inputType,                          //  VkComponentTypeKHR inputType;
2204                             outputType,                         //  VkComponentTypeKHR outputType;
2205                             !!colCases[colNdx].value,           //  bool colMajor;
2206                             (StorageClass)scCases[scNdx].value, //  StorageClass storageClass;
2207                             useType,                            //  UseType useType;
2208                             sgsCases[sgsNdx].value,             //  SubgroupSizeMode subgroupSizeMode;
2209                             computePipelineConstructionType, //  vk::ComputePipelineConstructionType computePipelineConstructionType;
2210                             1,                               //  uint32_t inputComponentCount;
2211                             1,                               //  uint32_t outputComponentCount;
2212                         };
2213 
2214                         scGroup->addChild(new CooperativeMatrixTestCase(testCtx, colCases[colNdx].name, c));
2215                     }
2216                     dtGroup->addChild(scGroup.release());
2217                 }
2218                 ttGroup->addChild(dtGroup.release());
2219             }
2220             group->addChild(ttGroup.release());
2221         }
2222     }
2223 
2224     {
2225         const string name = string("convert");
2226         const string desc = string("OpFConvert/OpSConvert/OpUConvert/OpBitcast");
2227         de::MovePtr<tcu::TestCaseGroup> ttGroup(new tcu::TestCaseGroup(testCtx, name.c_str()));
2228 
2229         for (int dtNdx1 = 0; dtNdx1 < DE_LENGTH_OF_ARRAY(allTypes); dtNdx1++)
2230         {
2231             for (int dtNdx2 = 0; dtNdx2 < DE_LENGTH_OF_ARRAY(allTypes); dtNdx2++)
2232             {
2233                 const VkComponentTypeKHR inputType  = (VkComponentTypeKHR)allTypes[dtNdx1];
2234                 const VkComponentTypeKHR outputType = (VkComponentTypeKHR)allTypes[dtNdx2];
2235                 const string name2                  = string("input_") + string(componentTypeInfo[inputType].typeName) +
2236                                      string("_output_") + string(componentTypeInfo[outputType].typeName);
2237                 de::MovePtr<tcu::TestCaseGroup> dtGroup(new tcu::TestCaseGroup(testCtx, name2.c_str()));
2238                 for (int scNdx = 0; scNdx < DE_LENGTH_OF_ARRAY(scCases); scNdx++)
2239                 {
2240                     de::MovePtr<tcu::TestCaseGroup> scGroup(new tcu::TestCaseGroup(testCtx, scCases[scNdx].name));
2241                     for (int colNdx = 0; colNdx < DE_LENGTH_OF_ARRAY(colCases); colNdx++)
2242                     {
2243 
2244                         CaseDef c = {
2245                             TT_CONVERT,                         //  TestType testtype;
2246                             2u,                                 //  uint32_t subgroupsPerWorkgroupX;
2247                             2u,                                 //  uint32_t subgroupsPerWorkgroupY;
2248                             4u,                                 //  uint32_t workgroupsX;
2249                             4u,                                 //  uint32_t workgroupsY;
2250                             inputType,                          //  VkComponentTypeKHR inputType;
2251                             outputType,                         //  VkComponentTypeKHR outputType;
2252                             !!colCases[colNdx].value,           //  bool colMajor;
2253                             (StorageClass)scCases[scNdx].value, //  StorageClass storageClass;
2254                             useType,                            //  UseType useType;
2255                             SUBGROUP_SIZE_NONE,                 //  SubgroupSizeMode subgroupSizeMode;
2256                             computePipelineConstructionType, //  vk::ComputePipelineConstructionType computePipelineConstructionType;
2257                             1,                               //  uint32_t inputComponentCount;
2258                             1,                               //  uint32_t outputComponentCount;
2259                         };
2260 
2261                         scGroup->addChild(new CooperativeMatrixTestCase(testCtx, colCases[colNdx].name, c));
2262                     }
2263                     dtGroup->addChild(scGroup.release());
2264                 }
2265                 ttGroup->addChild(dtGroup.release());
2266             }
2267         }
2268         group->addChild(ttGroup.release());
2269     }
2270 
2271     if (useType != UT_NV)
2272     {
2273         de::MovePtr<tcu::TestCaseGroup> ttGroup(
2274             new tcu::TestCaseGroup(testCtx, "multicomponent", "Multicomponent types tests"));
2275         for (int ctNdx = 0; ctNdx < DE_LENGTH_OF_ARRAY(multicomponentTypes); ctNdx++)
2276         {
2277             de::MovePtr<tcu::TestCaseGroup> ctGroup(new tcu::TestCaseGroup(testCtx, multicomponentTypes[ctNdx].name,
2278                                                                            multicomponentTypes[ctNdx].description));
2279             const uint32_t componentCount = multicomponentTypes[ctNdx].componentCount;
2280 
2281             for (int ioNdx = 0; ioNdx < DE_LENGTH_OF_ARRAY(ioTypes); ioNdx++)
2282             {
2283                 de::MovePtr<tcu::TestCaseGroup> ioGroup(
2284                     new tcu::TestCaseGroup(testCtx, ioTypes[ioNdx].name, ioTypes[ioNdx].description));
2285                 const TestType testType             = ioTypes[ioNdx].testType;
2286                 const uint32_t inputComponentCount  = testType == TT_MULTICOMPONENT_LOAD ? componentCount : 1;
2287                 const uint32_t outputComponentCount = testType == TT_MULTICOMPONENT_LOAD ? 1 : componentCount;
2288 
2289                 for (int dtNdx = 0; dtNdx < DE_LENGTH_OF_ARRAY(allTypes); dtNdx++)
2290                 {
2291                     const VkComponentTypeKHR inputType = allTypes[dtNdx];
2292                     const string name                  = componentTypeInfo[inputType].typeName;
2293 
2294                     de::MovePtr<tcu::TestCaseGroup> dtGroup(new tcu::TestCaseGroup(testCtx, name.c_str(), ""));
2295                     for (int scNdx = 0; scNdx < DE_LENGTH_OF_ARRAY(scCases); scNdx++)
2296                     {
2297                         de::MovePtr<tcu::TestCaseGroup> scGroup(
2298                             new tcu::TestCaseGroup(testCtx, scCases[scNdx].name, ""));
2299                         for (int colNdx = 0; colNdx < DE_LENGTH_OF_ARRAY(colCases); colNdx++)
2300                         {
2301                             CaseDef c = {
2302                                 testType,                           //  TestType testtype;
2303                                 2u,                                 //  uint32_t subgroupsPerWorkgroupX;
2304                                 2u,                                 //  uint32_t subgroupsPerWorkgroupY;
2305                                 4u,                                 //  uint32_t workgroupsX;
2306                                 4u,                                 //  uint32_t workgroupsY;
2307                                 inputType,                          //  VkComponentTypeKHR inputType;
2308                                 inputType,                          //  VkComponentTypeKHR outputType;
2309                                 !!colCases[colNdx].value,           //  bool colMajor;
2310                                 (StorageClass)scCases[scNdx].value, //  StorageClass storageClass;
2311                                 useType,                            //  UseType useType;
2312                                 SUBGROUP_SIZE_NONE,                 //  SubgroupSizeMode subgroupSizeMode;
2313                                 computePipelineConstructionType, //  vk::ComputePipelineConstructionType computePipelineConstructionType;
2314                                 inputComponentCount,  //  uint32_t inputComponentCount;
2315                                 outputComponentCount, //  uint32_t outputComponentCount;
2316                             };
2317 
2318                             scGroup->addChild(new CooperativeMatrixTestCase(testCtx, colCases[colNdx].name, c));
2319                         }
2320                         dtGroup->addChild(scGroup.release());
2321                     }
2322                     ioGroup->addChild(dtGroup.release());
2323                 }
2324                 ctGroup->addChild(ioGroup.release());
2325             }
2326             ttGroup->addChild(ctGroup.release());
2327         }
2328         group->addChild(ttGroup.release());
2329     }
2330 
2331     return group.release();
2332 }
2333 
2334 } // namespace
2335 
createCooperativeMatrixTests(tcu::TestContext & testCtx,vk::ComputePipelineConstructionType computePipelineConstructionType)2336 tcu::TestCaseGroup *createCooperativeMatrixTests(tcu::TestContext &testCtx,
2337                                                  vk::ComputePipelineConstructionType computePipelineConstructionType)
2338 {
2339     de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(testCtx, "cooperative_matrix"));
2340 
2341     group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_NV));
2342     group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_KHR_A));
2343     group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_KHR_B));
2344     group->addChild(createCooperativeMatrixTestsInternal(testCtx, computePipelineConstructionType, UT_KHR_Result));
2345 
2346     return group.release();
2347 }
2348 
2349 } // namespace compute
2350 } // namespace vkt
2351