1 //
2 // Copyright (c) 2022 The Khronos Group Inc.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16
17 #include "subhelpers.h"
18
19 #include <random>
20
21 // Define operator<< for cl_ types, accessing the .s member.
22 #define OP_OSTREAM(Ty, VecSize) \
23 std::ostream& operator<<(std::ostream& os, const Ty##VecSize& val) \
24 { \
25 os << +val.s[0]; /* unary plus forces char to be printed as number */ \
26 for (unsigned i = 1; i < VecSize; i++) \
27 { \
28 os << ", " << +val.s[i]; \
29 } \
30 return os; \
31 }
32
33 // Define operator<< for subgroups::cl_ types, accessing the .data member and
34 // forwarding to operator<< for the cl_ types.
35 #define OP_OSTREAM_SUBGROUP(Ty, VecSize) \
36 std::ostream& operator<<(std::ostream& os, const Ty##VecSize& val) \
37 { \
38 return os << val.data; \
39 }
40
41 // Define operator<< for all vector sizes.
42 #define OP_OSTREAM_ALL_VEC(Ty) \
43 OP_OSTREAM(Ty, 2) \
44 OP_OSTREAM(Ty, 4) \
45 OP_OSTREAM(Ty, 8) \
46 OP_OSTREAM(Ty, 16) \
47 OP_OSTREAM_SUBGROUP(subgroups::Ty, 3)
48
49 OP_OSTREAM_ALL_VEC(cl_char)
OP_OSTREAM_ALL_VEC(cl_uchar)50 OP_OSTREAM_ALL_VEC(cl_uchar)
51 OP_OSTREAM_ALL_VEC(cl_short)
52 OP_OSTREAM_ALL_VEC(cl_ushort)
53 OP_OSTREAM_ALL_VEC(cl_int)
54 OP_OSTREAM_ALL_VEC(cl_uint)
55 OP_OSTREAM_ALL_VEC(cl_long)
56 OP_OSTREAM_ALL_VEC(cl_ulong)
57 OP_OSTREAM_ALL_VEC(cl_float)
58 OP_OSTREAM_ALL_VEC(cl_double)
59 OP_OSTREAM_ALL_VEC(cl_half)
60 OP_OSTREAM_SUBGROUP(subgroups::cl_half, )
61 OP_OSTREAM_SUBGROUP(subgroups::cl_half, 2)
62 OP_OSTREAM_SUBGROUP(subgroups::cl_half, 4)
63 OP_OSTREAM_SUBGROUP(subgroups::cl_half, 8)
64 OP_OSTREAM_SUBGROUP(subgroups::cl_half, 16)
65
66 bs128 cl_uint4_to_bs128(cl_uint4 v)
67 {
68 return bs128(v.s0) | (bs128(v.s1) << 32) | (bs128(v.s2) << 64)
69 | (bs128(v.s3) << 96);
70 }
71
bs128_to_cl_uint4(bs128 v)72 cl_uint4 bs128_to_cl_uint4(bs128 v)
73 {
74 bs128 bs128_ffffffff = 0xffffffffU;
75
76 cl_uint4 r;
77 r.s0 = ((v >> 0) & bs128_ffffffff).to_ulong();
78 r.s1 = ((v >> 32) & bs128_ffffffff).to_ulong();
79 r.s2 = ((v >> 64) & bs128_ffffffff).to_ulong();
80 r.s3 = ((v >> 96) & bs128_ffffffff).to_ulong();
81
82 return r;
83 }
84
generate_bit_mask(cl_uint subgroup_local_id,const std::string & mask_type,cl_uint max_sub_group_size)85 cl_uint4 generate_bit_mask(cl_uint subgroup_local_id,
86 const std::string &mask_type,
87 cl_uint max_sub_group_size)
88 {
89 bs128 mask128;
90 cl_uint4 mask;
91 cl_uint pos = subgroup_local_id;
92 if (mask_type == "eq") mask128.set(pos);
93 if (mask_type == "le" || mask_type == "lt")
94 {
95 for (cl_uint i = 0; i <= pos; i++) mask128.set(i);
96 if (mask_type == "lt") mask128.reset(pos);
97 }
98 if (mask_type == "ge" || mask_type == "gt")
99 {
100 for (cl_uint i = pos; i < max_sub_group_size; i++) mask128.set(i);
101 if (mask_type == "gt") mask128.reset(pos);
102 }
103
104 // convert std::bitset<128> to uint4
105 auto const uint_mask = bs128{ static_cast<unsigned long>(-1) };
106 mask.s0 = (mask128 & uint_mask).to_ulong();
107 mask128 >>= 32;
108 mask.s1 = (mask128 & uint_mask).to_ulong();
109 mask128 >>= 32;
110 mask.s2 = (mask128 & uint_mask).to_ulong();
111 mask128 >>= 32;
112 mask.s3 = (mask128 & uint_mask).to_ulong();
113
114 return mask;
115 }
116
operation_names(ArithmeticOp operation)117 const char *const operation_names(ArithmeticOp operation)
118 {
119 switch (operation)
120 {
121 case ArithmeticOp::add_: return "add";
122 case ArithmeticOp::max_: return "max";
123 case ArithmeticOp::min_: return "min";
124 case ArithmeticOp::mul_: return "mul";
125 case ArithmeticOp::and_: return "and";
126 case ArithmeticOp::or_: return "or";
127 case ArithmeticOp::xor_: return "xor";
128 case ArithmeticOp::logical_and: return "logical_and";
129 case ArithmeticOp::logical_or: return "logical_or";
130 case ArithmeticOp::logical_xor: return "logical_xor";
131 default: log_error("Unknown operation request\n"); break;
132 }
133 return "";
134 }
135
operation_names(BallotOp operation)136 const char *const operation_names(BallotOp operation)
137 {
138 switch (operation)
139 {
140 case BallotOp::ballot: return "ballot";
141 case BallotOp::inverse_ballot: return "inverse_ballot";
142 case BallotOp::ballot_bit_extract: return "bit_extract";
143 case BallotOp::ballot_bit_count: return "bit_count";
144 case BallotOp::ballot_inclusive_scan: return "inclusive_scan";
145 case BallotOp::ballot_exclusive_scan: return "exclusive_scan";
146 case BallotOp::ballot_find_lsb: return "find_lsb";
147 case BallotOp::ballot_find_msb: return "find_msb";
148 case BallotOp::eq_mask: return "eq";
149 case BallotOp::ge_mask: return "ge";
150 case BallotOp::gt_mask: return "gt";
151 case BallotOp::le_mask: return "le";
152 case BallotOp::lt_mask: return "lt";
153 default: log_error("Unknown operation request\n"); break;
154 }
155 return "";
156 }
157
operation_names(ShuffleOp operation)158 const char *const operation_names(ShuffleOp operation)
159 {
160 switch (operation)
161 {
162 case ShuffleOp::shuffle: return "shuffle";
163 case ShuffleOp::shuffle_up: return "shuffle_up";
164 case ShuffleOp::shuffle_down: return "shuffle_down";
165 case ShuffleOp::shuffle_xor: return "shuffle_xor";
166 case ShuffleOp::rotate: return "rotate";
167 case ShuffleOp::clustered_rotate: return "clustered_rotate";
168 default: log_error("Unknown operation request\n"); break;
169 }
170 return "";
171 }
172
operation_names(NonUniformVoteOp operation)173 const char *const operation_names(NonUniformVoteOp operation)
174 {
175 switch (operation)
176 {
177 case NonUniformVoteOp::all: return "all";
178 case NonUniformVoteOp::all_equal: return "all_equal";
179 case NonUniformVoteOp::any: return "any";
180 case NonUniformVoteOp::elect: return "elect";
181 default: log_error("Unknown operation request\n"); break;
182 }
183 return "";
184 }
185
operation_names(SubgroupsBroadcastOp operation)186 const char *const operation_names(SubgroupsBroadcastOp operation)
187 {
188 switch (operation)
189 {
190 case SubgroupsBroadcastOp::broadcast: return "broadcast";
191 case SubgroupsBroadcastOp::broadcast_first: return "broadcast_first";
192 case SubgroupsBroadcastOp::non_uniform_broadcast:
193 return "non_uniform_broadcast";
194 default: log_error("Unknown operation request\n"); break;
195 }
196 return "";
197 }
198
set_last_workgroup_params(int non_uniform_size,int & number_of_subgroups,int subgroup_size,int & workgroup_size,int & last_subgroup_size)199 void set_last_workgroup_params(int non_uniform_size, int &number_of_subgroups,
200 int subgroup_size, int &workgroup_size,
201 int &last_subgroup_size)
202 {
203 number_of_subgroups = 1 + non_uniform_size / subgroup_size;
204 last_subgroup_size = non_uniform_size % subgroup_size;
205 workgroup_size = non_uniform_size;
206 }
207
fill_and_shuffle_safe_values(std::vector<cl_ulong> & safe_values,size_t sb_size)208 void fill_and_shuffle_safe_values(std::vector<cl_ulong> &safe_values,
209 size_t sb_size)
210 {
211 // max product is 720, cl_half has enough precision for it
212 const std::vector<cl_ulong> non_one_values{ 2, 3, 4, 5, 6 };
213
214 if (sb_size <= non_one_values.size())
215 {
216 safe_values.assign(non_one_values.begin(),
217 non_one_values.begin() + sb_size);
218 }
219 else
220 {
221 safe_values.assign(sb_size, 1);
222 std::copy(non_one_values.begin(), non_one_values.end(),
223 safe_values.begin());
224 }
225
226 std::mt19937 mersenne_twister_engine(10000);
227 std::shuffle(safe_values.begin(), safe_values.end(),
228 mersenne_twister_engine);
229 }
230