1 /*-------------------------------------------------------------------------
2  * Vulkan Conformance Tests
3  * ------------------------
4  *
5  * Copyright (c) 2020 Valve Corporation.
6  * Copyright (c) 2020 The Khronos Group Inc.
7  *
8  * Licensed under the Apache License, Version 2.0 (the "License");
9  * you may not use this file except in compliance with the License.
10  * You may obtain a copy of the License at
11  *
12  *      http://www.apache.org/licenses/LICENSE-2.0
13  *
14  * Unless required by applicable law or agreed to in writing, software
15  * distributed under the License is distributed on an "AS IS" BASIS,
16  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17  * See the License for the specific language governing permissions and
18  * limitations under the License.
19  *
20  *//*!
21  * \file
22  * \brief SPIR-V tests for VK_AMD_shader_trinary_minmax.
23  *//*--------------------------------------------------------------------*/
24 
25 #include "vktSpvAsmTrinaryMinMaxTests.hpp"
26 #include "vktTestCase.hpp"
27 
28 #include "vkQueryUtil.hpp"
29 #include "vkObjUtil.hpp"
30 #include "vkBufferWithMemory.hpp"
31 #include "vkBuilderUtil.hpp"
32 #include "vkTypeUtil.hpp"
33 #include "vkBarrierUtil.hpp"
34 #include "vkCmdUtil.hpp"
35 
36 #include "tcuStringTemplate.hpp"
37 #include "tcuFloat.hpp"
38 #include "tcuMaybe.hpp"
39 
40 #include "deStringUtil.hpp"
41 #include "deRandom.hpp"
42 #include "deMemory.h"
43 
44 #include <string>
45 #include <sstream>
46 #include <map>
47 #include <vector>
48 #include <algorithm>
49 #include <array>
50 #include <memory>
51 
52 namespace vkt
53 {
54 namespace SpirVAssembly
55 {
56 
57 namespace
58 {
59 
60 enum class OperationType
61 {
62     MIN = 0,
63     MAX = 1,
64     MID = 2,
65 };
66 
67 enum class BaseType
68 {
69     TYPE_INT = 0,
70     TYPE_UINT,
71     TYPE_FLOAT,
72 };
73 
74 // The numeric value is the size in bytes.
75 enum class TypeSize
76 {
77     SIZE_8BIT  = 1,
78     SIZE_16BIT = 2,
79     SIZE_32BIT = 4,
80     SIZE_64BIT = 8,
81 };
82 
83 // The numeric value is the number of components.
84 enum class AggregationType
85 {
86     SCALAR = 1,
87     VEC2   = 2,
88     VEC3   = 3,
89     VEC4   = 4,
90 };
91 
92 struct TestParams
93 {
94     OperationType operation;
95     BaseType baseType;
96     TypeSize typeSize;
97     AggregationType aggregation;
98     uint32_t randomSeed;
99 
100     uint32_t operandSize() const;         // In bytes.
101     uint32_t numComponents() const;       // Number of components.
102     uint32_t effectiveComponents() const; // Effective number of components for size calculation.
103     uint32_t componentSize() const;       // In bytes.
104 };
105 
operandSize() const106 uint32_t TestParams::operandSize() const
107 {
108     return (effectiveComponents() * componentSize());
109 }
110 
numComponents() const111 uint32_t TestParams::numComponents() const
112 {
113     return static_cast<uint32_t>(aggregation);
114 }
115 
effectiveComponents() const116 uint32_t TestParams::effectiveComponents() const
117 {
118     return static_cast<uint32_t>((aggregation == AggregationType::VEC3) ? AggregationType::VEC4 : aggregation);
119 }
120 
componentSize() const121 uint32_t TestParams::componentSize() const
122 {
123     return static_cast<uint32_t>(typeSize);
124 }
125 
126 template <class T>
min3(T op1,T op2,T op3)127 T min3(T op1, T op2, T op3)
128 {
129     return std::min({op1, op2, op3});
130 }
131 
132 template <class T>
max3(T op1,T op2,T op3)133 T max3(T op1, T op2, T op3)
134 {
135     return std::max({op1, op2, op3});
136 }
137 
138 template <class T>
mid3(T op1,T op2,T op3)139 T mid3(T op1, T op2, T op3)
140 {
141     std::array<T, 3> aux{{op1, op2, op3}};
142     std::sort(begin(aux), end(aux));
143     return aux[1];
144 }
145 
146 class OperationManager
147 {
148 public:
149     // Operation and component index in case of error.
150     using OperationComponent = std::pair<uint32_t, uint32_t>;
151     using ComparisonError    = tcu::Maybe<OperationComponent>;
152 
153     OperationManager(const TestParams &params);
154     void genInputBuffer(void *bufferPtr, uint32_t numOperations);
155     void calculateResult(void *referenceBuffer, void *inputBuffer, uint32_t numOperations);
156     ComparisonError compareResults(void *referenceBuffer, void *resultsBuffer, uint32_t numOperations);
157 
158 private:
159     using GenerateCompFunc = void (*)(de::Random &, void *); // Write a generated component to the given location.
160 
161     // Generator variants to populate input buffer.
genInt8(de::Random & rnd,void * ptr)162     static void genInt8(de::Random &rnd, void *ptr)
163     {
164         *reinterpret_cast<int8_t *>(ptr) = static_cast<int8_t>(rnd.getUint8());
165     }
genUint8(de::Random & rnd,void * ptr)166     static void genUint8(de::Random &rnd, void *ptr)
167     {
168         *reinterpret_cast<uint8_t *>(ptr) = rnd.getUint8();
169     }
genInt16(de::Random & rnd,void * ptr)170     static void genInt16(de::Random &rnd, void *ptr)
171     {
172         *reinterpret_cast<int16_t *>(ptr) = static_cast<int16_t>(rnd.getUint16());
173     }
genUint16(de::Random & rnd,void * ptr)174     static void genUint16(de::Random &rnd, void *ptr)
175     {
176         *reinterpret_cast<uint16_t *>(ptr) = rnd.getUint16();
177     }
genInt32(de::Random & rnd,void * ptr)178     static void genInt32(de::Random &rnd, void *ptr)
179     {
180         *reinterpret_cast<int32_t *>(ptr) = static_cast<int32_t>(rnd.getUint32());
181     }
genUint32(de::Random & rnd,void * ptr)182     static void genUint32(de::Random &rnd, void *ptr)
183     {
184         *reinterpret_cast<uint32_t *>(ptr) = rnd.getUint32();
185     }
genInt64(de::Random & rnd,void * ptr)186     static void genInt64(de::Random &rnd, void *ptr)
187     {
188         *reinterpret_cast<int64_t *>(ptr) = static_cast<int64_t>(rnd.getUint64());
189     }
genUint64(de::Random & rnd,void * ptr)190     static void genUint64(de::Random &rnd, void *ptr)
191     {
192         *reinterpret_cast<uint64_t *>(ptr) = rnd.getUint64();
193     }
194 
195     // Helper template for float generators.
196     // T must be a tcu::Float instantiation.
197     // Attempts to generate +-Inf once every 10 times and avoid denormals.
198     template <class T>
genFloat(de::Random & rnd,void * ptr)199     static inline void genFloat(de::Random &rnd, void *ptr)
200     {
201         T *valuePtr = reinterpret_cast<T *>(ptr);
202         if (rnd.getInt(1, 10) == 1)
203             *valuePtr = T::inf(rnd.getBool() ? 1 : -1);
204         else
205         {
206             do
207             {
208                 *valuePtr = T{rnd.getDouble(T::largestNormal(-1).asDouble(), T::largestNormal(1).asDouble())};
209             } while (valuePtr->isDenorm());
210         }
211     }
212 
genFloat16(de::Random & rnd,void * ptr)213     static void genFloat16(de::Random &rnd, void *ptr)
214     {
215         genFloat<tcu::Float16>(rnd, ptr);
216     }
genFloat32(de::Random & rnd,void * ptr)217     static void genFloat32(de::Random &rnd, void *ptr)
218     {
219         genFloat<tcu::Float32>(rnd, ptr);
220     }
genFloat64(de::Random & rnd,void * ptr)221     static void genFloat64(de::Random &rnd, void *ptr)
222     {
223         genFloat<tcu::Float64>(rnd, ptr);
224     }
225 
226     // An operation function writes an output value given 3 input values.
227     using OperationFunc = void (*)(void *, const void *, const void *, const void *);
228 
229     // Helper template used below.
230     template <class T, class F>
runOpFunc(F f,void * out,const void * in1,const void * in2,const void * in3)231     static inline void runOpFunc(F f, void *out, const void *in1, const void *in2, const void *in3)
232     {
233         *reinterpret_cast<T *>(out) =
234             f(*reinterpret_cast<const T *>(in1), *reinterpret_cast<const T *>(in2), *reinterpret_cast<const T *>(in3));
235     }
236 
237     // Apply an operation in software to a given group of components and calculate result.
minInt8(void * out,const void * in1,const void * in2,const void * in3)238     static void minInt8(void *out, const void *in1, const void *in2, const void *in3)
239     {
240         runOpFunc<int8_t>(min3<int8_t>, out, in1, in2, in3);
241     }
maxInt8(void * out,const void * in1,const void * in2,const void * in3)242     static void maxInt8(void *out, const void *in1, const void *in2, const void *in3)
243     {
244         runOpFunc<int8_t>(max3<int8_t>, out, in1, in2, in3);
245     }
midInt8(void * out,const void * in1,const void * in2,const void * in3)246     static void midInt8(void *out, const void *in1, const void *in2, const void *in3)
247     {
248         runOpFunc<int8_t>(mid3<int8_t>, out, in1, in2, in3);
249     }
minUint8(void * out,const void * in1,const void * in2,const void * in3)250     static void minUint8(void *out, const void *in1, const void *in2, const void *in3)
251     {
252         runOpFunc<uint8_t>(min3<uint8_t>, out, in1, in2, in3);
253     }
maxUint8(void * out,const void * in1,const void * in2,const void * in3)254     static void maxUint8(void *out, const void *in1, const void *in2, const void *in3)
255     {
256         runOpFunc<uint8_t>(max3<uint8_t>, out, in1, in2, in3);
257     }
midUint8(void * out,const void * in1,const void * in2,const void * in3)258     static void midUint8(void *out, const void *in1, const void *in2, const void *in3)
259     {
260         runOpFunc<uint8_t>(mid3<uint8_t>, out, in1, in2, in3);
261     }
minInt16(void * out,const void * in1,const void * in2,const void * in3)262     static void minInt16(void *out, const void *in1, const void *in2, const void *in3)
263     {
264         runOpFunc<int16_t>(min3<int16_t>, out, in1, in2, in3);
265     }
maxInt16(void * out,const void * in1,const void * in2,const void * in3)266     static void maxInt16(void *out, const void *in1, const void *in2, const void *in3)
267     {
268         runOpFunc<int16_t>(max3<int16_t>, out, in1, in2, in3);
269     }
midInt16(void * out,const void * in1,const void * in2,const void * in3)270     static void midInt16(void *out, const void *in1, const void *in2, const void *in3)
271     {
272         runOpFunc<int16_t>(mid3<int16_t>, out, in1, in2, in3);
273     }
minUint16(void * out,const void * in1,const void * in2,const void * in3)274     static void minUint16(void *out, const void *in1, const void *in2, const void *in3)
275     {
276         runOpFunc<uint16_t>(min3<uint16_t>, out, in1, in2, in3);
277     }
maxUint16(void * out,const void * in1,const void * in2,const void * in3)278     static void maxUint16(void *out, const void *in1, const void *in2, const void *in3)
279     {
280         runOpFunc<uint16_t>(max3<uint16_t>, out, in1, in2, in3);
281     }
midUint16(void * out,const void * in1,const void * in2,const void * in3)282     static void midUint16(void *out, const void *in1, const void *in2, const void *in3)
283     {
284         runOpFunc<uint16_t>(mid3<uint16_t>, out, in1, in2, in3);
285     }
minInt32(void * out,const void * in1,const void * in2,const void * in3)286     static void minInt32(void *out, const void *in1, const void *in2, const void *in3)
287     {
288         runOpFunc<int32_t>(min3<int32_t>, out, in1, in2, in3);
289     }
maxInt32(void * out,const void * in1,const void * in2,const void * in3)290     static void maxInt32(void *out, const void *in1, const void *in2, const void *in3)
291     {
292         runOpFunc<int32_t>(max3<int32_t>, out, in1, in2, in3);
293     }
midInt32(void * out,const void * in1,const void * in2,const void * in3)294     static void midInt32(void *out, const void *in1, const void *in2, const void *in3)
295     {
296         runOpFunc<int32_t>(mid3<int32_t>, out, in1, in2, in3);
297     }
minUint32(void * out,const void * in1,const void * in2,const void * in3)298     static void minUint32(void *out, const void *in1, const void *in2, const void *in3)
299     {
300         runOpFunc<uint32_t>(min3<uint32_t>, out, in1, in2, in3);
301     }
maxUint32(void * out,const void * in1,const void * in2,const void * in3)302     static void maxUint32(void *out, const void *in1, const void *in2, const void *in3)
303     {
304         runOpFunc<uint32_t>(max3<uint32_t>, out, in1, in2, in3);
305     }
midUint32(void * out,const void * in1,const void * in2,const void * in3)306     static void midUint32(void *out, const void *in1, const void *in2, const void *in3)
307     {
308         runOpFunc<uint32_t>(mid3<uint32_t>, out, in1, in2, in3);
309     }
minInt64(void * out,const void * in1,const void * in2,const void * in3)310     static void minInt64(void *out, const void *in1, const void *in2, const void *in3)
311     {
312         runOpFunc<int64_t>(min3<int64_t>, out, in1, in2, in3);
313     }
maxInt64(void * out,const void * in1,const void * in2,const void * in3)314     static void maxInt64(void *out, const void *in1, const void *in2, const void *in3)
315     {
316         runOpFunc<int64_t>(max3<int64_t>, out, in1, in2, in3);
317     }
midInt64(void * out,const void * in1,const void * in2,const void * in3)318     static void midInt64(void *out, const void *in1, const void *in2, const void *in3)
319     {
320         runOpFunc<int64_t>(mid3<int64_t>, out, in1, in2, in3);
321     }
minUint64(void * out,const void * in1,const void * in2,const void * in3)322     static void minUint64(void *out, const void *in1, const void *in2, const void *in3)
323     {
324         runOpFunc<uint64_t>(min3<uint64_t>, out, in1, in2, in3);
325     }
maxUint64(void * out,const void * in1,const void * in2,const void * in3)326     static void maxUint64(void *out, const void *in1, const void *in2, const void *in3)
327     {
328         runOpFunc<uint64_t>(max3<uint64_t>, out, in1, in2, in3);
329     }
midUint64(void * out,const void * in1,const void * in2,const void * in3)330     static void midUint64(void *out, const void *in1, const void *in2, const void *in3)
331     {
332         runOpFunc<uint64_t>(mid3<uint64_t>, out, in1, in2, in3);
333     }
minFloat16(void * out,const void * in1,const void * in2,const void * in3)334     static void minFloat16(void *out, const void *in1, const void *in2, const void *in3)
335     {
336         runOpFunc<tcu::Float16>(min3<tcu::Float16>, out, in1, in2, in3);
337     }
maxFloat16(void * out,const void * in1,const void * in2,const void * in3)338     static void maxFloat16(void *out, const void *in1, const void *in2, const void *in3)
339     {
340         runOpFunc<tcu::Float16>(max3<tcu::Float16>, out, in1, in2, in3);
341     }
midFloat16(void * out,const void * in1,const void * in2,const void * in3)342     static void midFloat16(void *out, const void *in1, const void *in2, const void *in3)
343     {
344         runOpFunc<tcu::Float16>(mid3<tcu::Float16>, out, in1, in2, in3);
345     }
minFloat32(void * out,const void * in1,const void * in2,const void * in3)346     static void minFloat32(void *out, const void *in1, const void *in2, const void *in3)
347     {
348         runOpFunc<tcu::Float32>(min3<tcu::Float32>, out, in1, in2, in3);
349     }
maxFloat32(void * out,const void * in1,const void * in2,const void * in3)350     static void maxFloat32(void *out, const void *in1, const void *in2, const void *in3)
351     {
352         runOpFunc<tcu::Float32>(max3<tcu::Float32>, out, in1, in2, in3);
353     }
midFloat32(void * out,const void * in1,const void * in2,const void * in3)354     static void midFloat32(void *out, const void *in1, const void *in2, const void *in3)
355     {
356         runOpFunc<tcu::Float32>(mid3<tcu::Float32>, out, in1, in2, in3);
357     }
minFloat64(void * out,const void * in1,const void * in2,const void * in3)358     static void minFloat64(void *out, const void *in1, const void *in2, const void *in3)
359     {
360         runOpFunc<tcu::Float64>(min3<tcu::Float64>, out, in1, in2, in3);
361     }
maxFloat64(void * out,const void * in1,const void * in2,const void * in3)362     static void maxFloat64(void *out, const void *in1, const void *in2, const void *in3)
363     {
364         runOpFunc<tcu::Float64>(max3<tcu::Float64>, out, in1, in2, in3);
365     }
midFloat64(void * out,const void * in1,const void * in2,const void * in3)366     static void midFloat64(void *out, const void *in1, const void *in2, const void *in3)
367     {
368         runOpFunc<tcu::Float64>(mid3<tcu::Float64>, out, in1, in2, in3);
369     }
370 
371     // Case for accessing the functions map.
372     struct Case
373     {
374         BaseType type;
375         TypeSize size;
376         OperationType operation;
377 
378         // This is required for sorting in the map.
operator <vkt::SpirVAssembly::__anon44d5838c0111::OperationManager::Case379         bool operator<(const Case &other) const
380         {
381             return (toArray() < other.toArray());
382         }
383 
384     private:
toArrayvkt::SpirVAssembly::__anon44d5838c0111::OperationManager::Case385         std::array<int, 3> toArray() const
386         {
387             return std::array<int, 3>{{static_cast<int>(type), static_cast<int>(size), static_cast<int>(operation)}};
388         }
389     };
390 
391     // Helper map to correctly choose the right generator and operation function for the specific case being tested.
392     using FuncPair = std::pair<GenerateCompFunc, OperationFunc>;
393     using CaseMap  = std::map<Case, FuncPair>;
394 
395     static const CaseMap kFunctionsMap;
396 
397     GenerateCompFunc m_chosenGenerator;
398     OperationFunc m_chosenOperation;
399     de::Random m_random;
400 
401     const uint32_t m_operandSize;
402     const uint32_t m_numComponents;
403     const uint32_t m_componentSize;
404 };
405 
406 // This map is used to choose how to generate inputs for each case and which operation to run on the CPU to calculate the reference
407 // results for the generated inputs.
408 const OperationManager::CaseMap OperationManager::kFunctionsMap = {
409     {{BaseType::TYPE_INT, TypeSize::SIZE_8BIT, OperationType::MIN}, {genInt8, minInt8}},
410     {{BaseType::TYPE_INT, TypeSize::SIZE_8BIT, OperationType::MAX}, {genInt8, maxInt8}},
411     {{BaseType::TYPE_INT, TypeSize::SIZE_8BIT, OperationType::MID}, {genInt8, midInt8}},
412     {{BaseType::TYPE_INT, TypeSize::SIZE_16BIT, OperationType::MIN}, {genInt16, minInt16}},
413     {{BaseType::TYPE_INT, TypeSize::SIZE_16BIT, OperationType::MAX}, {genInt16, maxInt16}},
414     {{BaseType::TYPE_INT, TypeSize::SIZE_16BIT, OperationType::MID}, {genInt16, midInt16}},
415     {{BaseType::TYPE_INT, TypeSize::SIZE_32BIT, OperationType::MIN}, {genInt32, minInt32}},
416     {{BaseType::TYPE_INT, TypeSize::SIZE_32BIT, OperationType::MAX}, {genInt32, maxInt32}},
417     {{BaseType::TYPE_INT, TypeSize::SIZE_32BIT, OperationType::MID}, {genInt32, midInt32}},
418     {{BaseType::TYPE_INT, TypeSize::SIZE_64BIT, OperationType::MIN}, {genInt64, minInt64}},
419     {{BaseType::TYPE_INT, TypeSize::SIZE_64BIT, OperationType::MAX}, {genInt64, maxInt64}},
420     {{BaseType::TYPE_INT, TypeSize::SIZE_64BIT, OperationType::MID}, {genInt64, midInt64}},
421     {{BaseType::TYPE_UINT, TypeSize::SIZE_8BIT, OperationType::MIN}, {genUint8, minUint8}},
422     {{BaseType::TYPE_UINT, TypeSize::SIZE_8BIT, OperationType::MAX}, {genUint8, maxUint8}},
423     {{BaseType::TYPE_UINT, TypeSize::SIZE_8BIT, OperationType::MID}, {genUint8, midUint8}},
424     {{BaseType::TYPE_UINT, TypeSize::SIZE_16BIT, OperationType::MIN}, {genUint16, minUint16}},
425     {{BaseType::TYPE_UINT, TypeSize::SIZE_16BIT, OperationType::MAX}, {genUint16, maxUint16}},
426     {{BaseType::TYPE_UINT, TypeSize::SIZE_16BIT, OperationType::MID}, {genUint16, midUint16}},
427     {{BaseType::TYPE_UINT, TypeSize::SIZE_32BIT, OperationType::MIN}, {genUint32, minUint32}},
428     {{BaseType::TYPE_UINT, TypeSize::SIZE_32BIT, OperationType::MAX}, {genUint32, maxUint32}},
429     {{BaseType::TYPE_UINT, TypeSize::SIZE_32BIT, OperationType::MID}, {genUint32, midUint32}},
430     {{BaseType::TYPE_UINT, TypeSize::SIZE_64BIT, OperationType::MIN}, {genUint64, minUint64}},
431     {{BaseType::TYPE_UINT, TypeSize::SIZE_64BIT, OperationType::MAX}, {genUint64, maxUint64}},
432     {{BaseType::TYPE_UINT, TypeSize::SIZE_64BIT, OperationType::MID}, {genUint64, midUint64}},
433     {{BaseType::TYPE_FLOAT, TypeSize::SIZE_16BIT, OperationType::MIN}, {genFloat16, minFloat16}},
434     {{BaseType::TYPE_FLOAT, TypeSize::SIZE_16BIT, OperationType::MAX}, {genFloat16, maxFloat16}},
435     {{BaseType::TYPE_FLOAT, TypeSize::SIZE_16BIT, OperationType::MID}, {genFloat16, midFloat16}},
436     {{BaseType::TYPE_FLOAT, TypeSize::SIZE_32BIT, OperationType::MIN}, {genFloat32, minFloat32}},
437     {{BaseType::TYPE_FLOAT, TypeSize::SIZE_32BIT, OperationType::MAX}, {genFloat32, maxFloat32}},
438     {{BaseType::TYPE_FLOAT, TypeSize::SIZE_32BIT, OperationType::MID}, {genFloat32, midFloat32}},
439     {{BaseType::TYPE_FLOAT, TypeSize::SIZE_64BIT, OperationType::MIN}, {genFloat64, minFloat64}},
440     {{BaseType::TYPE_FLOAT, TypeSize::SIZE_64BIT, OperationType::MAX}, {genFloat64, maxFloat64}},
441     {{BaseType::TYPE_FLOAT, TypeSize::SIZE_64BIT, OperationType::MID}, {genFloat64, midFloat64}},
442 };
443 
OperationManager(const TestParams & params)444 OperationManager::OperationManager(const TestParams &params)
445     : m_chosenGenerator{nullptr}
446     , m_chosenOperation{nullptr}
447     , m_random{params.randomSeed}
448     , m_operandSize{params.operandSize()}
449     , m_numComponents{params.numComponents()}
450     , m_componentSize{params.componentSize()}
451 {
452     // Choose generator and CPU operation from the map.
453     const Case paramCase{params.baseType, params.typeSize, params.operation};
454     const auto iter = kFunctionsMap.find(paramCase);
455 
456     DE_ASSERT(iter != kFunctionsMap.end());
457     m_chosenGenerator = iter->second.first;
458     m_chosenOperation = iter->second.second;
459 }
460 
461 // See TrinaryMinMaxCase::initPrograms for a description of the input buffer format.
462 // Generates inputs with the chosen generator.
genInputBuffer(void * bufferPtr,uint32_t numOperations)463 void OperationManager::genInputBuffer(void *bufferPtr, uint32_t numOperations)
464 {
465     const uint32_t numOperands = numOperations * 3u;
466     char *byteBuffer           = reinterpret_cast<char *>(bufferPtr);
467 
468     for (uint32_t opIdx = 0u; opIdx < numOperands; ++opIdx)
469     {
470         char *compPtr = byteBuffer;
471         for (uint32_t compIdx = 0u; compIdx < m_numComponents; ++compIdx)
472         {
473             m_chosenGenerator(m_random, reinterpret_cast<void *>(compPtr));
474             compPtr += m_componentSize;
475         }
476         byteBuffer += m_operandSize;
477     }
478 }
479 
480 // See TrinaryMinMaxCase::initPrograms for a description of the input and output buffer formats.
481 // Calculates reference results on the CPU using the chosen operation and the input buffer.
calculateResult(void * referenceBuffer,void * inputBuffer,uint32_t numOperations)482 void OperationManager::calculateResult(void *referenceBuffer, void *inputBuffer, uint32_t numOperations)
483 {
484     char *outputByte = reinterpret_cast<char *>(referenceBuffer);
485     char *inputByte  = reinterpret_cast<char *>(inputBuffer);
486 
487     for (uint32_t opIdx = 0u; opIdx < numOperations; ++opIdx)
488     {
489         char *res = outputByte;
490         char *op1 = inputByte;
491         char *op2 = inputByte + m_operandSize;
492         char *op3 = inputByte + m_operandSize * 2u;
493 
494         for (uint32_t compIdx = 0u; compIdx < m_numComponents; ++compIdx)
495         {
496             m_chosenOperation(reinterpret_cast<void *>(res), reinterpret_cast<void *>(op1),
497                               reinterpret_cast<void *>(op2), reinterpret_cast<void *>(op3));
498 
499             res += m_componentSize;
500             op1 += m_componentSize;
501             op2 += m_componentSize;
502             op3 += m_componentSize;
503         }
504 
505         outputByte += m_operandSize;
506         inputByte += m_operandSize * 3u;
507     }
508 }
509 
510 // See TrinaryMinMaxCase::initPrograms for a description of the output buffer format.
compareResults(void * referenceBuffer,void * resultsBuffer,uint32_t numOperations)511 OperationManager::ComparisonError OperationManager::compareResults(void *referenceBuffer, void *resultsBuffer,
512                                                                    uint32_t numOperations)
513 {
514     char *referenceBytes = reinterpret_cast<char *>(referenceBuffer);
515     char *resultsBytes   = reinterpret_cast<char *>(resultsBuffer);
516 
517     for (uint32_t opIdx = 0u; opIdx < numOperations; ++opIdx)
518     {
519         char *refCompBytes = referenceBytes;
520         char *resCompBytes = resultsBytes;
521 
522         for (uint32_t compIdx = 0u; compIdx < m_numComponents; ++compIdx)
523         {
524             if (deMemCmp(refCompBytes, resCompBytes, m_componentSize) != 0)
525                 return tcu::just(OperationComponent(opIdx, compIdx));
526             refCompBytes += m_componentSize;
527             resCompBytes += m_componentSize;
528         }
529         referenceBytes += m_operandSize;
530         resultsBytes += m_operandSize;
531     }
532 
533     return tcu::Nothing;
534 }
535 
536 class TrinaryMinMaxCase : public vkt::TestCase
537 {
538 public:
539     using ReplacementsMap = std::map<std::string, std::string>;
540 
541     TrinaryMinMaxCase(tcu::TestContext &testCtx, const std::string &name, const TestParams &params);
~TrinaryMinMaxCase(void)542     virtual ~TrinaryMinMaxCase(void)
543     {
544     }
545 
546     virtual void initPrograms(vk::SourceCollections &programCollection) const;
547     virtual TestInstance *createInstance(Context &context) const;
548     virtual void checkSupport(Context &context) const;
549     ReplacementsMap getSpirVReplacements(void) const;
550 
551     static const uint32_t kArraySize;
552 
553 private:
554     TestParams m_params;
555 };
556 
557 const uint32_t TrinaryMinMaxCase::kArraySize = 100u;
558 
559 class TrinaryMinMaxInstance : public vkt::TestInstance
560 {
561 public:
562     TrinaryMinMaxInstance(Context &context, const TestParams &params);
~TrinaryMinMaxInstance(void)563     virtual ~TrinaryMinMaxInstance(void)
564     {
565     }
566 
567     virtual tcu::TestStatus iterate(void);
568 
569 private:
570     TestParams m_params;
571 };
572 
TrinaryMinMaxCase(tcu::TestContext & testCtx,const std::string & name,const TestParams & params)573 TrinaryMinMaxCase::TrinaryMinMaxCase(tcu::TestContext &testCtx, const std::string &name, const TestParams &params)
574     : vkt::TestCase(testCtx, name)
575     , m_params(params)
576 {
577 }
578 
createInstance(Context & context) const579 TestInstance *TrinaryMinMaxCase::createInstance(Context &context) const
580 {
581     return new TrinaryMinMaxInstance{context, m_params};
582 }
583 
checkSupport(Context & context) const584 void TrinaryMinMaxCase::checkSupport(Context &context) const
585 {
586     // These are always required.
587     context.requireInstanceFunctionality("VK_KHR_get_physical_device_properties2");
588     context.requireDeviceFunctionality("VK_KHR_storage_buffer_storage_class");
589     context.requireDeviceFunctionality("VK_AMD_shader_trinary_minmax");
590 
591     const auto devFeatures          = context.getDeviceFeatures();
592     const auto storage16BitFeatures = context.get16BitStorageFeatures();
593     const auto storage8BitFeatures  = context.get8BitStorageFeatures();
594     const auto shaderFeatures       = context.getShaderFloat16Int8Features();
595 
596     // Storage features.
597     if (m_params.typeSize == TypeSize::SIZE_8BIT)
598     {
599         // We will be using 8-bit types in storage buffers.
600         context.requireDeviceFunctionality("VK_KHR_8bit_storage");
601         if (!storage8BitFeatures.storageBuffer8BitAccess)
602             TCU_THROW(NotSupportedError, "8-bit storage buffer access not supported");
603     }
604     else if (m_params.typeSize == TypeSize::SIZE_16BIT)
605     {
606         // We will be using 16-bit types in storage buffers.
607         context.requireDeviceFunctionality("VK_KHR_16bit_storage");
608         if (!storage16BitFeatures.storageBuffer16BitAccess)
609             TCU_THROW(NotSupportedError, "16-bit storage buffer access not supported");
610     }
611 
612     // Shader type features.
613     if (m_params.baseType == BaseType::TYPE_INT || m_params.baseType == BaseType::TYPE_UINT)
614     {
615         if (m_params.typeSize == TypeSize::SIZE_8BIT && !shaderFeatures.shaderInt8)
616             TCU_THROW(NotSupportedError, "8-bit integers not supported in shaders");
617         else if (m_params.typeSize == TypeSize::SIZE_16BIT && !devFeatures.shaderInt16)
618             TCU_THROW(NotSupportedError, "16-bit integers not supported in shaders");
619         else if (m_params.typeSize == TypeSize::SIZE_64BIT && !devFeatures.shaderInt64)
620             TCU_THROW(NotSupportedError, "64-bit integers not supported in shaders");
621     }
622     else // BaseType::TYPE_FLOAT
623     {
624         DE_ASSERT(m_params.typeSize != TypeSize::SIZE_8BIT);
625         if (m_params.typeSize == TypeSize::SIZE_16BIT && !shaderFeatures.shaderFloat16)
626             TCU_THROW(NotSupportedError, "16-bit floats not supported in shaders");
627         else if (m_params.typeSize == TypeSize::SIZE_64BIT && !devFeatures.shaderFloat64)
628             TCU_THROW(NotSupportedError, "64-bit floats not supported in shaders");
629     }
630 }
631 
getSpirVReplacements(void) const632 TrinaryMinMaxCase::ReplacementsMap TrinaryMinMaxCase::getSpirVReplacements(void) const
633 {
634     ReplacementsMap replacements;
635 
636     // Capabilities and extensions.
637     if (m_params.baseType == BaseType::TYPE_INT || m_params.baseType == BaseType::TYPE_UINT)
638     {
639         if (m_params.typeSize == TypeSize::SIZE_8BIT)
640             replacements["CAPABILITIES"] += "OpCapability Int8\n";
641         else if (m_params.typeSize == TypeSize::SIZE_16BIT)
642             replacements["CAPABILITIES"] += "OpCapability Int16\n";
643         else if (m_params.typeSize == TypeSize::SIZE_64BIT)
644             replacements["CAPABILITIES"] += "OpCapability Int64\n";
645     }
646     else // BaseType::TYPE_FLOAT
647     {
648         if (m_params.typeSize == TypeSize::SIZE_16BIT)
649             replacements["CAPABILITIES"] += "OpCapability Float16\n";
650         else if (m_params.typeSize == TypeSize::SIZE_64BIT)
651             replacements["CAPABILITIES"] += "OpCapability Float64\n";
652     }
653 
654     if (m_params.typeSize == TypeSize::SIZE_8BIT)
655     {
656         replacements["CAPABILITIES"] += "OpCapability StorageBuffer8BitAccess\n";
657         replacements["EXTENSIONS"] += "OpExtension \"SPV_KHR_8bit_storage\"\n";
658     }
659     else if (m_params.typeSize == TypeSize::SIZE_16BIT)
660     {
661         replacements["CAPABILITIES"] += "OpCapability StorageBuffer16BitAccess\n";
662         replacements["EXTENSIONS"] += "OpExtension \"SPV_KHR_16bit_storage\"\n";
663     }
664 
665     // Operand size in bytes.
666     const uint32_t opSize               = m_params.operandSize();
667     replacements["OPERAND_SIZE"]        = de::toString(opSize);
668     replacements["OPERAND_SIZE_2TIMES"] = de::toString(opSize * 2u);
669     replacements["OPERAND_SIZE_3TIMES"] = de::toString(opSize * 3u);
670 
671     // Array size.
672     replacements["ARRAY_SIZE"] = de::toString(kArraySize);
673 
674     // Types and operand type: define the base integer or float type and the vector type if needed, then set the operand type replacement.
675     const std::string vecSize = de::toString(m_params.numComponents());
676     const std::string bitSize = de::toString(m_params.componentSize() * 8u);
677 
678     if (m_params.baseType == BaseType::TYPE_INT || m_params.baseType == BaseType::TYPE_UINT)
679     {
680         const std::string signBit    = (m_params.baseType == BaseType::TYPE_INT ? "1" : "0");
681         const std::string typePrefix = (m_params.baseType == BaseType::TYPE_UINT ? "u" : "");
682         std::string baseTypeName;
683 
684         // 32-bit integers are already defined in the default shader text.
685         if (m_params.typeSize != TypeSize::SIZE_32BIT)
686         {
687             baseTypeName = typePrefix + "int" + bitSize + "_t";
688             replacements["TYPES"] += "%" + baseTypeName + " = OpTypeInt " + bitSize + " " + signBit + "\n";
689         }
690         else
691         {
692             baseTypeName = typePrefix + "int";
693         }
694 
695         if (m_params.aggregation == AggregationType::SCALAR)
696         {
697             replacements["OPERAND_TYPE"] = "%" + baseTypeName;
698         }
699         else
700         {
701             const std::string typeName = "%v" + vecSize + baseTypeName;
702             // %v3uint is already defined in the default shader text.
703             if (m_params.baseType != BaseType::TYPE_UINT || m_params.typeSize != TypeSize::SIZE_32BIT ||
704                 m_params.aggregation != AggregationType::VEC3)
705             {
706                 replacements["TYPES"] += typeName + " = OpTypeVector %" + baseTypeName + " " + vecSize + "\n";
707             }
708             replacements["OPERAND_TYPE"] = typeName;
709         }
710     }
711     else // BaseType::TYPE_FLOAT
712     {
713         const std::string baseTypeName = "float" + bitSize + "_t";
714         replacements["TYPES"] += "%" + baseTypeName + " = OpTypeFloat " + bitSize + "\n";
715 
716         if (m_params.aggregation == AggregationType::SCALAR)
717         {
718             replacements["OPERAND_TYPE"] = "%" + baseTypeName;
719         }
720         else
721         {
722             const std::string typeName = "%v" + vecSize + baseTypeName;
723             replacements["TYPES"] += typeName + " = OpTypeVector %" + baseTypeName + " " + vecSize + "\n";
724             replacements["OPERAND_TYPE"] = typeName;
725         }
726     }
727 
728     // Operation name.
729     const static std::vector<std::string> opTypeStr = {"Min", "Max", "Mid"};
730     const static std::vector<std::string> opPrefix  = {"S", "U", "F"};
731     replacements["OPERATION_NAME"] =
732         opPrefix[static_cast<int>(m_params.baseType)] + opTypeStr[static_cast<int>(m_params.operation)] + "3AMD";
733 
734     return replacements;
735 }
736 
initPrograms(vk::SourceCollections & programCollection) const737 void TrinaryMinMaxCase::initPrograms(vk::SourceCollections &programCollection) const
738 {
739     // The shader below uses an input buffer at set 0 binding 0 and an output buffer at set 0 binding 1. Their structure is similar
740     // to the code below:
741     //
742     //      struct Operands {
743     //              <type> op1;
744     //              <type> op2;
745     //              <type> op3;
746     //      };
747     //
748     //      layout (set=0, binding=0, std430) buffer InputBlock {
749     //              Operands operands[<arraysize>];
750     //      };
751     //
752     //      layout (set=0, binding=1, std430) buffer OutputBlock {
753     //              <type> result[<arraysize>];
754     //      };
755     //
756     // Where <type> can be int8_t, uint32_t, float, etc. So in the input buffer the operands are "grouped" per operation and can
757     // have several components each and the output buffer contains an array of results, one per trio of input operands.
758 
759     std::ostringstream shaderStr;
760     shaderStr << "; SPIR-V\n"
761               << "; Version: 1.5\n"
762               << "                            OpCapability Shader\n"
763               << "${CAPABILITIES:opt}"
764               << "                            OpExtension \"SPV_KHR_storage_buffer_storage_class\"\n"
765               << "                            OpExtension \"SPV_AMD_shader_trinary_minmax\"\n"
766               << "${EXTENSIONS:opt}"
767               << "                  %std450 = OpExtInstImport \"GLSL.std.450\"\n"
768               << "                 %trinary = OpExtInstImport \"SPV_AMD_shader_trinary_minmax\"\n"
769               << "                            OpMemoryModel Logical GLSL450\n"
770               << "                            OpEntryPoint GLCompute %main \"main\" %gl_GlobalInvocationID "
771                  "%output_buffer %input_buffer\n"
772               << "                            OpExecutionMode %main LocalSize 1 1 1\n"
773               << "                            OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId\n"
774               << "                            OpDecorate %results_array_t ArrayStride ${OPERAND_SIZE}\n"
775               << "                            OpMemberDecorate %OutputBlock 0 Offset 0\n"
776               << "                            OpDecorate %OutputBlock Block\n"
777               << "                            OpDecorate %output_buffer DescriptorSet 0\n"
778               << "                            OpDecorate %output_buffer Binding 1\n"
779               << "                            OpMemberDecorate %Operands 0 Offset 0\n"
780               << "                            OpMemberDecorate %Operands 1 Offset ${OPERAND_SIZE}\n"
781               << "                            OpMemberDecorate %Operands 2 Offset ${OPERAND_SIZE_2TIMES}\n"
782               << "                            OpDecorate %_arr_Operands_arraysize ArrayStride ${OPERAND_SIZE_3TIMES}\n"
783               << "                            OpMemberDecorate %InputBlock 0 Offset 0\n"
784               << "                            OpDecorate %InputBlock Block\n"
785               << "                            OpDecorate %input_buffer DescriptorSet 0\n"
786               << "                            OpDecorate %input_buffer Binding 0\n"
787               << "                            OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize\n"
788               << "                    %void = OpTypeVoid\n"
789               << "                %voidfunc = OpTypeFunction %void\n"
790               << "                     %int = OpTypeInt 32 1\n"
791               << "                    %uint = OpTypeInt 32 0\n"
792               << "                  %v3uint = OpTypeVector %uint 3\n"
793               << "${TYPES:opt}"
794               << "                   %int_0 = OpConstant %int 0\n"
795               << "                   %int_1 = OpConstant %int 1\n"
796               << "                   %int_2 = OpConstant %int 2\n"
797               << "                  %uint_1 = OpConstant %uint 1\n"
798               << "                  %uint_0 = OpConstant %uint 0\n"
799               << "               %arraysize = OpConstant %uint ${ARRAY_SIZE}\n"
800               << "      %_ptr_Function_uint = OpTypePointer Function %uint\n"
801               << "       %_ptr_Input_v3uint = OpTypePointer Input %v3uint\n"
802               << "   %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input\n"
803               << "         %_ptr_Input_uint = OpTypePointer Input %uint\n"
804               << "         %results_array_t = OpTypeArray ${OPERAND_TYPE} %arraysize\n"
805               << "                %Operands = OpTypeStruct ${OPERAND_TYPE} ${OPERAND_TYPE} ${OPERAND_TYPE}\n"
806               << " %_arr_Operands_arraysize = OpTypeArray %Operands %arraysize\n"
807               << "             %OutputBlock = OpTypeStruct %results_array_t\n"
808               << "              %InputBlock = OpTypeStruct %_arr_Operands_arraysize\n"
809               << "%_ptr_Uniform_OutputBlock = OpTypePointer StorageBuffer %OutputBlock\n"
810               << " %_ptr_Uniform_InputBlock = OpTypePointer StorageBuffer %InputBlock\n"
811               << "           %output_buffer = OpVariable %_ptr_Uniform_OutputBlock StorageBuffer\n"
812               << "            %input_buffer = OpVariable %_ptr_Uniform_InputBlock StorageBuffer\n"
813               << "              %optype_ptr = OpTypePointer StorageBuffer ${OPERAND_TYPE}\n"
814               << "        %gl_WorkGroupSize = OpConstantComposite %v3uint %uint_1 %uint_1 %uint_1\n"
815               << "                    %main = OpFunction %void None %voidfunc\n"
816               << "               %mainlabel = OpLabel\n"
817               << "                 %gidxptr = OpAccessChain %_ptr_Input_uint %gl_GlobalInvocationID %uint_0\n"
818               << "                     %idx = OpLoad %uint %gidxptr\n"
819               << "                  %op1ptr = OpAccessChain %optype_ptr %input_buffer %int_0 %idx %int_0\n"
820               << "                     %op1 = OpLoad ${OPERAND_TYPE} %op1ptr\n"
821               << "                  %op2ptr = OpAccessChain %optype_ptr %input_buffer %int_0 %idx %int_1\n"
822               << "                     %op2 = OpLoad ${OPERAND_TYPE} %op2ptr\n"
823               << "                  %op3ptr = OpAccessChain %optype_ptr %input_buffer %int_0 %idx %int_2\n"
824               << "                     %op3 = OpLoad ${OPERAND_TYPE} %op3ptr\n"
825               << "                  %result = OpExtInst ${OPERAND_TYPE} %trinary ${OPERATION_NAME} %op1 %op2 %op3\n"
826               << "               %resultptr = OpAccessChain %optype_ptr %output_buffer %int_0 %idx\n"
827               << "                            OpStore %resultptr %result\n"
828               << "                            OpReturn\n"
829               << "                            OpFunctionEnd\n";
830 
831     const tcu::StringTemplate shaderTemplate{shaderStr.str()};
832     const vk::SpirVAsmBuildOptions buildOptions{VK_MAKE_API_VERSION(0, 1, 2, 0), vk::SPIRV_VERSION_1_5};
833 
834     programCollection.spirvAsmSources.add("comp", &buildOptions) << shaderTemplate.specialize(getSpirVReplacements());
835 }
836 
TrinaryMinMaxInstance(Context & context,const TestParams & params)837 TrinaryMinMaxInstance::TrinaryMinMaxInstance(Context &context, const TestParams &params)
838     : vkt::TestInstance(context)
839     , m_params(params)
840 {
841 }
842 
iterate(void)843 tcu::TestStatus TrinaryMinMaxInstance::iterate(void)
844 {
845     const auto &vkd       = m_context.getDeviceInterface();
846     const auto device     = m_context.getDevice();
847     auto &allocator       = m_context.getDefaultAllocator();
848     const auto queue      = m_context.getUniversalQueue();
849     const auto queueIndex = m_context.getUniversalQueueFamilyIndex();
850 
851     constexpr auto kNumOperations = TrinaryMinMaxCase::kArraySize;
852 
853     const vk::VkDeviceSize kInputBufferSize =
854         static_cast<vk::VkDeviceSize>(kNumOperations * 3u * m_params.operandSize());
855     const vk::VkDeviceSize kOutputBufferSize =
856         static_cast<vk::VkDeviceSize>(kNumOperations * m_params.operandSize()); // Single output per operation.
857 
858     // Create input, output and reference buffers.
859     auto inputBufferInfo  = vk::makeBufferCreateInfo(kInputBufferSize, vk::VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
860     auto outputBufferInfo = vk::makeBufferCreateInfo(kOutputBufferSize, vk::VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
861 
862     vk::BufferWithMemory inputBuffer{vkd, device, allocator, inputBufferInfo, vk::MemoryRequirement::HostVisible};
863     vk::BufferWithMemory outputBuffer{vkd, device, allocator, outputBufferInfo, vk::MemoryRequirement::HostVisible};
864     std::unique_ptr<char[]> referenceBuffer{new char[static_cast<size_t>(kOutputBufferSize)]};
865 
866     // Fill buffers with initial contents.
867     auto &inputAlloc  = inputBuffer.getAllocation();
868     auto &outputAlloc = outputBuffer.getAllocation();
869 
870     void *inputBufferPtr     = static_cast<uint8_t *>(inputAlloc.getHostPtr()) + inputAlloc.getOffset();
871     void *outputBufferPtr    = static_cast<uint8_t *>(outputAlloc.getHostPtr()) + outputAlloc.getOffset();
872     void *referenceBufferPtr = referenceBuffer.get();
873 
874     deMemset(inputBufferPtr, 0, static_cast<size_t>(kInputBufferSize));
875     deMemset(outputBufferPtr, 0, static_cast<size_t>(kOutputBufferSize));
876     deMemset(referenceBufferPtr, 0, static_cast<size_t>(kOutputBufferSize));
877 
878     // Generate input buffer and calculate reference results.
879     OperationManager opMan{m_params};
880     opMan.genInputBuffer(inputBufferPtr, kNumOperations);
881     opMan.calculateResult(referenceBufferPtr, inputBufferPtr, kNumOperations);
882 
883     // Flush buffer memory before starting.
884     vk::flushAlloc(vkd, device, inputAlloc);
885     vk::flushAlloc(vkd, device, outputAlloc);
886 
887     // Descriptor set layout.
888     vk::DescriptorSetLayoutBuilder layoutBuilder;
889     layoutBuilder.addSingleBinding(vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, vk::VK_SHADER_STAGE_COMPUTE_BIT);
890     layoutBuilder.addSingleBinding(vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, vk::VK_SHADER_STAGE_COMPUTE_BIT);
891     auto descriptorSetLayout = layoutBuilder.build(vkd, device);
892 
893     // Descriptor pool.
894     vk::DescriptorPoolBuilder poolBuilder;
895     poolBuilder.addType(vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 2u);
896     auto descriptorPool = poolBuilder.build(vkd, device, vk::VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
897 
898     // Descriptor set.
899     const auto descriptorSet = vk::makeDescriptorSet(vkd, device, descriptorPool.get(), descriptorSetLayout.get());
900 
901     // Update descriptor set using the buffers.
902     const auto inputBufferDescriptorInfo  = vk::makeDescriptorBufferInfo(inputBuffer.get(), 0ull, VK_WHOLE_SIZE);
903     const auto outputBufferDescriptorInfo = vk::makeDescriptorBufferInfo(outputBuffer.get(), 0ull, VK_WHOLE_SIZE);
904 
905     vk::DescriptorSetUpdateBuilder updateBuilder;
906     updateBuilder.writeSingle(descriptorSet.get(), vk::DescriptorSetUpdateBuilder::Location::binding(0u),
907                               vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &inputBufferDescriptorInfo);
908     updateBuilder.writeSingle(descriptorSet.get(), vk::DescriptorSetUpdateBuilder::Location::binding(1u),
909                               vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &outputBufferDescriptorInfo);
910     updateBuilder.update(vkd, device);
911 
912     // Create compute pipeline.
913     auto shaderModule   = vk::createShaderModule(vkd, device, m_context.getBinaryCollection().get("comp"), 0u);
914     auto pipelineLayout = vk::makePipelineLayout(vkd, device, descriptorSetLayout.get());
915 
916     const vk::VkComputePipelineCreateInfo pipelineCreateInfo = {
917         vk::VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
918         nullptr,
919         0u, // flags
920         {
921             // compute shader
922             vk::VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // VkStructureType sType;
923             nullptr,                                                 // const void* pNext;
924             0u,                                                      // VkPipelineShaderStageCreateFlags flags;
925             vk::VK_SHADER_STAGE_COMPUTE_BIT,                         // VkShaderStageFlagBits stage;
926             shaderModule.get(),                                      // VkShaderModule module;
927             "main",                                                  // const char* pName;
928             nullptr,                                                 // const VkSpecializationInfo* pSpecializationInfo;
929         },
930         pipelineLayout.get(), // layout
931         DE_NULL,              // basePipelineHandle
932         0,                    // basePipelineIndex
933     };
934     auto pipeline = vk::createComputePipeline(vkd, device, DE_NULL, &pipelineCreateInfo);
935 
936     // Synchronization barriers.
937     auto inputBufferHostToDevBarrier = vk::makeBufferMemoryBarrier(
938         vk::VK_ACCESS_HOST_WRITE_BIT, vk::VK_ACCESS_SHADER_READ_BIT, inputBuffer.get(), 0ull, VK_WHOLE_SIZE);
939     auto outputBufferHostToDevBarrier = vk::makeBufferMemoryBarrier(
940         vk::VK_ACCESS_HOST_WRITE_BIT, vk::VK_ACCESS_SHADER_WRITE_BIT, outputBuffer.get(), 0ull, VK_WHOLE_SIZE);
941     auto outputBufferDevToHostBarrier = vk::makeBufferMemoryBarrier(
942         vk::VK_ACCESS_SHADER_WRITE_BIT, vk::VK_ACCESS_HOST_READ_BIT, outputBuffer.get(), 0ull, VK_WHOLE_SIZE);
943 
944     // Command buffer.
945     auto cmdPool      = vk::makeCommandPool(vkd, device, queueIndex);
946     auto cmdBufferPtr = vk::allocateCommandBuffer(vkd, device, cmdPool.get(), vk::VK_COMMAND_BUFFER_LEVEL_PRIMARY);
947     auto cmdBuffer    = cmdBufferPtr.get();
948 
949     // Record and submit commands.
950     vk::beginCommandBuffer(vkd, cmdBuffer);
951     vkd.cmdBindPipeline(cmdBuffer, vk::VK_PIPELINE_BIND_POINT_COMPUTE, pipeline.get());
952     vkd.cmdBindDescriptorSets(cmdBuffer, vk::VK_PIPELINE_BIND_POINT_COMPUTE, pipelineLayout.get(), 0, 1u,
953                               &descriptorSet.get(), 0u, nullptr);
954     vkd.cmdPipelineBarrier(cmdBuffer, vk::VK_PIPELINE_STAGE_HOST_BIT, vk::VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0u, 0u,
955                            nullptr, 1u, &inputBufferHostToDevBarrier, 0u, nullptr);
956     vkd.cmdPipelineBarrier(cmdBuffer, vk::VK_PIPELINE_STAGE_HOST_BIT, vk::VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0u, 0u,
957                            nullptr, 1u, &outputBufferHostToDevBarrier, 0u, nullptr);
958     vkd.cmdDispatch(cmdBuffer, kNumOperations, 1u, 1u);
959     vkd.cmdPipelineBarrier(cmdBuffer, vk::VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, vk::VK_PIPELINE_STAGE_HOST_BIT, 0u, 0u,
960                            nullptr, 1u, &outputBufferDevToHostBarrier, 0u, nullptr);
961     vk::endCommandBuffer(vkd, cmdBuffer);
962     vk::submitCommandsAndWait(vkd, device, queue, cmdBuffer);
963 
964     // Verify output buffer contents.
965     vk::invalidateAlloc(vkd, device, outputAlloc);
966 
967     const auto error = opMan.compareResults(referenceBufferPtr, outputBufferPtr, kNumOperations);
968 
969     if (!error)
970         return tcu::TestStatus::pass("Pass");
971 
972     std::ostringstream msg;
973     msg << "Value mismatch at operation " << error.get().first << " in component " << error.get().second;
974     return tcu::TestStatus::fail(msg.str());
975 }
976 
977 } // namespace
978 
createTrinaryMinMaxGroup(tcu::TestContext & testCtx)979 tcu::TestCaseGroup *createTrinaryMinMaxGroup(tcu::TestContext &testCtx)
980 {
981     uint32_t seed = 0xFEE768FCu;
982     de::MovePtr<tcu::TestCaseGroup> group{new tcu::TestCaseGroup{testCtx, "amd_trinary_minmax"}};
983 
984     static const std::vector<std::pair<OperationType, std::string>> operationTypes = {
985         {OperationType::MIN, "min3"},
986         {OperationType::MAX, "max3"},
987         {OperationType::MID, "mid3"},
988     };
989 
990     static const std::vector<std::pair<BaseType, std::string>> baseTypes = {
991         {BaseType::TYPE_INT, "i"},
992         {BaseType::TYPE_UINT, "u"},
993         {BaseType::TYPE_FLOAT, "f"},
994     };
995 
996     static const std::vector<std::pair<TypeSize, std::string>> typeSizes = {
997         {TypeSize::SIZE_8BIT, "8"},
998         {TypeSize::SIZE_16BIT, "16"},
999         {TypeSize::SIZE_32BIT, "32"},
1000         {TypeSize::SIZE_64BIT, "64"},
1001     };
1002 
1003     static const std::vector<std::pair<AggregationType, std::string>> aggregationTypes = {
1004         {AggregationType::SCALAR, "scalar"},
1005         {AggregationType::VEC2, "vec2"},
1006         {AggregationType::VEC3, "vec3"},
1007         {AggregationType::VEC4, "vec4"},
1008     };
1009 
1010     for (const auto &opType : operationTypes)
1011     {
1012         de::MovePtr<tcu::TestCaseGroup> opGroup{new tcu::TestCaseGroup{testCtx, opType.second.c_str()}};
1013 
1014         for (const auto &baseType : baseTypes)
1015             for (const auto &typeSize : typeSizes)
1016             {
1017                 // There are no 8-bit floats.
1018                 if (baseType.first == BaseType::TYPE_FLOAT && typeSize.first == TypeSize::SIZE_8BIT)
1019                     continue;
1020 
1021                 const std::string typeName = baseType.second + typeSize.second;
1022 
1023                 de::MovePtr<tcu::TestCaseGroup> typeGroup{new tcu::TestCaseGroup{testCtx, typeName.c_str()}};
1024 
1025                 for (const auto &aggType : aggregationTypes)
1026                 {
1027                     const TestParams params = {
1028                         opType.first,   // OperationType operation;
1029                         baseType.first, // BaseType baseType;
1030                         typeSize.first, // TypeSize typeSize;
1031                         aggType.first,  // AggregationType aggregation;
1032                         seed++,         // uint32_t randomSeed;
1033                     };
1034                     typeGroup->addChild(new TrinaryMinMaxCase{testCtx, aggType.second, params});
1035                 }
1036 
1037                 opGroup->addChild(typeGroup.release());
1038             }
1039 
1040         group->addChild(opGroup.release());
1041     }
1042 
1043     return group.release();
1044 }
1045 
1046 } // namespace SpirVAssembly
1047 } // namespace vkt
1048