1*6467f958SSadaf Ebrahimi //
2*6467f958SSadaf Ebrahimi // Copyright (c) 2022 The Khronos Group Inc.
3*6467f958SSadaf Ebrahimi //
4*6467f958SSadaf Ebrahimi // Licensed under the Apache License, Version 2.0 (the "License");
5*6467f958SSadaf Ebrahimi // you may not use this file except in compliance with the License.
6*6467f958SSadaf Ebrahimi // You may obtain a copy of the License at
7*6467f958SSadaf Ebrahimi //
8*6467f958SSadaf Ebrahimi // http://www.apache.org/licenses/LICENSE-2.0
9*6467f958SSadaf Ebrahimi //
10*6467f958SSadaf Ebrahimi // Unless required by applicable law or agreed to in writing, software
11*6467f958SSadaf Ebrahimi // distributed under the License is distributed on an "AS IS" BASIS,
12*6467f958SSadaf Ebrahimi // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*6467f958SSadaf Ebrahimi // See the License for the specific language governing permissions and
14*6467f958SSadaf Ebrahimi // limitations under the License.
15*6467f958SSadaf Ebrahimi //
16*6467f958SSadaf Ebrahimi
17*6467f958SSadaf Ebrahimi #include "subhelpers.h"
18*6467f958SSadaf Ebrahimi
19*6467f958SSadaf Ebrahimi #include <random>
20*6467f958SSadaf Ebrahimi
21*6467f958SSadaf Ebrahimi // Define operator<< for cl_ types, accessing the .s member.
22*6467f958SSadaf Ebrahimi #define OP_OSTREAM(Ty, VecSize) \
23*6467f958SSadaf Ebrahimi std::ostream& operator<<(std::ostream& os, const Ty##VecSize& val) \
24*6467f958SSadaf Ebrahimi { \
25*6467f958SSadaf Ebrahimi os << +val.s[0]; /* unary plus forces char to be printed as number */ \
26*6467f958SSadaf Ebrahimi for (unsigned i = 1; i < VecSize; i++) \
27*6467f958SSadaf Ebrahimi { \
28*6467f958SSadaf Ebrahimi os << ", " << +val.s[i]; \
29*6467f958SSadaf Ebrahimi } \
30*6467f958SSadaf Ebrahimi return os; \
31*6467f958SSadaf Ebrahimi }
32*6467f958SSadaf Ebrahimi
33*6467f958SSadaf Ebrahimi // Define operator<< for subgroups::cl_ types, accessing the .data member and
34*6467f958SSadaf Ebrahimi // forwarding to operator<< for the cl_ types.
35*6467f958SSadaf Ebrahimi #define OP_OSTREAM_SUBGROUP(Ty, VecSize) \
36*6467f958SSadaf Ebrahimi std::ostream& operator<<(std::ostream& os, const Ty##VecSize& val) \
37*6467f958SSadaf Ebrahimi { \
38*6467f958SSadaf Ebrahimi return os << val.data; \
39*6467f958SSadaf Ebrahimi }
40*6467f958SSadaf Ebrahimi
41*6467f958SSadaf Ebrahimi // Define operator<< for all vector sizes.
42*6467f958SSadaf Ebrahimi #define OP_OSTREAM_ALL_VEC(Ty) \
43*6467f958SSadaf Ebrahimi OP_OSTREAM(Ty, 2) \
44*6467f958SSadaf Ebrahimi OP_OSTREAM(Ty, 4) \
45*6467f958SSadaf Ebrahimi OP_OSTREAM(Ty, 8) \
46*6467f958SSadaf Ebrahimi OP_OSTREAM(Ty, 16) \
47*6467f958SSadaf Ebrahimi OP_OSTREAM_SUBGROUP(subgroups::Ty, 3)
48*6467f958SSadaf Ebrahimi
49*6467f958SSadaf Ebrahimi OP_OSTREAM_ALL_VEC(cl_char)
OP_OSTREAM_ALL_VEC(cl_uchar)50*6467f958SSadaf Ebrahimi OP_OSTREAM_ALL_VEC(cl_uchar)
51*6467f958SSadaf Ebrahimi OP_OSTREAM_ALL_VEC(cl_short)
52*6467f958SSadaf Ebrahimi OP_OSTREAM_ALL_VEC(cl_ushort)
53*6467f958SSadaf Ebrahimi OP_OSTREAM_ALL_VEC(cl_int)
54*6467f958SSadaf Ebrahimi OP_OSTREAM_ALL_VEC(cl_uint)
55*6467f958SSadaf Ebrahimi OP_OSTREAM_ALL_VEC(cl_long)
56*6467f958SSadaf Ebrahimi OP_OSTREAM_ALL_VEC(cl_ulong)
57*6467f958SSadaf Ebrahimi OP_OSTREAM_ALL_VEC(cl_float)
58*6467f958SSadaf Ebrahimi OP_OSTREAM_ALL_VEC(cl_double)
59*6467f958SSadaf Ebrahimi OP_OSTREAM_ALL_VEC(cl_half)
60*6467f958SSadaf Ebrahimi OP_OSTREAM_SUBGROUP(subgroups::cl_half, )
61*6467f958SSadaf Ebrahimi OP_OSTREAM_SUBGROUP(subgroups::cl_half, 2)
62*6467f958SSadaf Ebrahimi OP_OSTREAM_SUBGROUP(subgroups::cl_half, 4)
63*6467f958SSadaf Ebrahimi OP_OSTREAM_SUBGROUP(subgroups::cl_half, 8)
64*6467f958SSadaf Ebrahimi OP_OSTREAM_SUBGROUP(subgroups::cl_half, 16)
65*6467f958SSadaf Ebrahimi
66*6467f958SSadaf Ebrahimi bs128 cl_uint4_to_bs128(cl_uint4 v)
67*6467f958SSadaf Ebrahimi {
68*6467f958SSadaf Ebrahimi return bs128(v.s0) | (bs128(v.s1) << 32) | (bs128(v.s2) << 64)
69*6467f958SSadaf Ebrahimi | (bs128(v.s3) << 96);
70*6467f958SSadaf Ebrahimi }
71*6467f958SSadaf Ebrahimi
bs128_to_cl_uint4(bs128 v)72*6467f958SSadaf Ebrahimi cl_uint4 bs128_to_cl_uint4(bs128 v)
73*6467f958SSadaf Ebrahimi {
74*6467f958SSadaf Ebrahimi bs128 bs128_ffffffff = 0xffffffffU;
75*6467f958SSadaf Ebrahimi
76*6467f958SSadaf Ebrahimi cl_uint4 r;
77*6467f958SSadaf Ebrahimi r.s0 = ((v >> 0) & bs128_ffffffff).to_ulong();
78*6467f958SSadaf Ebrahimi r.s1 = ((v >> 32) & bs128_ffffffff).to_ulong();
79*6467f958SSadaf Ebrahimi r.s2 = ((v >> 64) & bs128_ffffffff).to_ulong();
80*6467f958SSadaf Ebrahimi r.s3 = ((v >> 96) & bs128_ffffffff).to_ulong();
81*6467f958SSadaf Ebrahimi
82*6467f958SSadaf Ebrahimi return r;
83*6467f958SSadaf Ebrahimi }
84*6467f958SSadaf Ebrahimi
generate_bit_mask(cl_uint subgroup_local_id,const std::string & mask_type,cl_uint max_sub_group_size)85*6467f958SSadaf Ebrahimi cl_uint4 generate_bit_mask(cl_uint subgroup_local_id,
86*6467f958SSadaf Ebrahimi const std::string &mask_type,
87*6467f958SSadaf Ebrahimi cl_uint max_sub_group_size)
88*6467f958SSadaf Ebrahimi {
89*6467f958SSadaf Ebrahimi bs128 mask128;
90*6467f958SSadaf Ebrahimi cl_uint4 mask;
91*6467f958SSadaf Ebrahimi cl_uint pos = subgroup_local_id;
92*6467f958SSadaf Ebrahimi if (mask_type == "eq") mask128.set(pos);
93*6467f958SSadaf Ebrahimi if (mask_type == "le" || mask_type == "lt")
94*6467f958SSadaf Ebrahimi {
95*6467f958SSadaf Ebrahimi for (cl_uint i = 0; i <= pos; i++) mask128.set(i);
96*6467f958SSadaf Ebrahimi if (mask_type == "lt") mask128.reset(pos);
97*6467f958SSadaf Ebrahimi }
98*6467f958SSadaf Ebrahimi if (mask_type == "ge" || mask_type == "gt")
99*6467f958SSadaf Ebrahimi {
100*6467f958SSadaf Ebrahimi for (cl_uint i = pos; i < max_sub_group_size; i++) mask128.set(i);
101*6467f958SSadaf Ebrahimi if (mask_type == "gt") mask128.reset(pos);
102*6467f958SSadaf Ebrahimi }
103*6467f958SSadaf Ebrahimi
104*6467f958SSadaf Ebrahimi // convert std::bitset<128> to uint4
105*6467f958SSadaf Ebrahimi auto const uint_mask = bs128{ static_cast<unsigned long>(-1) };
106*6467f958SSadaf Ebrahimi mask.s0 = (mask128 & uint_mask).to_ulong();
107*6467f958SSadaf Ebrahimi mask128 >>= 32;
108*6467f958SSadaf Ebrahimi mask.s1 = (mask128 & uint_mask).to_ulong();
109*6467f958SSadaf Ebrahimi mask128 >>= 32;
110*6467f958SSadaf Ebrahimi mask.s2 = (mask128 & uint_mask).to_ulong();
111*6467f958SSadaf Ebrahimi mask128 >>= 32;
112*6467f958SSadaf Ebrahimi mask.s3 = (mask128 & uint_mask).to_ulong();
113*6467f958SSadaf Ebrahimi
114*6467f958SSadaf Ebrahimi return mask;
115*6467f958SSadaf Ebrahimi }
116*6467f958SSadaf Ebrahimi
operation_names(ArithmeticOp operation)117*6467f958SSadaf Ebrahimi const char *const operation_names(ArithmeticOp operation)
118*6467f958SSadaf Ebrahimi {
119*6467f958SSadaf Ebrahimi switch (operation)
120*6467f958SSadaf Ebrahimi {
121*6467f958SSadaf Ebrahimi case ArithmeticOp::add_: return "add";
122*6467f958SSadaf Ebrahimi case ArithmeticOp::max_: return "max";
123*6467f958SSadaf Ebrahimi case ArithmeticOp::min_: return "min";
124*6467f958SSadaf Ebrahimi case ArithmeticOp::mul_: return "mul";
125*6467f958SSadaf Ebrahimi case ArithmeticOp::and_: return "and";
126*6467f958SSadaf Ebrahimi case ArithmeticOp::or_: return "or";
127*6467f958SSadaf Ebrahimi case ArithmeticOp::xor_: return "xor";
128*6467f958SSadaf Ebrahimi case ArithmeticOp::logical_and: return "logical_and";
129*6467f958SSadaf Ebrahimi case ArithmeticOp::logical_or: return "logical_or";
130*6467f958SSadaf Ebrahimi case ArithmeticOp::logical_xor: return "logical_xor";
131*6467f958SSadaf Ebrahimi default: log_error("Unknown operation request\n"); break;
132*6467f958SSadaf Ebrahimi }
133*6467f958SSadaf Ebrahimi return "";
134*6467f958SSadaf Ebrahimi }
135*6467f958SSadaf Ebrahimi
operation_names(BallotOp operation)136*6467f958SSadaf Ebrahimi const char *const operation_names(BallotOp operation)
137*6467f958SSadaf Ebrahimi {
138*6467f958SSadaf Ebrahimi switch (operation)
139*6467f958SSadaf Ebrahimi {
140*6467f958SSadaf Ebrahimi case BallotOp::ballot: return "ballot";
141*6467f958SSadaf Ebrahimi case BallotOp::inverse_ballot: return "inverse_ballot";
142*6467f958SSadaf Ebrahimi case BallotOp::ballot_bit_extract: return "bit_extract";
143*6467f958SSadaf Ebrahimi case BallotOp::ballot_bit_count: return "bit_count";
144*6467f958SSadaf Ebrahimi case BallotOp::ballot_inclusive_scan: return "inclusive_scan";
145*6467f958SSadaf Ebrahimi case BallotOp::ballot_exclusive_scan: return "exclusive_scan";
146*6467f958SSadaf Ebrahimi case BallotOp::ballot_find_lsb: return "find_lsb";
147*6467f958SSadaf Ebrahimi case BallotOp::ballot_find_msb: return "find_msb";
148*6467f958SSadaf Ebrahimi case BallotOp::eq_mask: return "eq";
149*6467f958SSadaf Ebrahimi case BallotOp::ge_mask: return "ge";
150*6467f958SSadaf Ebrahimi case BallotOp::gt_mask: return "gt";
151*6467f958SSadaf Ebrahimi case BallotOp::le_mask: return "le";
152*6467f958SSadaf Ebrahimi case BallotOp::lt_mask: return "lt";
153*6467f958SSadaf Ebrahimi default: log_error("Unknown operation request\n"); break;
154*6467f958SSadaf Ebrahimi }
155*6467f958SSadaf Ebrahimi return "";
156*6467f958SSadaf Ebrahimi }
157*6467f958SSadaf Ebrahimi
operation_names(ShuffleOp operation)158*6467f958SSadaf Ebrahimi const char *const operation_names(ShuffleOp operation)
159*6467f958SSadaf Ebrahimi {
160*6467f958SSadaf Ebrahimi switch (operation)
161*6467f958SSadaf Ebrahimi {
162*6467f958SSadaf Ebrahimi case ShuffleOp::shuffle: return "shuffle";
163*6467f958SSadaf Ebrahimi case ShuffleOp::shuffle_up: return "shuffle_up";
164*6467f958SSadaf Ebrahimi case ShuffleOp::shuffle_down: return "shuffle_down";
165*6467f958SSadaf Ebrahimi case ShuffleOp::shuffle_xor: return "shuffle_xor";
166*6467f958SSadaf Ebrahimi case ShuffleOp::rotate: return "rotate";
167*6467f958SSadaf Ebrahimi case ShuffleOp::clustered_rotate: return "clustered_rotate";
168*6467f958SSadaf Ebrahimi default: log_error("Unknown operation request\n"); break;
169*6467f958SSadaf Ebrahimi }
170*6467f958SSadaf Ebrahimi return "";
171*6467f958SSadaf Ebrahimi }
172*6467f958SSadaf Ebrahimi
operation_names(NonUniformVoteOp operation)173*6467f958SSadaf Ebrahimi const char *const operation_names(NonUniformVoteOp operation)
174*6467f958SSadaf Ebrahimi {
175*6467f958SSadaf Ebrahimi switch (operation)
176*6467f958SSadaf Ebrahimi {
177*6467f958SSadaf Ebrahimi case NonUniformVoteOp::all: return "all";
178*6467f958SSadaf Ebrahimi case NonUniformVoteOp::all_equal: return "all_equal";
179*6467f958SSadaf Ebrahimi case NonUniformVoteOp::any: return "any";
180*6467f958SSadaf Ebrahimi case NonUniformVoteOp::elect: return "elect";
181*6467f958SSadaf Ebrahimi default: log_error("Unknown operation request\n"); break;
182*6467f958SSadaf Ebrahimi }
183*6467f958SSadaf Ebrahimi return "";
184*6467f958SSadaf Ebrahimi }
185*6467f958SSadaf Ebrahimi
operation_names(SubgroupsBroadcastOp operation)186*6467f958SSadaf Ebrahimi const char *const operation_names(SubgroupsBroadcastOp operation)
187*6467f958SSadaf Ebrahimi {
188*6467f958SSadaf Ebrahimi switch (operation)
189*6467f958SSadaf Ebrahimi {
190*6467f958SSadaf Ebrahimi case SubgroupsBroadcastOp::broadcast: return "broadcast";
191*6467f958SSadaf Ebrahimi case SubgroupsBroadcastOp::broadcast_first: return "broadcast_first";
192*6467f958SSadaf Ebrahimi case SubgroupsBroadcastOp::non_uniform_broadcast:
193*6467f958SSadaf Ebrahimi return "non_uniform_broadcast";
194*6467f958SSadaf Ebrahimi default: log_error("Unknown operation request\n"); break;
195*6467f958SSadaf Ebrahimi }
196*6467f958SSadaf Ebrahimi return "";
197*6467f958SSadaf Ebrahimi }
198*6467f958SSadaf Ebrahimi
set_last_workgroup_params(int non_uniform_size,int & number_of_subgroups,int subgroup_size,int & workgroup_size,int & last_subgroup_size)199*6467f958SSadaf Ebrahimi void set_last_workgroup_params(int non_uniform_size, int &number_of_subgroups,
200*6467f958SSadaf Ebrahimi int subgroup_size, int &workgroup_size,
201*6467f958SSadaf Ebrahimi int &last_subgroup_size)
202*6467f958SSadaf Ebrahimi {
203*6467f958SSadaf Ebrahimi number_of_subgroups = 1 + non_uniform_size / subgroup_size;
204*6467f958SSadaf Ebrahimi last_subgroup_size = non_uniform_size % subgroup_size;
205*6467f958SSadaf Ebrahimi workgroup_size = non_uniform_size;
206*6467f958SSadaf Ebrahimi }
207*6467f958SSadaf Ebrahimi
fill_and_shuffle_safe_values(std::vector<cl_ulong> & safe_values,size_t sb_size)208*6467f958SSadaf Ebrahimi void fill_and_shuffle_safe_values(std::vector<cl_ulong> &safe_values,
209*6467f958SSadaf Ebrahimi size_t sb_size)
210*6467f958SSadaf Ebrahimi {
211*6467f958SSadaf Ebrahimi // max product is 720, cl_half has enough precision for it
212*6467f958SSadaf Ebrahimi const std::vector<cl_ulong> non_one_values{ 2, 3, 4, 5, 6 };
213*6467f958SSadaf Ebrahimi
214*6467f958SSadaf Ebrahimi if (sb_size <= non_one_values.size())
215*6467f958SSadaf Ebrahimi {
216*6467f958SSadaf Ebrahimi safe_values.assign(non_one_values.begin(),
217*6467f958SSadaf Ebrahimi non_one_values.begin() + sb_size);
218*6467f958SSadaf Ebrahimi }
219*6467f958SSadaf Ebrahimi else
220*6467f958SSadaf Ebrahimi {
221*6467f958SSadaf Ebrahimi safe_values.assign(sb_size, 1);
222*6467f958SSadaf Ebrahimi std::copy(non_one_values.begin(), non_one_values.end(),
223*6467f958SSadaf Ebrahimi safe_values.begin());
224*6467f958SSadaf Ebrahimi }
225*6467f958SSadaf Ebrahimi
226*6467f958SSadaf Ebrahimi std::mt19937 mersenne_twister_engine(10000);
227*6467f958SSadaf Ebrahimi std::shuffle(safe_values.begin(), safe_values.end(),
228*6467f958SSadaf Ebrahimi mersenne_twister_engine);
229*6467f958SSadaf Ebrahimi }
230