1 /*-------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2020 Valve Corporation.
6 * Copyright (c) 2020 The Khronos Group Inc.
7 *
8 * Licensed under the Apache License, Version 2.0 (the "License");
9 * you may not use this file except in compliance with the License.
10 * You may obtain a copy of the License at
11 *
12 * http://www.apache.org/licenses/LICENSE-2.0
13 *
14 * Unless required by applicable law or agreed to in writing, software
15 * distributed under the License is distributed on an "AS IS" BASIS,
16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 * See the License for the specific language governing permissions and
18 * limitations under the License.
19 *
20 *//*!
21 * \file
22 * \brief SPIR-V tests for VK_AMD_shader_trinary_minmax.
23 *//*--------------------------------------------------------------------*/
24
25 #include "vktSpvAsmTrinaryMinMaxTests.hpp"
26 #include "vktTestCase.hpp"
27
28 #include "vkQueryUtil.hpp"
29 #include "vkObjUtil.hpp"
30 #include "vkBufferWithMemory.hpp"
31 #include "vkBuilderUtil.hpp"
32 #include "vkTypeUtil.hpp"
33 #include "vkBarrierUtil.hpp"
34 #include "vkCmdUtil.hpp"
35
36 #include "tcuStringTemplate.hpp"
37 #include "tcuFloat.hpp"
38 #include "tcuMaybe.hpp"
39
40 #include "deStringUtil.hpp"
41 #include "deRandom.hpp"
42 #include "deMemory.h"
43
44 #include <string>
45 #include <sstream>
46 #include <map>
47 #include <vector>
48 #include <algorithm>
49 #include <array>
50 #include <memory>
51
52 namespace vkt
53 {
54 namespace SpirVAssembly
55 {
56
57 namespace
58 {
59
60 enum class OperationType
61 {
62 MIN = 0,
63 MAX = 1,
64 MID = 2,
65 };
66
67 enum class BaseType
68 {
69 TYPE_INT = 0,
70 TYPE_UINT,
71 TYPE_FLOAT,
72 };
73
74 // The numeric value is the size in bytes.
75 enum class TypeSize
76 {
77 SIZE_8BIT = 1,
78 SIZE_16BIT = 2,
79 SIZE_32BIT = 4,
80 SIZE_64BIT = 8,
81 };
82
83 // The numeric value is the number of components.
84 enum class AggregationType
85 {
86 SCALAR = 1,
87 VEC2 = 2,
88 VEC3 = 3,
89 VEC4 = 4,
90 };
91
92 struct TestParams
93 {
94 OperationType operation;
95 BaseType baseType;
96 TypeSize typeSize;
97 AggregationType aggregation;
98 uint32_t randomSeed;
99
100 uint32_t operandSize() const; // In bytes.
101 uint32_t numComponents() const; // Number of components.
102 uint32_t effectiveComponents() const; // Effective number of components for size calculation.
103 uint32_t componentSize() const; // In bytes.
104 };
105
operandSize() const106 uint32_t TestParams::operandSize() const
107 {
108 return (effectiveComponents() * componentSize());
109 }
110
numComponents() const111 uint32_t TestParams::numComponents() const
112 {
113 return static_cast<uint32_t>(aggregation);
114 }
115
effectiveComponents() const116 uint32_t TestParams::effectiveComponents() const
117 {
118 return static_cast<uint32_t>((aggregation == AggregationType::VEC3) ? AggregationType::VEC4 : aggregation);
119 }
120
componentSize() const121 uint32_t TestParams::componentSize() const
122 {
123 return static_cast<uint32_t>(typeSize);
124 }
125
126 template <class T>
min3(T op1,T op2,T op3)127 T min3(T op1, T op2, T op3)
128 {
129 return std::min({op1, op2, op3});
130 }
131
132 template <class T>
max3(T op1,T op2,T op3)133 T max3(T op1, T op2, T op3)
134 {
135 return std::max({op1, op2, op3});
136 }
137
138 template <class T>
mid3(T op1,T op2,T op3)139 T mid3(T op1, T op2, T op3)
140 {
141 std::array<T, 3> aux{{op1, op2, op3}};
142 std::sort(begin(aux), end(aux));
143 return aux[1];
144 }
145
146 class OperationManager
147 {
148 public:
149 // Operation and component index in case of error.
150 using OperationComponent = std::pair<uint32_t, uint32_t>;
151 using ComparisonError = tcu::Maybe<OperationComponent>;
152
153 OperationManager(const TestParams ¶ms);
154 void genInputBuffer(void *bufferPtr, uint32_t numOperations);
155 void calculateResult(void *referenceBuffer, void *inputBuffer, uint32_t numOperations);
156 ComparisonError compareResults(void *referenceBuffer, void *resultsBuffer, uint32_t numOperations);
157
158 private:
159 using GenerateCompFunc = void (*)(de::Random &, void *); // Write a generated component to the given location.
160
161 // Generator variants to populate input buffer.
genInt8(de::Random & rnd,void * ptr)162 static void genInt8(de::Random &rnd, void *ptr)
163 {
164 *reinterpret_cast<int8_t *>(ptr) = static_cast<int8_t>(rnd.getUint8());
165 }
genUint8(de::Random & rnd,void * ptr)166 static void genUint8(de::Random &rnd, void *ptr)
167 {
168 *reinterpret_cast<uint8_t *>(ptr) = rnd.getUint8();
169 }
genInt16(de::Random & rnd,void * ptr)170 static void genInt16(de::Random &rnd, void *ptr)
171 {
172 *reinterpret_cast<int16_t *>(ptr) = static_cast<int16_t>(rnd.getUint16());
173 }
genUint16(de::Random & rnd,void * ptr)174 static void genUint16(de::Random &rnd, void *ptr)
175 {
176 *reinterpret_cast<uint16_t *>(ptr) = rnd.getUint16();
177 }
genInt32(de::Random & rnd,void * ptr)178 static void genInt32(de::Random &rnd, void *ptr)
179 {
180 *reinterpret_cast<int32_t *>(ptr) = static_cast<int32_t>(rnd.getUint32());
181 }
genUint32(de::Random & rnd,void * ptr)182 static void genUint32(de::Random &rnd, void *ptr)
183 {
184 *reinterpret_cast<uint32_t *>(ptr) = rnd.getUint32();
185 }
genInt64(de::Random & rnd,void * ptr)186 static void genInt64(de::Random &rnd, void *ptr)
187 {
188 *reinterpret_cast<int64_t *>(ptr) = static_cast<int64_t>(rnd.getUint64());
189 }
genUint64(de::Random & rnd,void * ptr)190 static void genUint64(de::Random &rnd, void *ptr)
191 {
192 *reinterpret_cast<uint64_t *>(ptr) = rnd.getUint64();
193 }
194
195 // Helper template for float generators.
196 // T must be a tcu::Float instantiation.
197 // Attempts to generate +-Inf once every 10 times and avoid denormals.
198 template <class T>
genFloat(de::Random & rnd,void * ptr)199 static inline void genFloat(de::Random &rnd, void *ptr)
200 {
201 T *valuePtr = reinterpret_cast<T *>(ptr);
202 if (rnd.getInt(1, 10) == 1)
203 *valuePtr = T::inf(rnd.getBool() ? 1 : -1);
204 else
205 {
206 do
207 {
208 *valuePtr = T{rnd.getDouble(T::largestNormal(-1).asDouble(), T::largestNormal(1).asDouble())};
209 } while (valuePtr->isDenorm());
210 }
211 }
212
genFloat16(de::Random & rnd,void * ptr)213 static void genFloat16(de::Random &rnd, void *ptr)
214 {
215 genFloat<tcu::Float16>(rnd, ptr);
216 }
genFloat32(de::Random & rnd,void * ptr)217 static void genFloat32(de::Random &rnd, void *ptr)
218 {
219 genFloat<tcu::Float32>(rnd, ptr);
220 }
genFloat64(de::Random & rnd,void * ptr)221 static void genFloat64(de::Random &rnd, void *ptr)
222 {
223 genFloat<tcu::Float64>(rnd, ptr);
224 }
225
226 // An operation function writes an output value given 3 input values.
227 using OperationFunc = void (*)(void *, const void *, const void *, const void *);
228
229 // Helper template used below.
230 template <class T, class F>
runOpFunc(F f,void * out,const void * in1,const void * in2,const void * in3)231 static inline void runOpFunc(F f, void *out, const void *in1, const void *in2, const void *in3)
232 {
233 *reinterpret_cast<T *>(out) =
234 f(*reinterpret_cast<const T *>(in1), *reinterpret_cast<const T *>(in2), *reinterpret_cast<const T *>(in3));
235 }
236
237 // Apply an operation in software to a given group of components and calculate result.
minInt8(void * out,const void * in1,const void * in2,const void * in3)238 static void minInt8(void *out, const void *in1, const void *in2, const void *in3)
239 {
240 runOpFunc<int8_t>(min3<int8_t>, out, in1, in2, in3);
241 }
maxInt8(void * out,const void * in1,const void * in2,const void * in3)242 static void maxInt8(void *out, const void *in1, const void *in2, const void *in3)
243 {
244 runOpFunc<int8_t>(max3<int8_t>, out, in1, in2, in3);
245 }
midInt8(void * out,const void * in1,const void * in2,const void * in3)246 static void midInt8(void *out, const void *in1, const void *in2, const void *in3)
247 {
248 runOpFunc<int8_t>(mid3<int8_t>, out, in1, in2, in3);
249 }
minUint8(void * out,const void * in1,const void * in2,const void * in3)250 static void minUint8(void *out, const void *in1, const void *in2, const void *in3)
251 {
252 runOpFunc<uint8_t>(min3<uint8_t>, out, in1, in2, in3);
253 }
maxUint8(void * out,const void * in1,const void * in2,const void * in3)254 static void maxUint8(void *out, const void *in1, const void *in2, const void *in3)
255 {
256 runOpFunc<uint8_t>(max3<uint8_t>, out, in1, in2, in3);
257 }
midUint8(void * out,const void * in1,const void * in2,const void * in3)258 static void midUint8(void *out, const void *in1, const void *in2, const void *in3)
259 {
260 runOpFunc<uint8_t>(mid3<uint8_t>, out, in1, in2, in3);
261 }
minInt16(void * out,const void * in1,const void * in2,const void * in3)262 static void minInt16(void *out, const void *in1, const void *in2, const void *in3)
263 {
264 runOpFunc<int16_t>(min3<int16_t>, out, in1, in2, in3);
265 }
maxInt16(void * out,const void * in1,const void * in2,const void * in3)266 static void maxInt16(void *out, const void *in1, const void *in2, const void *in3)
267 {
268 runOpFunc<int16_t>(max3<int16_t>, out, in1, in2, in3);
269 }
midInt16(void * out,const void * in1,const void * in2,const void * in3)270 static void midInt16(void *out, const void *in1, const void *in2, const void *in3)
271 {
272 runOpFunc<int16_t>(mid3<int16_t>, out, in1, in2, in3);
273 }
minUint16(void * out,const void * in1,const void * in2,const void * in3)274 static void minUint16(void *out, const void *in1, const void *in2, const void *in3)
275 {
276 runOpFunc<uint16_t>(min3<uint16_t>, out, in1, in2, in3);
277 }
maxUint16(void * out,const void * in1,const void * in2,const void * in3)278 static void maxUint16(void *out, const void *in1, const void *in2, const void *in3)
279 {
280 runOpFunc<uint16_t>(max3<uint16_t>, out, in1, in2, in3);
281 }
midUint16(void * out,const void * in1,const void * in2,const void * in3)282 static void midUint16(void *out, const void *in1, const void *in2, const void *in3)
283 {
284 runOpFunc<uint16_t>(mid3<uint16_t>, out, in1, in2, in3);
285 }
minInt32(void * out,const void * in1,const void * in2,const void * in3)286 static void minInt32(void *out, const void *in1, const void *in2, const void *in3)
287 {
288 runOpFunc<int32_t>(min3<int32_t>, out, in1, in2, in3);
289 }
maxInt32(void * out,const void * in1,const void * in2,const void * in3)290 static void maxInt32(void *out, const void *in1, const void *in2, const void *in3)
291 {
292 runOpFunc<int32_t>(max3<int32_t>, out, in1, in2, in3);
293 }
midInt32(void * out,const void * in1,const void * in2,const void * in3)294 static void midInt32(void *out, const void *in1, const void *in2, const void *in3)
295 {
296 runOpFunc<int32_t>(mid3<int32_t>, out, in1, in2, in3);
297 }
minUint32(void * out,const void * in1,const void * in2,const void * in3)298 static void minUint32(void *out, const void *in1, const void *in2, const void *in3)
299 {
300 runOpFunc<uint32_t>(min3<uint32_t>, out, in1, in2, in3);
301 }
maxUint32(void * out,const void * in1,const void * in2,const void * in3)302 static void maxUint32(void *out, const void *in1, const void *in2, const void *in3)
303 {
304 runOpFunc<uint32_t>(max3<uint32_t>, out, in1, in2, in3);
305 }
midUint32(void * out,const void * in1,const void * in2,const void * in3)306 static void midUint32(void *out, const void *in1, const void *in2, const void *in3)
307 {
308 runOpFunc<uint32_t>(mid3<uint32_t>, out, in1, in2, in3);
309 }
minInt64(void * out,const void * in1,const void * in2,const void * in3)310 static void minInt64(void *out, const void *in1, const void *in2, const void *in3)
311 {
312 runOpFunc<int64_t>(min3<int64_t>, out, in1, in2, in3);
313 }
maxInt64(void * out,const void * in1,const void * in2,const void * in3)314 static void maxInt64(void *out, const void *in1, const void *in2, const void *in3)
315 {
316 runOpFunc<int64_t>(max3<int64_t>, out, in1, in2, in3);
317 }
midInt64(void * out,const void * in1,const void * in2,const void * in3)318 static void midInt64(void *out, const void *in1, const void *in2, const void *in3)
319 {
320 runOpFunc<int64_t>(mid3<int64_t>, out, in1, in2, in3);
321 }
minUint64(void * out,const void * in1,const void * in2,const void * in3)322 static void minUint64(void *out, const void *in1, const void *in2, const void *in3)
323 {
324 runOpFunc<uint64_t>(min3<uint64_t>, out, in1, in2, in3);
325 }
maxUint64(void * out,const void * in1,const void * in2,const void * in3)326 static void maxUint64(void *out, const void *in1, const void *in2, const void *in3)
327 {
328 runOpFunc<uint64_t>(max3<uint64_t>, out, in1, in2, in3);
329 }
midUint64(void * out,const void * in1,const void * in2,const void * in3)330 static void midUint64(void *out, const void *in1, const void *in2, const void *in3)
331 {
332 runOpFunc<uint64_t>(mid3<uint64_t>, out, in1, in2, in3);
333 }
minFloat16(void * out,const void * in1,const void * in2,const void * in3)334 static void minFloat16(void *out, const void *in1, const void *in2, const void *in3)
335 {
336 runOpFunc<tcu::Float16>(min3<tcu::Float16>, out, in1, in2, in3);
337 }
maxFloat16(void * out,const void * in1,const void * in2,const void * in3)338 static void maxFloat16(void *out, const void *in1, const void *in2, const void *in3)
339 {
340 runOpFunc<tcu::Float16>(max3<tcu::Float16>, out, in1, in2, in3);
341 }
midFloat16(void * out,const void * in1,const void * in2,const void * in3)342 static void midFloat16(void *out, const void *in1, const void *in2, const void *in3)
343 {
344 runOpFunc<tcu::Float16>(mid3<tcu::Float16>, out, in1, in2, in3);
345 }
minFloat32(void * out,const void * in1,const void * in2,const void * in3)346 static void minFloat32(void *out, const void *in1, const void *in2, const void *in3)
347 {
348 runOpFunc<tcu::Float32>(min3<tcu::Float32>, out, in1, in2, in3);
349 }
maxFloat32(void * out,const void * in1,const void * in2,const void * in3)350 static void maxFloat32(void *out, const void *in1, const void *in2, const void *in3)
351 {
352 runOpFunc<tcu::Float32>(max3<tcu::Float32>, out, in1, in2, in3);
353 }
midFloat32(void * out,const void * in1,const void * in2,const void * in3)354 static void midFloat32(void *out, const void *in1, const void *in2, const void *in3)
355 {
356 runOpFunc<tcu::Float32>(mid3<tcu::Float32>, out, in1, in2, in3);
357 }
minFloat64(void * out,const void * in1,const void * in2,const void * in3)358 static void minFloat64(void *out, const void *in1, const void *in2, const void *in3)
359 {
360 runOpFunc<tcu::Float64>(min3<tcu::Float64>, out, in1, in2, in3);
361 }
maxFloat64(void * out,const void * in1,const void * in2,const void * in3)362 static void maxFloat64(void *out, const void *in1, const void *in2, const void *in3)
363 {
364 runOpFunc<tcu::Float64>(max3<tcu::Float64>, out, in1, in2, in3);
365 }
midFloat64(void * out,const void * in1,const void * in2,const void * in3)366 static void midFloat64(void *out, const void *in1, const void *in2, const void *in3)
367 {
368 runOpFunc<tcu::Float64>(mid3<tcu::Float64>, out, in1, in2, in3);
369 }
370
371 // Case for accessing the functions map.
372 struct Case
373 {
374 BaseType type;
375 TypeSize size;
376 OperationType operation;
377
378 // This is required for sorting in the map.
operator <vkt::SpirVAssembly::__anon44d5838c0111::OperationManager::Case379 bool operator<(const Case &other) const
380 {
381 return (toArray() < other.toArray());
382 }
383
384 private:
toArrayvkt::SpirVAssembly::__anon44d5838c0111::OperationManager::Case385 std::array<int, 3> toArray() const
386 {
387 return std::array<int, 3>{{static_cast<int>(type), static_cast<int>(size), static_cast<int>(operation)}};
388 }
389 };
390
391 // Helper map to correctly choose the right generator and operation function for the specific case being tested.
392 using FuncPair = std::pair<GenerateCompFunc, OperationFunc>;
393 using CaseMap = std::map<Case, FuncPair>;
394
395 static const CaseMap kFunctionsMap;
396
397 GenerateCompFunc m_chosenGenerator;
398 OperationFunc m_chosenOperation;
399 de::Random m_random;
400
401 const uint32_t m_operandSize;
402 const uint32_t m_numComponents;
403 const uint32_t m_componentSize;
404 };
405
406 // This map is used to choose how to generate inputs for each case and which operation to run on the CPU to calculate the reference
407 // results for the generated inputs.
408 const OperationManager::CaseMap OperationManager::kFunctionsMap = {
409 {{BaseType::TYPE_INT, TypeSize::SIZE_8BIT, OperationType::MIN}, {genInt8, minInt8}},
410 {{BaseType::TYPE_INT, TypeSize::SIZE_8BIT, OperationType::MAX}, {genInt8, maxInt8}},
411 {{BaseType::TYPE_INT, TypeSize::SIZE_8BIT, OperationType::MID}, {genInt8, midInt8}},
412 {{BaseType::TYPE_INT, TypeSize::SIZE_16BIT, OperationType::MIN}, {genInt16, minInt16}},
413 {{BaseType::TYPE_INT, TypeSize::SIZE_16BIT, OperationType::MAX}, {genInt16, maxInt16}},
414 {{BaseType::TYPE_INT, TypeSize::SIZE_16BIT, OperationType::MID}, {genInt16, midInt16}},
415 {{BaseType::TYPE_INT, TypeSize::SIZE_32BIT, OperationType::MIN}, {genInt32, minInt32}},
416 {{BaseType::TYPE_INT, TypeSize::SIZE_32BIT, OperationType::MAX}, {genInt32, maxInt32}},
417 {{BaseType::TYPE_INT, TypeSize::SIZE_32BIT, OperationType::MID}, {genInt32, midInt32}},
418 {{BaseType::TYPE_INT, TypeSize::SIZE_64BIT, OperationType::MIN}, {genInt64, minInt64}},
419 {{BaseType::TYPE_INT, TypeSize::SIZE_64BIT, OperationType::MAX}, {genInt64, maxInt64}},
420 {{BaseType::TYPE_INT, TypeSize::SIZE_64BIT, OperationType::MID}, {genInt64, midInt64}},
421 {{BaseType::TYPE_UINT, TypeSize::SIZE_8BIT, OperationType::MIN}, {genUint8, minUint8}},
422 {{BaseType::TYPE_UINT, TypeSize::SIZE_8BIT, OperationType::MAX}, {genUint8, maxUint8}},
423 {{BaseType::TYPE_UINT, TypeSize::SIZE_8BIT, OperationType::MID}, {genUint8, midUint8}},
424 {{BaseType::TYPE_UINT, TypeSize::SIZE_16BIT, OperationType::MIN}, {genUint16, minUint16}},
425 {{BaseType::TYPE_UINT, TypeSize::SIZE_16BIT, OperationType::MAX}, {genUint16, maxUint16}},
426 {{BaseType::TYPE_UINT, TypeSize::SIZE_16BIT, OperationType::MID}, {genUint16, midUint16}},
427 {{BaseType::TYPE_UINT, TypeSize::SIZE_32BIT, OperationType::MIN}, {genUint32, minUint32}},
428 {{BaseType::TYPE_UINT, TypeSize::SIZE_32BIT, OperationType::MAX}, {genUint32, maxUint32}},
429 {{BaseType::TYPE_UINT, TypeSize::SIZE_32BIT, OperationType::MID}, {genUint32, midUint32}},
430 {{BaseType::TYPE_UINT, TypeSize::SIZE_64BIT, OperationType::MIN}, {genUint64, minUint64}},
431 {{BaseType::TYPE_UINT, TypeSize::SIZE_64BIT, OperationType::MAX}, {genUint64, maxUint64}},
432 {{BaseType::TYPE_UINT, TypeSize::SIZE_64BIT, OperationType::MID}, {genUint64, midUint64}},
433 {{BaseType::TYPE_FLOAT, TypeSize::SIZE_16BIT, OperationType::MIN}, {genFloat16, minFloat16}},
434 {{BaseType::TYPE_FLOAT, TypeSize::SIZE_16BIT, OperationType::MAX}, {genFloat16, maxFloat16}},
435 {{BaseType::TYPE_FLOAT, TypeSize::SIZE_16BIT, OperationType::MID}, {genFloat16, midFloat16}},
436 {{BaseType::TYPE_FLOAT, TypeSize::SIZE_32BIT, OperationType::MIN}, {genFloat32, minFloat32}},
437 {{BaseType::TYPE_FLOAT, TypeSize::SIZE_32BIT, OperationType::MAX}, {genFloat32, maxFloat32}},
438 {{BaseType::TYPE_FLOAT, TypeSize::SIZE_32BIT, OperationType::MID}, {genFloat32, midFloat32}},
439 {{BaseType::TYPE_FLOAT, TypeSize::SIZE_64BIT, OperationType::MIN}, {genFloat64, minFloat64}},
440 {{BaseType::TYPE_FLOAT, TypeSize::SIZE_64BIT, OperationType::MAX}, {genFloat64, maxFloat64}},
441 {{BaseType::TYPE_FLOAT, TypeSize::SIZE_64BIT, OperationType::MID}, {genFloat64, midFloat64}},
442 };
443
OperationManager(const TestParams & params)444 OperationManager::OperationManager(const TestParams ¶ms)
445 : m_chosenGenerator{nullptr}
446 , m_chosenOperation{nullptr}
447 , m_random{params.randomSeed}
448 , m_operandSize{params.operandSize()}
449 , m_numComponents{params.numComponents()}
450 , m_componentSize{params.componentSize()}
451 {
452 // Choose generator and CPU operation from the map.
453 const Case paramCase{params.baseType, params.typeSize, params.operation};
454 const auto iter = kFunctionsMap.find(paramCase);
455
456 DE_ASSERT(iter != kFunctionsMap.end());
457 m_chosenGenerator = iter->second.first;
458 m_chosenOperation = iter->second.second;
459 }
460
461 // See TrinaryMinMaxCase::initPrograms for a description of the input buffer format.
462 // Generates inputs with the chosen generator.
genInputBuffer(void * bufferPtr,uint32_t numOperations)463 void OperationManager::genInputBuffer(void *bufferPtr, uint32_t numOperations)
464 {
465 const uint32_t numOperands = numOperations * 3u;
466 char *byteBuffer = reinterpret_cast<char *>(bufferPtr);
467
468 for (uint32_t opIdx = 0u; opIdx < numOperands; ++opIdx)
469 {
470 char *compPtr = byteBuffer;
471 for (uint32_t compIdx = 0u; compIdx < m_numComponents; ++compIdx)
472 {
473 m_chosenGenerator(m_random, reinterpret_cast<void *>(compPtr));
474 compPtr += m_componentSize;
475 }
476 byteBuffer += m_operandSize;
477 }
478 }
479
480 // See TrinaryMinMaxCase::initPrograms for a description of the input and output buffer formats.
481 // Calculates reference results on the CPU using the chosen operation and the input buffer.
calculateResult(void * referenceBuffer,void * inputBuffer,uint32_t numOperations)482 void OperationManager::calculateResult(void *referenceBuffer, void *inputBuffer, uint32_t numOperations)
483 {
484 char *outputByte = reinterpret_cast<char *>(referenceBuffer);
485 char *inputByte = reinterpret_cast<char *>(inputBuffer);
486
487 for (uint32_t opIdx = 0u; opIdx < numOperations; ++opIdx)
488 {
489 char *res = outputByte;
490 char *op1 = inputByte;
491 char *op2 = inputByte + m_operandSize;
492 char *op3 = inputByte + m_operandSize * 2u;
493
494 for (uint32_t compIdx = 0u; compIdx < m_numComponents; ++compIdx)
495 {
496 m_chosenOperation(reinterpret_cast<void *>(res), reinterpret_cast<void *>(op1),
497 reinterpret_cast<void *>(op2), reinterpret_cast<void *>(op3));
498
499 res += m_componentSize;
500 op1 += m_componentSize;
501 op2 += m_componentSize;
502 op3 += m_componentSize;
503 }
504
505 outputByte += m_operandSize;
506 inputByte += m_operandSize * 3u;
507 }
508 }
509
510 // See TrinaryMinMaxCase::initPrograms for a description of the output buffer format.
compareResults(void * referenceBuffer,void * resultsBuffer,uint32_t numOperations)511 OperationManager::ComparisonError OperationManager::compareResults(void *referenceBuffer, void *resultsBuffer,
512 uint32_t numOperations)
513 {
514 char *referenceBytes = reinterpret_cast<char *>(referenceBuffer);
515 char *resultsBytes = reinterpret_cast<char *>(resultsBuffer);
516
517 for (uint32_t opIdx = 0u; opIdx < numOperations; ++opIdx)
518 {
519 char *refCompBytes = referenceBytes;
520 char *resCompBytes = resultsBytes;
521
522 for (uint32_t compIdx = 0u; compIdx < m_numComponents; ++compIdx)
523 {
524 if (deMemCmp(refCompBytes, resCompBytes, m_componentSize) != 0)
525 return tcu::just(OperationComponent(opIdx, compIdx));
526 refCompBytes += m_componentSize;
527 resCompBytes += m_componentSize;
528 }
529 referenceBytes += m_operandSize;
530 resultsBytes += m_operandSize;
531 }
532
533 return tcu::Nothing;
534 }
535
536 class TrinaryMinMaxCase : public vkt::TestCase
537 {
538 public:
539 using ReplacementsMap = std::map<std::string, std::string>;
540
541 TrinaryMinMaxCase(tcu::TestContext &testCtx, const std::string &name, const TestParams ¶ms);
~TrinaryMinMaxCase(void)542 virtual ~TrinaryMinMaxCase(void)
543 {
544 }
545
546 virtual void initPrograms(vk::SourceCollections &programCollection) const;
547 virtual TestInstance *createInstance(Context &context) const;
548 virtual void checkSupport(Context &context) const;
549 ReplacementsMap getSpirVReplacements(void) const;
550
551 static const uint32_t kArraySize;
552
553 private:
554 TestParams m_params;
555 };
556
557 const uint32_t TrinaryMinMaxCase::kArraySize = 100u;
558
559 class TrinaryMinMaxInstance : public vkt::TestInstance
560 {
561 public:
562 TrinaryMinMaxInstance(Context &context, const TestParams ¶ms);
~TrinaryMinMaxInstance(void)563 virtual ~TrinaryMinMaxInstance(void)
564 {
565 }
566
567 virtual tcu::TestStatus iterate(void);
568
569 private:
570 TestParams m_params;
571 };
572
TrinaryMinMaxCase(tcu::TestContext & testCtx,const std::string & name,const TestParams & params)573 TrinaryMinMaxCase::TrinaryMinMaxCase(tcu::TestContext &testCtx, const std::string &name, const TestParams ¶ms)
574 : vkt::TestCase(testCtx, name)
575 , m_params(params)
576 {
577 }
578
createInstance(Context & context) const579 TestInstance *TrinaryMinMaxCase::createInstance(Context &context) const
580 {
581 return new TrinaryMinMaxInstance{context, m_params};
582 }
583
checkSupport(Context & context) const584 void TrinaryMinMaxCase::checkSupport(Context &context) const
585 {
586 // These are always required.
587 context.requireInstanceFunctionality("VK_KHR_get_physical_device_properties2");
588 context.requireDeviceFunctionality("VK_KHR_storage_buffer_storage_class");
589 context.requireDeviceFunctionality("VK_AMD_shader_trinary_minmax");
590
591 const auto devFeatures = context.getDeviceFeatures();
592 const auto storage16BitFeatures = context.get16BitStorageFeatures();
593 const auto storage8BitFeatures = context.get8BitStorageFeatures();
594 const auto shaderFeatures = context.getShaderFloat16Int8Features();
595
596 // Storage features.
597 if (m_params.typeSize == TypeSize::SIZE_8BIT)
598 {
599 // We will be using 8-bit types in storage buffers.
600 context.requireDeviceFunctionality("VK_KHR_8bit_storage");
601 if (!storage8BitFeatures.storageBuffer8BitAccess)
602 TCU_THROW(NotSupportedError, "8-bit storage buffer access not supported");
603 }
604 else if (m_params.typeSize == TypeSize::SIZE_16BIT)
605 {
606 // We will be using 16-bit types in storage buffers.
607 context.requireDeviceFunctionality("VK_KHR_16bit_storage");
608 if (!storage16BitFeatures.storageBuffer16BitAccess)
609 TCU_THROW(NotSupportedError, "16-bit storage buffer access not supported");
610 }
611
612 // Shader type features.
613 if (m_params.baseType == BaseType::TYPE_INT || m_params.baseType == BaseType::TYPE_UINT)
614 {
615 if (m_params.typeSize == TypeSize::SIZE_8BIT && !shaderFeatures.shaderInt8)
616 TCU_THROW(NotSupportedError, "8-bit integers not supported in shaders");
617 else if (m_params.typeSize == TypeSize::SIZE_16BIT && !devFeatures.shaderInt16)
618 TCU_THROW(NotSupportedError, "16-bit integers not supported in shaders");
619 else if (m_params.typeSize == TypeSize::SIZE_64BIT && !devFeatures.shaderInt64)
620 TCU_THROW(NotSupportedError, "64-bit integers not supported in shaders");
621 }
622 else // BaseType::TYPE_FLOAT
623 {
624 DE_ASSERT(m_params.typeSize != TypeSize::SIZE_8BIT);
625 if (m_params.typeSize == TypeSize::SIZE_16BIT && !shaderFeatures.shaderFloat16)
626 TCU_THROW(NotSupportedError, "16-bit floats not supported in shaders");
627 else if (m_params.typeSize == TypeSize::SIZE_64BIT && !devFeatures.shaderFloat64)
628 TCU_THROW(NotSupportedError, "64-bit floats not supported in shaders");
629 }
630 }
631
getSpirVReplacements(void) const632 TrinaryMinMaxCase::ReplacementsMap TrinaryMinMaxCase::getSpirVReplacements(void) const
633 {
634 ReplacementsMap replacements;
635
636 // Capabilities and extensions.
637 if (m_params.baseType == BaseType::TYPE_INT || m_params.baseType == BaseType::TYPE_UINT)
638 {
639 if (m_params.typeSize == TypeSize::SIZE_8BIT)
640 replacements["CAPABILITIES"] += "OpCapability Int8\n";
641 else if (m_params.typeSize == TypeSize::SIZE_16BIT)
642 replacements["CAPABILITIES"] += "OpCapability Int16\n";
643 else if (m_params.typeSize == TypeSize::SIZE_64BIT)
644 replacements["CAPABILITIES"] += "OpCapability Int64\n";
645 }
646 else // BaseType::TYPE_FLOAT
647 {
648 if (m_params.typeSize == TypeSize::SIZE_16BIT)
649 replacements["CAPABILITIES"] += "OpCapability Float16\n";
650 else if (m_params.typeSize == TypeSize::SIZE_64BIT)
651 replacements["CAPABILITIES"] += "OpCapability Float64\n";
652 }
653
654 if (m_params.typeSize == TypeSize::SIZE_8BIT)
655 {
656 replacements["CAPABILITIES"] += "OpCapability StorageBuffer8BitAccess\n";
657 replacements["EXTENSIONS"] += "OpExtension \"SPV_KHR_8bit_storage\"\n";
658 }
659 else if (m_params.typeSize == TypeSize::SIZE_16BIT)
660 {
661 replacements["CAPABILITIES"] += "OpCapability StorageBuffer16BitAccess\n";
662 replacements["EXTENSIONS"] += "OpExtension \"SPV_KHR_16bit_storage\"\n";
663 }
664
665 // Operand size in bytes.
666 const uint32_t opSize = m_params.operandSize();
667 replacements["OPERAND_SIZE"] = de::toString(opSize);
668 replacements["OPERAND_SIZE_2TIMES"] = de::toString(opSize * 2u);
669 replacements["OPERAND_SIZE_3TIMES"] = de::toString(opSize * 3u);
670
671 // Array size.
672 replacements["ARRAY_SIZE"] = de::toString(kArraySize);
673
674 // Types and operand type: define the base integer or float type and the vector type if needed, then set the operand type replacement.
675 const std::string vecSize = de::toString(m_params.numComponents());
676 const std::string bitSize = de::toString(m_params.componentSize() * 8u);
677
678 if (m_params.baseType == BaseType::TYPE_INT || m_params.baseType == BaseType::TYPE_UINT)
679 {
680 const std::string signBit = (m_params.baseType == BaseType::TYPE_INT ? "1" : "0");
681 const std::string typePrefix = (m_params.baseType == BaseType::TYPE_UINT ? "u" : "");
682 std::string baseTypeName;
683
684 // 32-bit integers are already defined in the default shader text.
685 if (m_params.typeSize != TypeSize::SIZE_32BIT)
686 {
687 baseTypeName = typePrefix + "int" + bitSize + "_t";
688 replacements["TYPES"] += "%" + baseTypeName + " = OpTypeInt " + bitSize + " " + signBit + "\n";
689 }
690 else
691 {
692 baseTypeName = typePrefix + "int";
693 }
694
695 if (m_params.aggregation == AggregationType::SCALAR)
696 {
697 replacements["OPERAND_TYPE"] = "%" + baseTypeName;
698 }
699 else
700 {
701 const std::string typeName = "%v" + vecSize + baseTypeName;
702 // %v3uint is already defined in the default shader text.
703 if (m_params.baseType != BaseType::TYPE_UINT || m_params.typeSize != TypeSize::SIZE_32BIT ||
704 m_params.aggregation != AggregationType::VEC3)
705 {
706 replacements["TYPES"] += typeName + " = OpTypeVector %" + baseTypeName + " " + vecSize + "\n";
707 }
708 replacements["OPERAND_TYPE"] = typeName;
709 }
710 }
711 else // BaseType::TYPE_FLOAT
712 {
713 const std::string baseTypeName = "float" + bitSize + "_t";
714 replacements["TYPES"] += "%" + baseTypeName + " = OpTypeFloat " + bitSize + "\n";
715
716 if (m_params.aggregation == AggregationType::SCALAR)
717 {
718 replacements["OPERAND_TYPE"] = "%" + baseTypeName;
719 }
720 else
721 {
722 const std::string typeName = "%v" + vecSize + baseTypeName;
723 replacements["TYPES"] += typeName + " = OpTypeVector %" + baseTypeName + " " + vecSize + "\n";
724 replacements["OPERAND_TYPE"] = typeName;
725 }
726 }
727
728 // Operation name.
729 const static std::vector<std::string> opTypeStr = {"Min", "Max", "Mid"};
730 const static std::vector<std::string> opPrefix = {"S", "U", "F"};
731 replacements["OPERATION_NAME"] =
732 opPrefix[static_cast<int>(m_params.baseType)] + opTypeStr[static_cast<int>(m_params.operation)] + "3AMD";
733
734 return replacements;
735 }
736
initPrograms(vk::SourceCollections & programCollection) const737 void TrinaryMinMaxCase::initPrograms(vk::SourceCollections &programCollection) const
738 {
739 // The shader below uses an input buffer at set 0 binding 0 and an output buffer at set 0 binding 1. Their structure is similar
740 // to the code below:
741 //
742 // struct Operands {
743 // <type> op1;
744 // <type> op2;
745 // <type> op3;
746 // };
747 //
748 // layout (set=0, binding=0, std430) buffer InputBlock {
749 // Operands operands[<arraysize>];
750 // };
751 //
752 // layout (set=0, binding=1, std430) buffer OutputBlock {
753 // <type> result[<arraysize>];
754 // };
755 //
756 // Where <type> can be int8_t, uint32_t, float, etc. So in the input buffer the operands are "grouped" per operation and can
757 // have several components each and the output buffer contains an array of results, one per trio of input operands.
758
759 std::ostringstream shaderStr;
760 shaderStr << "; SPIR-V\n"
761 << "; Version: 1.5\n"
762 << " OpCapability Shader\n"
763 << "${CAPABILITIES:opt}"
764 << " OpExtension \"SPV_KHR_storage_buffer_storage_class\"\n"
765 << " OpExtension \"SPV_AMD_shader_trinary_minmax\"\n"
766 << "${EXTENSIONS:opt}"
767 << " %std450 = OpExtInstImport \"GLSL.std.450\"\n"
768 << " %trinary = OpExtInstImport \"SPV_AMD_shader_trinary_minmax\"\n"
769 << " OpMemoryModel Logical GLSL450\n"
770 << " OpEntryPoint GLCompute %main \"main\" %gl_GlobalInvocationID "
771 "%output_buffer %input_buffer\n"
772 << " OpExecutionMode %main LocalSize 1 1 1\n"
773 << " OpDecorate %gl_GlobalInvocationID BuiltIn GlobalInvocationId\n"
774 << " OpDecorate %results_array_t ArrayStride ${OPERAND_SIZE}\n"
775 << " OpMemberDecorate %OutputBlock 0 Offset 0\n"
776 << " OpDecorate %OutputBlock Block\n"
777 << " OpDecorate %output_buffer DescriptorSet 0\n"
778 << " OpDecorate %output_buffer Binding 1\n"
779 << " OpMemberDecorate %Operands 0 Offset 0\n"
780 << " OpMemberDecorate %Operands 1 Offset ${OPERAND_SIZE}\n"
781 << " OpMemberDecorate %Operands 2 Offset ${OPERAND_SIZE_2TIMES}\n"
782 << " OpDecorate %_arr_Operands_arraysize ArrayStride ${OPERAND_SIZE_3TIMES}\n"
783 << " OpMemberDecorate %InputBlock 0 Offset 0\n"
784 << " OpDecorate %InputBlock Block\n"
785 << " OpDecorate %input_buffer DescriptorSet 0\n"
786 << " OpDecorate %input_buffer Binding 0\n"
787 << " OpDecorate %gl_WorkGroupSize BuiltIn WorkgroupSize\n"
788 << " %void = OpTypeVoid\n"
789 << " %voidfunc = OpTypeFunction %void\n"
790 << " %int = OpTypeInt 32 1\n"
791 << " %uint = OpTypeInt 32 0\n"
792 << " %v3uint = OpTypeVector %uint 3\n"
793 << "${TYPES:opt}"
794 << " %int_0 = OpConstant %int 0\n"
795 << " %int_1 = OpConstant %int 1\n"
796 << " %int_2 = OpConstant %int 2\n"
797 << " %uint_1 = OpConstant %uint 1\n"
798 << " %uint_0 = OpConstant %uint 0\n"
799 << " %arraysize = OpConstant %uint ${ARRAY_SIZE}\n"
800 << " %_ptr_Function_uint = OpTypePointer Function %uint\n"
801 << " %_ptr_Input_v3uint = OpTypePointer Input %v3uint\n"
802 << " %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input\n"
803 << " %_ptr_Input_uint = OpTypePointer Input %uint\n"
804 << " %results_array_t = OpTypeArray ${OPERAND_TYPE} %arraysize\n"
805 << " %Operands = OpTypeStruct ${OPERAND_TYPE} ${OPERAND_TYPE} ${OPERAND_TYPE}\n"
806 << " %_arr_Operands_arraysize = OpTypeArray %Operands %arraysize\n"
807 << " %OutputBlock = OpTypeStruct %results_array_t\n"
808 << " %InputBlock = OpTypeStruct %_arr_Operands_arraysize\n"
809 << "%_ptr_Uniform_OutputBlock = OpTypePointer StorageBuffer %OutputBlock\n"
810 << " %_ptr_Uniform_InputBlock = OpTypePointer StorageBuffer %InputBlock\n"
811 << " %output_buffer = OpVariable %_ptr_Uniform_OutputBlock StorageBuffer\n"
812 << " %input_buffer = OpVariable %_ptr_Uniform_InputBlock StorageBuffer\n"
813 << " %optype_ptr = OpTypePointer StorageBuffer ${OPERAND_TYPE}\n"
814 << " %gl_WorkGroupSize = OpConstantComposite %v3uint %uint_1 %uint_1 %uint_1\n"
815 << " %main = OpFunction %void None %voidfunc\n"
816 << " %mainlabel = OpLabel\n"
817 << " %gidxptr = OpAccessChain %_ptr_Input_uint %gl_GlobalInvocationID %uint_0\n"
818 << " %idx = OpLoad %uint %gidxptr\n"
819 << " %op1ptr = OpAccessChain %optype_ptr %input_buffer %int_0 %idx %int_0\n"
820 << " %op1 = OpLoad ${OPERAND_TYPE} %op1ptr\n"
821 << " %op2ptr = OpAccessChain %optype_ptr %input_buffer %int_0 %idx %int_1\n"
822 << " %op2 = OpLoad ${OPERAND_TYPE} %op2ptr\n"
823 << " %op3ptr = OpAccessChain %optype_ptr %input_buffer %int_0 %idx %int_2\n"
824 << " %op3 = OpLoad ${OPERAND_TYPE} %op3ptr\n"
825 << " %result = OpExtInst ${OPERAND_TYPE} %trinary ${OPERATION_NAME} %op1 %op2 %op3\n"
826 << " %resultptr = OpAccessChain %optype_ptr %output_buffer %int_0 %idx\n"
827 << " OpStore %resultptr %result\n"
828 << " OpReturn\n"
829 << " OpFunctionEnd\n";
830
831 const tcu::StringTemplate shaderTemplate{shaderStr.str()};
832 const vk::SpirVAsmBuildOptions buildOptions{VK_MAKE_API_VERSION(0, 1, 2, 0), vk::SPIRV_VERSION_1_5};
833
834 programCollection.spirvAsmSources.add("comp", &buildOptions) << shaderTemplate.specialize(getSpirVReplacements());
835 }
836
TrinaryMinMaxInstance(Context & context,const TestParams & params)837 TrinaryMinMaxInstance::TrinaryMinMaxInstance(Context &context, const TestParams ¶ms)
838 : vkt::TestInstance(context)
839 , m_params(params)
840 {
841 }
842
iterate(void)843 tcu::TestStatus TrinaryMinMaxInstance::iterate(void)
844 {
845 const auto &vkd = m_context.getDeviceInterface();
846 const auto device = m_context.getDevice();
847 auto &allocator = m_context.getDefaultAllocator();
848 const auto queue = m_context.getUniversalQueue();
849 const auto queueIndex = m_context.getUniversalQueueFamilyIndex();
850
851 constexpr auto kNumOperations = TrinaryMinMaxCase::kArraySize;
852
853 const vk::VkDeviceSize kInputBufferSize =
854 static_cast<vk::VkDeviceSize>(kNumOperations * 3u * m_params.operandSize());
855 const vk::VkDeviceSize kOutputBufferSize =
856 static_cast<vk::VkDeviceSize>(kNumOperations * m_params.operandSize()); // Single output per operation.
857
858 // Create input, output and reference buffers.
859 auto inputBufferInfo = vk::makeBufferCreateInfo(kInputBufferSize, vk::VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
860 auto outputBufferInfo = vk::makeBufferCreateInfo(kOutputBufferSize, vk::VK_BUFFER_USAGE_STORAGE_BUFFER_BIT);
861
862 vk::BufferWithMemory inputBuffer{vkd, device, allocator, inputBufferInfo, vk::MemoryRequirement::HostVisible};
863 vk::BufferWithMemory outputBuffer{vkd, device, allocator, outputBufferInfo, vk::MemoryRequirement::HostVisible};
864 std::unique_ptr<char[]> referenceBuffer{new char[static_cast<size_t>(kOutputBufferSize)]};
865
866 // Fill buffers with initial contents.
867 auto &inputAlloc = inputBuffer.getAllocation();
868 auto &outputAlloc = outputBuffer.getAllocation();
869
870 void *inputBufferPtr = static_cast<uint8_t *>(inputAlloc.getHostPtr()) + inputAlloc.getOffset();
871 void *outputBufferPtr = static_cast<uint8_t *>(outputAlloc.getHostPtr()) + outputAlloc.getOffset();
872 void *referenceBufferPtr = referenceBuffer.get();
873
874 deMemset(inputBufferPtr, 0, static_cast<size_t>(kInputBufferSize));
875 deMemset(outputBufferPtr, 0, static_cast<size_t>(kOutputBufferSize));
876 deMemset(referenceBufferPtr, 0, static_cast<size_t>(kOutputBufferSize));
877
878 // Generate input buffer and calculate reference results.
879 OperationManager opMan{m_params};
880 opMan.genInputBuffer(inputBufferPtr, kNumOperations);
881 opMan.calculateResult(referenceBufferPtr, inputBufferPtr, kNumOperations);
882
883 // Flush buffer memory before starting.
884 vk::flushAlloc(vkd, device, inputAlloc);
885 vk::flushAlloc(vkd, device, outputAlloc);
886
887 // Descriptor set layout.
888 vk::DescriptorSetLayoutBuilder layoutBuilder;
889 layoutBuilder.addSingleBinding(vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, vk::VK_SHADER_STAGE_COMPUTE_BIT);
890 layoutBuilder.addSingleBinding(vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, vk::VK_SHADER_STAGE_COMPUTE_BIT);
891 auto descriptorSetLayout = layoutBuilder.build(vkd, device);
892
893 // Descriptor pool.
894 vk::DescriptorPoolBuilder poolBuilder;
895 poolBuilder.addType(vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 2u);
896 auto descriptorPool = poolBuilder.build(vkd, device, vk::VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT, 1u);
897
898 // Descriptor set.
899 const auto descriptorSet = vk::makeDescriptorSet(vkd, device, descriptorPool.get(), descriptorSetLayout.get());
900
901 // Update descriptor set using the buffers.
902 const auto inputBufferDescriptorInfo = vk::makeDescriptorBufferInfo(inputBuffer.get(), 0ull, VK_WHOLE_SIZE);
903 const auto outputBufferDescriptorInfo = vk::makeDescriptorBufferInfo(outputBuffer.get(), 0ull, VK_WHOLE_SIZE);
904
905 vk::DescriptorSetUpdateBuilder updateBuilder;
906 updateBuilder.writeSingle(descriptorSet.get(), vk::DescriptorSetUpdateBuilder::Location::binding(0u),
907 vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &inputBufferDescriptorInfo);
908 updateBuilder.writeSingle(descriptorSet.get(), vk::DescriptorSetUpdateBuilder::Location::binding(1u),
909 vk::VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, &outputBufferDescriptorInfo);
910 updateBuilder.update(vkd, device);
911
912 // Create compute pipeline.
913 auto shaderModule = vk::createShaderModule(vkd, device, m_context.getBinaryCollection().get("comp"), 0u);
914 auto pipelineLayout = vk::makePipelineLayout(vkd, device, descriptorSetLayout.get());
915
916 const vk::VkComputePipelineCreateInfo pipelineCreateInfo = {
917 vk::VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
918 nullptr,
919 0u, // flags
920 {
921 // compute shader
922 vk::VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // VkStructureType sType;
923 nullptr, // const void* pNext;
924 0u, // VkPipelineShaderStageCreateFlags flags;
925 vk::VK_SHADER_STAGE_COMPUTE_BIT, // VkShaderStageFlagBits stage;
926 shaderModule.get(), // VkShaderModule module;
927 "main", // const char* pName;
928 nullptr, // const VkSpecializationInfo* pSpecializationInfo;
929 },
930 pipelineLayout.get(), // layout
931 DE_NULL, // basePipelineHandle
932 0, // basePipelineIndex
933 };
934 auto pipeline = vk::createComputePipeline(vkd, device, DE_NULL, &pipelineCreateInfo);
935
936 // Synchronization barriers.
937 auto inputBufferHostToDevBarrier = vk::makeBufferMemoryBarrier(
938 vk::VK_ACCESS_HOST_WRITE_BIT, vk::VK_ACCESS_SHADER_READ_BIT, inputBuffer.get(), 0ull, VK_WHOLE_SIZE);
939 auto outputBufferHostToDevBarrier = vk::makeBufferMemoryBarrier(
940 vk::VK_ACCESS_HOST_WRITE_BIT, vk::VK_ACCESS_SHADER_WRITE_BIT, outputBuffer.get(), 0ull, VK_WHOLE_SIZE);
941 auto outputBufferDevToHostBarrier = vk::makeBufferMemoryBarrier(
942 vk::VK_ACCESS_SHADER_WRITE_BIT, vk::VK_ACCESS_HOST_READ_BIT, outputBuffer.get(), 0ull, VK_WHOLE_SIZE);
943
944 // Command buffer.
945 auto cmdPool = vk::makeCommandPool(vkd, device, queueIndex);
946 auto cmdBufferPtr = vk::allocateCommandBuffer(vkd, device, cmdPool.get(), vk::VK_COMMAND_BUFFER_LEVEL_PRIMARY);
947 auto cmdBuffer = cmdBufferPtr.get();
948
949 // Record and submit commands.
950 vk::beginCommandBuffer(vkd, cmdBuffer);
951 vkd.cmdBindPipeline(cmdBuffer, vk::VK_PIPELINE_BIND_POINT_COMPUTE, pipeline.get());
952 vkd.cmdBindDescriptorSets(cmdBuffer, vk::VK_PIPELINE_BIND_POINT_COMPUTE, pipelineLayout.get(), 0, 1u,
953 &descriptorSet.get(), 0u, nullptr);
954 vkd.cmdPipelineBarrier(cmdBuffer, vk::VK_PIPELINE_STAGE_HOST_BIT, vk::VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0u, 0u,
955 nullptr, 1u, &inputBufferHostToDevBarrier, 0u, nullptr);
956 vkd.cmdPipelineBarrier(cmdBuffer, vk::VK_PIPELINE_STAGE_HOST_BIT, vk::VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0u, 0u,
957 nullptr, 1u, &outputBufferHostToDevBarrier, 0u, nullptr);
958 vkd.cmdDispatch(cmdBuffer, kNumOperations, 1u, 1u);
959 vkd.cmdPipelineBarrier(cmdBuffer, vk::VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, vk::VK_PIPELINE_STAGE_HOST_BIT, 0u, 0u,
960 nullptr, 1u, &outputBufferDevToHostBarrier, 0u, nullptr);
961 vk::endCommandBuffer(vkd, cmdBuffer);
962 vk::submitCommandsAndWait(vkd, device, queue, cmdBuffer);
963
964 // Verify output buffer contents.
965 vk::invalidateAlloc(vkd, device, outputAlloc);
966
967 const auto error = opMan.compareResults(referenceBufferPtr, outputBufferPtr, kNumOperations);
968
969 if (!error)
970 return tcu::TestStatus::pass("Pass");
971
972 std::ostringstream msg;
973 msg << "Value mismatch at operation " << error.get().first << " in component " << error.get().second;
974 return tcu::TestStatus::fail(msg.str());
975 }
976
977 } // namespace
978
createTrinaryMinMaxGroup(tcu::TestContext & testCtx)979 tcu::TestCaseGroup *createTrinaryMinMaxGroup(tcu::TestContext &testCtx)
980 {
981 uint32_t seed = 0xFEE768FCu;
982 de::MovePtr<tcu::TestCaseGroup> group{new tcu::TestCaseGroup{testCtx, "amd_trinary_minmax"}};
983
984 static const std::vector<std::pair<OperationType, std::string>> operationTypes = {
985 {OperationType::MIN, "min3"},
986 {OperationType::MAX, "max3"},
987 {OperationType::MID, "mid3"},
988 };
989
990 static const std::vector<std::pair<BaseType, std::string>> baseTypes = {
991 {BaseType::TYPE_INT, "i"},
992 {BaseType::TYPE_UINT, "u"},
993 {BaseType::TYPE_FLOAT, "f"},
994 };
995
996 static const std::vector<std::pair<TypeSize, std::string>> typeSizes = {
997 {TypeSize::SIZE_8BIT, "8"},
998 {TypeSize::SIZE_16BIT, "16"},
999 {TypeSize::SIZE_32BIT, "32"},
1000 {TypeSize::SIZE_64BIT, "64"},
1001 };
1002
1003 static const std::vector<std::pair<AggregationType, std::string>> aggregationTypes = {
1004 {AggregationType::SCALAR, "scalar"},
1005 {AggregationType::VEC2, "vec2"},
1006 {AggregationType::VEC3, "vec3"},
1007 {AggregationType::VEC4, "vec4"},
1008 };
1009
1010 for (const auto &opType : operationTypes)
1011 {
1012 de::MovePtr<tcu::TestCaseGroup> opGroup{new tcu::TestCaseGroup{testCtx, opType.second.c_str()}};
1013
1014 for (const auto &baseType : baseTypes)
1015 for (const auto &typeSize : typeSizes)
1016 {
1017 // There are no 8-bit floats.
1018 if (baseType.first == BaseType::TYPE_FLOAT && typeSize.first == TypeSize::SIZE_8BIT)
1019 continue;
1020
1021 const std::string typeName = baseType.second + typeSize.second;
1022
1023 de::MovePtr<tcu::TestCaseGroup> typeGroup{new tcu::TestCaseGroup{testCtx, typeName.c_str()}};
1024
1025 for (const auto &aggType : aggregationTypes)
1026 {
1027 const TestParams params = {
1028 opType.first, // OperationType operation;
1029 baseType.first, // BaseType baseType;
1030 typeSize.first, // TypeSize typeSize;
1031 aggType.first, // AggregationType aggregation;
1032 seed++, // uint32_t randomSeed;
1033 };
1034 typeGroup->addChild(new TrinaryMinMaxCase{testCtx, aggType.second, params});
1035 }
1036
1037 opGroup->addChild(typeGroup.release());
1038 }
1039
1040 group->addChild(opGroup.release());
1041 }
1042
1043 return group.release();
1044 }
1045
1046 } // namespace SpirVAssembly
1047 } // namespace vkt
1048