xref: /aosp_15_r20/external/mesa3d/src/compiler/spirv/vtn_opencl.c (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright © 2018 Red Hat
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  * Authors:
24  *    Rob Clark ([email protected])
25  */
26 
27 #include "math.h"
28 #include "nir/nir_builtin_builder.h"
29 
30 #include "util/u_printf.h"
31 #include "vtn_private.h"
32 #include "OpenCL.std.h"
33 
34 typedef nir_def *(*nir_handler)(struct vtn_builder *b,
35                                     uint32_t opcode,
36                                     unsigned num_srcs, nir_def **srcs,
37                                     struct vtn_type **src_types,
38                                     const struct vtn_type *dest_type);
39 
to_llvm_address_space(SpvStorageClass mode)40 static int to_llvm_address_space(SpvStorageClass mode)
41 {
42    switch (mode) {
43    case SpvStorageClassPrivate:
44    case SpvStorageClassFunction: return 0;
45    case SpvStorageClassCrossWorkgroup: return 1;
46    case SpvStorageClassUniform:
47    case SpvStorageClassUniformConstant: return 2;
48    case SpvStorageClassWorkgroup: return 3;
49    case SpvStorageClassGeneric: return 4;
50    default: return -1;
51    }
52 }
53 
54 
55 static void
vtn_opencl_mangle(const char * in_name,uint32_t const_mask,int ntypes,struct vtn_type ** src_types,char ** outstring)56 vtn_opencl_mangle(const char *in_name,
57                   uint32_t const_mask,
58                   int ntypes, struct vtn_type **src_types,
59                   char **outstring)
60 {
61    char local_name[256] = "";
62    char *args_str = local_name + sprintf(local_name, "_Z%zu%s", strlen(in_name), in_name);
63 
64    for (unsigned i = 0; i < ntypes; ++i) {
65       const struct glsl_type *type = src_types[i]->type;
66       enum vtn_base_type base_type = src_types[i]->base_type;
67       if (src_types[i]->base_type == vtn_base_type_pointer) {
68          *(args_str++) = 'P';
69          int address_space = to_llvm_address_space(src_types[i]->storage_class);
70          if (address_space > 0)
71             args_str += sprintf(args_str, "U3AS%d", address_space);
72 
73          type = src_types[i]->pointed->type;
74          base_type = src_types[i]->pointed->base_type;
75       }
76 
77       if (const_mask & (1 << i))
78          *(args_str++) = 'K';
79 
80       unsigned num_elements = glsl_get_components(type);
81       if (num_elements > 1) {
82          /* Vectors are not treated as built-ins for mangling, so check for substitution.
83           * In theory, we'd need to know which substitution value this is. In practice,
84           * the functions we need from libclc only support 1
85           */
86          bool substitution = false;
87          for (unsigned j = 0; j < i; ++j) {
88             const struct glsl_type *other_type = src_types[j]->base_type == vtn_base_type_pointer ?
89                src_types[j]->pointed->type : src_types[j]->type;
90             if (type == other_type) {
91                substitution = true;
92                break;
93             }
94          }
95 
96          if (substitution) {
97             args_str += sprintf(args_str, "S_");
98             continue;
99          } else
100             args_str += sprintf(args_str, "Dv%d_", num_elements);
101       }
102 
103       const char *suffix = NULL;
104       switch (base_type) {
105       case vtn_base_type_sampler: suffix = "11ocl_sampler"; break;
106       case vtn_base_type_event: suffix = "9ocl_event"; break;
107       default: {
108          const char *primitives[] = {
109             [GLSL_TYPE_UINT] = "j",
110             [GLSL_TYPE_INT] = "i",
111             [GLSL_TYPE_FLOAT] = "f",
112             [GLSL_TYPE_FLOAT16] = "Dh",
113             [GLSL_TYPE_DOUBLE] = "d",
114             [GLSL_TYPE_UINT8] = "h",
115             [GLSL_TYPE_INT8] = "c",
116             [GLSL_TYPE_UINT16] = "t",
117             [GLSL_TYPE_INT16] = "s",
118             [GLSL_TYPE_UINT64] = "m",
119             [GLSL_TYPE_INT64] = "l",
120             [GLSL_TYPE_BOOL] = "b",
121             [GLSL_TYPE_ERROR] = NULL,
122          };
123          enum glsl_base_type glsl_base_type = glsl_get_base_type(type);
124          assert(glsl_base_type < ARRAY_SIZE(primitives) && primitives[glsl_base_type]);
125          suffix = primitives[glsl_base_type];
126          break;
127       }
128       }
129       args_str += sprintf(args_str, "%s", suffix);
130    }
131 
132    *outstring = strdup(local_name);
133 }
134 
mangle_and_find(struct vtn_builder * b,const char * name,uint32_t const_mask,uint32_t num_srcs,struct vtn_type ** src_types)135 static nir_function *mangle_and_find(struct vtn_builder *b,
136                                      const char *name,
137                                      uint32_t const_mask,
138                                      uint32_t num_srcs,
139                                      struct vtn_type **src_types)
140 {
141    char *mname;
142 
143    vtn_opencl_mangle(name, const_mask, num_srcs, src_types, &mname);
144 
145    /* try and find in current shader first. */
146    nir_function *found = nir_shader_get_function_for_name(b->shader, mname);
147 
148    /* if not found here find in clc shader and create a decl mirroring it */
149    if (!found && b->options->clc_shader && b->options->clc_shader != b->shader) {
150       found = nir_shader_get_function_for_name(b->options->clc_shader, mname);
151       if (found) {
152          nir_function *decl = nir_function_create(b->shader, mname);
153          decl->num_params = found->num_params;
154          decl->params = ralloc_array(b->shader, nir_parameter, decl->num_params);
155          for (unsigned i = 0; i < decl->num_params; i++) {
156             decl->params[i] = found->params[i];
157          }
158          found = decl;
159       }
160    }
161    if (!found)
162       vtn_fail("Can't find clc function %s\n", mname);
163    free(mname);
164    return found;
165 }
166 
call_mangled_function(struct vtn_builder * b,const char * name,uint32_t const_mask,uint32_t num_srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type,nir_def ** srcs,nir_deref_instr ** ret_deref_ptr)167 static bool call_mangled_function(struct vtn_builder *b,
168                                   const char *name,
169                                   uint32_t const_mask,
170                                   uint32_t num_srcs,
171                                   struct vtn_type **src_types,
172                                   const struct vtn_type *dest_type,
173                                   nir_def **srcs,
174                                   nir_deref_instr **ret_deref_ptr)
175 {
176    nir_function *found = mangle_and_find(b, name, const_mask, num_srcs, src_types);
177    if (!found)
178       return false;
179 
180    nir_call_instr *call = nir_call_instr_create(b->shader, found);
181 
182    nir_deref_instr *ret_deref = NULL;
183    uint32_t param_idx = 0;
184    if (dest_type) {
185       nir_variable *ret_tmp = nir_local_variable_create(b->nb.impl,
186                                                         glsl_get_bare_type(dest_type->type),
187                                                         "return_tmp");
188       ret_deref = nir_build_deref_var(&b->nb, ret_tmp);
189       call->params[param_idx++] = nir_src_for_ssa(&ret_deref->def);
190    }
191 
192    for (unsigned i = 0; i < num_srcs; i++)
193       call->params[param_idx++] = nir_src_for_ssa(srcs[i]);
194    nir_builder_instr_insert(&b->nb, &call->instr);
195 
196    *ret_deref_ptr = ret_deref;
197    return true;
198 }
199 
200 static void
handle_instr(struct vtn_builder * b,uint32_t opcode,const uint32_t * w_src,unsigned num_srcs,const uint32_t * w_dest,nir_handler handler)201 handle_instr(struct vtn_builder *b, uint32_t opcode,
202              const uint32_t *w_src, unsigned num_srcs, const uint32_t *w_dest, nir_handler handler)
203 {
204    struct vtn_type *dest_type = w_dest ? vtn_get_type(b, w_dest[0]) : NULL;
205 
206    nir_def *srcs[5] = { NULL };
207    struct vtn_type *src_types[5] = { NULL };
208    vtn_assert(num_srcs <= ARRAY_SIZE(srcs));
209    for (unsigned i = 0; i < num_srcs; i++) {
210       struct vtn_value *val = vtn_untyped_value(b, w_src[i]);
211       struct vtn_ssa_value *ssa = vtn_ssa_value(b, w_src[i]);
212       srcs[i] = ssa->def;
213       src_types[i] = val->type;
214    }
215 
216    nir_def *result = handler(b, opcode, num_srcs, srcs, src_types, dest_type);
217    if (result) {
218       vtn_push_nir_ssa(b, w_dest[1], result);
219    } else {
220       vtn_assert(dest_type == NULL);
221    }
222 }
223 
224 static nir_op
nir_alu_op_for_opencl_opcode(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode)225 nir_alu_op_for_opencl_opcode(struct vtn_builder *b,
226                              enum OpenCLstd_Entrypoints opcode)
227 {
228    switch (opcode) {
229    case OpenCLstd_Fabs: return nir_op_fabs;
230    case OpenCLstd_SAbs: return nir_op_iabs;
231    case OpenCLstd_SAdd_sat: return nir_op_iadd_sat;
232    case OpenCLstd_UAdd_sat: return nir_op_uadd_sat;
233    case OpenCLstd_Ceil: return nir_op_fceil;
234    case OpenCLstd_Floor: return nir_op_ffloor;
235    case OpenCLstd_SHadd: return nir_op_ihadd;
236    case OpenCLstd_UHadd: return nir_op_uhadd;
237    case OpenCLstd_Fmax: return nir_op_fmax;
238    case OpenCLstd_SMax: return nir_op_imax;
239    case OpenCLstd_UMax: return nir_op_umax;
240    case OpenCLstd_Fmin: return nir_op_fmin;
241    case OpenCLstd_SMin: return nir_op_imin;
242    case OpenCLstd_UMin: return nir_op_umin;
243    case OpenCLstd_Mix: return nir_op_flrp;
244    case OpenCLstd_Native_cos: return nir_op_fcos;
245    case OpenCLstd_Native_divide: return nir_op_fdiv;
246    case OpenCLstd_Native_exp2: return nir_op_fexp2;
247    case OpenCLstd_Native_log2: return nir_op_flog2;
248    case OpenCLstd_Native_powr: return nir_op_fpow;
249    case OpenCLstd_Native_recip: return nir_op_frcp;
250    case OpenCLstd_Native_rsqrt: return nir_op_frsq;
251    case OpenCLstd_Native_sin: return nir_op_fsin;
252    case OpenCLstd_Native_sqrt: return nir_op_fsqrt;
253    case OpenCLstd_SMul_hi: return nir_op_imul_high;
254    case OpenCLstd_UMul_hi: return nir_op_umul_high;
255    case OpenCLstd_Popcount: return nir_op_bit_count;
256    case OpenCLstd_SRhadd: return nir_op_irhadd;
257    case OpenCLstd_URhadd: return nir_op_urhadd;
258    case OpenCLstd_Rsqrt: return nir_op_frsq;
259    case OpenCLstd_Sign: return nir_op_fsign;
260    case OpenCLstd_Sqrt: return nir_op_fsqrt;
261    case OpenCLstd_SSub_sat: return nir_op_isub_sat;
262    case OpenCLstd_USub_sat: return nir_op_usub_sat;
263    case OpenCLstd_Trunc: return nir_op_ftrunc;
264    case OpenCLstd_Rint: return nir_op_fround_even;
265    case OpenCLstd_Half_divide: return nir_op_fdiv;
266    case OpenCLstd_Half_recip: return nir_op_frcp;
267    /* uhm... */
268    case OpenCLstd_UAbs: return nir_op_mov;
269    default:
270       vtn_fail("No NIR equivalent");
271    }
272 }
273 
274 static nir_def *
handle_alu(struct vtn_builder * b,uint32_t opcode,unsigned num_srcs,nir_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)275 handle_alu(struct vtn_builder *b, uint32_t opcode,
276            unsigned num_srcs, nir_def **srcs, struct vtn_type **src_types,
277            const struct vtn_type *dest_type)
278 {
279    nir_def *ret = nir_build_alu(&b->nb, nir_alu_op_for_opencl_opcode(b, (enum OpenCLstd_Entrypoints)opcode),
280                                     srcs[0], srcs[1], srcs[2], NULL);
281    if (opcode == OpenCLstd_Popcount)
282       ret = nir_u2uN(&b->nb, ret, glsl_get_bit_size(dest_type->type));
283    return ret;
284 }
285 
286 #define REMAP(op, str) [OpenCLstd_##op] = { str }
287 static const struct {
288    const char *fn;
289 } remap_table[] = {
290    REMAP(Distance, "distance"),
291    REMAP(Fast_distance, "fast_distance"),
292    REMAP(Fast_length, "fast_length"),
293    REMAP(Fast_normalize, "fast_normalize"),
294    REMAP(Half_rsqrt, "half_rsqrt"),
295    REMAP(Half_sqrt, "half_sqrt"),
296    REMAP(Length, "length"),
297    REMAP(Normalize, "normalize"),
298    REMAP(Degrees, "degrees"),
299    REMAP(Radians, "radians"),
300    REMAP(Rotate, "rotate"),
301    REMAP(Smoothstep, "smoothstep"),
302    REMAP(Step, "step"),
303 
304    REMAP(Pow, "pow"),
305    REMAP(Pown, "pown"),
306    REMAP(Powr, "powr"),
307    REMAP(Rootn, "rootn"),
308    REMAP(Modf, "modf"),
309 
310    REMAP(Acos, "acos"),
311    REMAP(Acosh, "acosh"),
312    REMAP(Acospi, "acospi"),
313    REMAP(Asin, "asin"),
314    REMAP(Asinh, "asinh"),
315    REMAP(Asinpi, "asinpi"),
316    REMAP(Atan, "atan"),
317    REMAP(Atan2, "atan2"),
318    REMAP(Atanh, "atanh"),
319    REMAP(Atanpi, "atanpi"),
320    REMAP(Atan2pi, "atan2pi"),
321    REMAP(Cos, "cos"),
322    REMAP(Cosh, "cosh"),
323    REMAP(Cospi, "cospi"),
324    REMAP(Sin, "sin"),
325    REMAP(Sinh, "sinh"),
326    REMAP(Sinpi, "sinpi"),
327    REMAP(Tan, "tan"),
328    REMAP(Tanh, "tanh"),
329    REMAP(Tanpi, "tanpi"),
330    REMAP(Sincos, "sincos"),
331    REMAP(Fract, "fract"),
332    REMAP(Frexp, "frexp"),
333    REMAP(Fma, "fma"),
334    REMAP(Fmod, "fmod"),
335 
336    REMAP(Half_cos, "cos"),
337    REMAP(Half_exp, "exp"),
338    REMAP(Half_exp2, "exp2"),
339    REMAP(Half_exp10, "exp10"),
340    REMAP(Half_log, "log"),
341    REMAP(Half_log2, "log2"),
342    REMAP(Half_log10, "log10"),
343    REMAP(Half_powr, "powr"),
344    REMAP(Half_sin, "sin"),
345    REMAP(Half_tan, "tan"),
346 
347    REMAP(Remainder, "remainder"),
348    REMAP(Remquo, "remquo"),
349    REMAP(Hypot, "hypot"),
350    REMAP(Exp, "exp"),
351    REMAP(Exp2, "exp2"),
352    REMAP(Exp10, "exp10"),
353    REMAP(Expm1, "expm1"),
354    REMAP(Ldexp, "ldexp"),
355 
356    REMAP(Ilogb, "ilogb"),
357    REMAP(Log, "log"),
358    REMAP(Log2, "log2"),
359    REMAP(Log10, "log10"),
360    REMAP(Log1p, "log1p"),
361    REMAP(Logb, "logb"),
362 
363    REMAP(Cbrt, "cbrt"),
364    REMAP(Erfc, "erfc"),
365    REMAP(Erf, "erf"),
366 
367    REMAP(Lgamma, "lgamma"),
368    REMAP(Lgamma_r, "lgamma_r"),
369    REMAP(Tgamma, "tgamma"),
370 
371    REMAP(UMad_sat, "mad_sat"),
372    REMAP(SMad_sat, "mad_sat"),
373 
374    REMAP(Shuffle, "shuffle"),
375    REMAP(Shuffle2, "shuffle2"),
376 };
377 #undef REMAP
378 
remap_clc_opcode(enum OpenCLstd_Entrypoints opcode)379 static const char *remap_clc_opcode(enum OpenCLstd_Entrypoints opcode)
380 {
381    if (opcode >= (sizeof(remap_table) / sizeof(const char *)))
382       return NULL;
383    return remap_table[opcode].fn;
384 }
385 
386 static struct vtn_type *
get_vtn_type_for_glsl_type(struct vtn_builder * b,const struct glsl_type * type)387 get_vtn_type_for_glsl_type(struct vtn_builder *b, const struct glsl_type *type)
388 {
389    struct vtn_type *ret = vtn_zalloc(b, struct vtn_type);
390    assert(glsl_type_is_vector_or_scalar(type));
391    ret->type = type;
392    ret->length = glsl_get_vector_elements(type);
393    ret->base_type = glsl_type_is_vector(type) ? vtn_base_type_vector : vtn_base_type_scalar;
394    return ret;
395 }
396 
397 static struct vtn_type *
get_pointer_type(struct vtn_builder * b,struct vtn_type * t,SpvStorageClass storage_class)398 get_pointer_type(struct vtn_builder *b, struct vtn_type *t, SpvStorageClass storage_class)
399 {
400    struct vtn_type *ret = vtn_zalloc(b, struct vtn_type);
401    ret->type = nir_address_format_to_glsl_type(
402             vtn_mode_to_address_format(
403                b, vtn_storage_class_to_mode(b, storage_class, NULL, NULL)));
404    ret->base_type = vtn_base_type_pointer;
405    ret->storage_class = storage_class;
406    ret->pointed = t;
407    return ret;
408 }
409 
410 static struct vtn_type *
get_signed_type(struct vtn_builder * b,struct vtn_type * t)411 get_signed_type(struct vtn_builder *b, struct vtn_type *t)
412 {
413    if (t->base_type == vtn_base_type_pointer) {
414       return get_pointer_type(b, get_signed_type(b, t->pointed), t->storage_class);
415    }
416    return get_vtn_type_for_glsl_type(
417       b, glsl_vector_type(glsl_signed_base_type_of(glsl_get_base_type(t->type)),
418                           glsl_get_vector_elements(t->type)));
419 }
420 
421 static nir_def *
handle_clc_fn(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,int num_srcs,nir_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)422 handle_clc_fn(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
423               int num_srcs,
424               nir_def **srcs,
425               struct vtn_type **src_types,
426               const struct vtn_type *dest_type)
427 {
428    const char *name = remap_clc_opcode(opcode);
429    if (!name)
430        return NULL;
431 
432    /* Some functions which take params end up with uint (or pointer-to-uint) being passed,
433     * which doesn't mangle correctly when the function expects int or pointer-to-int.
434     * See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers
435     */
436    int signed_param = -1;
437    switch (opcode) {
438    case OpenCLstd_Frexp:
439    case OpenCLstd_Lgamma_r:
440    case OpenCLstd_Pown:
441    case OpenCLstd_Rootn:
442    case OpenCLstd_Ldexp:
443       signed_param = 1;
444       break;
445    case OpenCLstd_Remquo:
446       signed_param = 2;
447       break;
448    case OpenCLstd_SMad_sat: {
449       /* All parameters need to be converted to signed */
450       src_types[0] = src_types[1] = src_types[2] = get_signed_type(b, src_types[0]);
451       break;
452    }
453    default: break;
454    }
455 
456    if (signed_param >= 0) {
457       src_types[signed_param] = get_signed_type(b, src_types[signed_param]);
458    }
459 
460    nir_deref_instr *ret_deref = NULL;
461 
462    if (!call_mangled_function(b, name, 0, num_srcs, src_types,
463                               dest_type, srcs, &ret_deref))
464       return NULL;
465 
466    return ret_deref ? nir_load_deref(&b->nb, ret_deref) : NULL;
467 }
468 
469 static nir_def *
handle_special(struct vtn_builder * b,uint32_t opcode,unsigned num_srcs,nir_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)470 handle_special(struct vtn_builder *b, uint32_t opcode,
471                unsigned num_srcs, nir_def **srcs, struct vtn_type **src_types,
472                const struct vtn_type *dest_type)
473 {
474    nir_builder *nb = &b->nb;
475    enum OpenCLstd_Entrypoints cl_opcode = (enum OpenCLstd_Entrypoints)opcode;
476 
477    switch (cl_opcode) {
478    case OpenCLstd_SAbs_diff:
479      /* these works easier in direct NIR */
480       return nir_iabs_diff(nb, srcs[0], srcs[1]);
481    case OpenCLstd_UAbs_diff:
482       return nir_uabs_diff(nb, srcs[0], srcs[1]);
483    case OpenCLstd_Bitselect:
484       return nir_bitselect(nb, srcs[0], srcs[1], srcs[2]);
485    case OpenCLstd_SMad_hi:
486       return nir_imad_hi(nb, srcs[0], srcs[1], srcs[2]);
487    case OpenCLstd_UMad_hi:
488       return nir_umad_hi(nb, srcs[0], srcs[1], srcs[2]);
489    case OpenCLstd_SMul24:
490       return nir_imul24_relaxed(nb, srcs[0], srcs[1]);
491    case OpenCLstd_UMul24:
492       return nir_umul24_relaxed(nb, srcs[0], srcs[1]);
493    case OpenCLstd_SMad24:
494       return nir_iadd(nb, nir_imul24_relaxed(nb, srcs[0], srcs[1]), srcs[2]);
495    case OpenCLstd_UMad24:
496       return nir_umad24_relaxed(nb, srcs[0], srcs[1], srcs[2]);
497    case OpenCLstd_FClamp:
498       return nir_fclamp(nb, srcs[0], srcs[1], srcs[2]);
499    case OpenCLstd_SClamp:
500       return nir_iclamp(nb, srcs[0], srcs[1], srcs[2]);
501    case OpenCLstd_UClamp:
502       return nir_uclamp(nb, srcs[0], srcs[1], srcs[2]);
503    case OpenCLstd_Copysign:
504       return nir_copysign(nb, srcs[0], srcs[1]);
505    case OpenCLstd_Cross:
506       if (dest_type->length == 4)
507          return nir_cross4(nb, srcs[0], srcs[1]);
508       return nir_cross3(nb, srcs[0], srcs[1]);
509    case OpenCLstd_Fdim:
510       return nir_fdim(nb, srcs[0], srcs[1]);
511    case OpenCLstd_Mad: {
512       /* The spec says mad is
513        *
514        *    Implemented either as a correctly rounded fma or as a multiply
515        *    followed by an add both of which are correctly rounded
516        *
517        * So lower to fmul+fadd if we have to, but fuse to an ffma if the backend
518        * supports that. This can be significantly faster.
519        */
520       bool lower =
521          ((nb->shader->options->lower_ffma16 && srcs[0]->bit_size == 16) ||
522           (nb->shader->options->lower_ffma32 && srcs[0]->bit_size == 32) ||
523           (nb->shader->options->lower_ffma64 && srcs[0]->bit_size == 64));
524 
525       if (lower)
526          return nir_fmad(nb, srcs[0], srcs[1], srcs[2]);
527       else
528          return nir_ffma(nb, srcs[0], srcs[1], srcs[2]);
529    }
530    case OpenCLstd_Maxmag:
531       return nir_maxmag(nb, srcs[0], srcs[1]);
532    case OpenCLstd_Minmag:
533       return nir_minmag(nb, srcs[0], srcs[1]);
534    case OpenCLstd_Nan:
535       return nir_nan(nb, srcs[0]);
536    case OpenCLstd_Nextafter:
537       return nir_nextafter(nb, srcs[0], srcs[1]);
538    case OpenCLstd_Normalize:
539       return nir_normalize(nb, srcs[0]);
540    case OpenCLstd_Clz:
541       return nir_clz_u(nb, srcs[0]);
542    case OpenCLstd_Ctz:
543       return nir_ctz_u(nb, srcs[0]);
544    case OpenCLstd_Select:
545       return nir_select(nb, srcs[0], srcs[1], srcs[2]);
546    case OpenCLstd_S_Upsample:
547    case OpenCLstd_U_Upsample:
548       /* SPIR-V and CL have different defs for upsample, just implement in nir */
549       return nir_upsample(nb, srcs[0], srcs[1]);
550    case OpenCLstd_Native_exp:
551       return nir_fexp(nb, srcs[0]);
552    case OpenCLstd_Native_exp10:
553       return nir_fexp2(nb, nir_fmul_imm(nb, srcs[0], log(10) / log(2)));
554    case OpenCLstd_Native_log:
555       return nir_flog(nb, srcs[0]);
556    case OpenCLstd_Native_log10:
557       return nir_fmul_imm(nb, nir_flog2(nb, srcs[0]), log(2) / log(10));
558    case OpenCLstd_Native_tan:
559       return nir_ftan(nb, srcs[0]);
560    case OpenCLstd_Ldexp:
561       if (nb->shader->options->lower_ldexp)
562          break;
563       return nir_ldexp(nb, srcs[0], srcs[1]);
564    case OpenCLstd_Fma:
565       /* FIXME: the software implementation only supports fp32 for now. */
566       if (nb->shader->options->lower_ffma32 && srcs[0]->bit_size == 32)
567          break;
568       return nir_ffma(nb, srcs[0], srcs[1], srcs[2]);
569    case OpenCLstd_Rotate:
570       return nir_urol(nb, srcs[0], nir_u2u32(nb, srcs[1]));
571    default:
572       break;
573    }
574 
575    nir_def *ret = handle_clc_fn(b, opcode, num_srcs, srcs, src_types, dest_type);
576    if (!ret)
577       vtn_fail("No NIR equivalent");
578 
579    return ret;
580 }
581 
582 static nir_def *
handle_core(struct vtn_builder * b,uint32_t opcode,unsigned num_srcs,nir_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)583 handle_core(struct vtn_builder *b, uint32_t opcode,
584             unsigned num_srcs, nir_def **srcs, struct vtn_type **src_types,
585             const struct vtn_type *dest_type)
586 {
587    nir_deref_instr *ret_deref = NULL;
588 
589    switch ((SpvOp)opcode) {
590    case SpvOpGroupAsyncCopy: {
591       /* Libclc doesn't include 3-component overloads of the async copy functions.
592        * However, the CLC spec says:
593        * async_work_group_copy and async_work_group_strided_copy for 3-component vector types
594        * behave as async_work_group_copy and async_work_group_strided_copy respectively for 4-component
595        * vector types
596        */
597       for (unsigned i = 0; i < num_srcs; ++i) {
598          if (src_types[i]->base_type == vtn_base_type_pointer &&
599              src_types[i]->pointed->base_type == vtn_base_type_vector &&
600              src_types[i]->pointed->length == 3) {
601             src_types[i] =
602                get_pointer_type(b,
603                                 get_vtn_type_for_glsl_type(b, glsl_replace_vector_type(src_types[i]->pointed->type, 4)),
604                                 src_types[i]->storage_class);
605          }
606       }
607       if (!call_mangled_function(b, "async_work_group_strided_copy", (1 << 1), num_srcs, src_types, dest_type, srcs, &ret_deref))
608          return NULL;
609       break;
610    }
611    case SpvOpGroupWaitEvents: {
612       /* libclc and clang don't agree on the mangling of this function.
613        * The libclc we have uses a __local pointer but clang gives us generic
614        * pointers.  Fortunately, the whole function is just a barrier.
615        */
616       nir_barrier(&b->nb, .execution_scope = SCOPE_WORKGROUP,
617                           .memory_scope = SCOPE_WORKGROUP,
618                           .memory_semantics = NIR_MEMORY_ACQUIRE |
619                                               NIR_MEMORY_RELEASE,
620                           .memory_modes = nir_var_mem_shared |
621                                           nir_var_mem_global);
622       break;
623    }
624    default:
625       return NULL;
626    }
627 
628    return ret_deref ? nir_load_deref(&b->nb, ret_deref) : NULL;
629 }
630 
631 
632 static void
_handle_v_load_store(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,const uint32_t * w,unsigned count,bool load,bool vec_aligned,nir_rounding_mode rounding)633 _handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
634                      const uint32_t *w, unsigned count, bool load,
635                      bool vec_aligned, nir_rounding_mode rounding)
636 {
637    struct vtn_type *type;
638    if (load)
639       type = vtn_get_type(b, w[1]);
640    else
641       type = vtn_get_value_type(b, w[5]);
642    unsigned a = load ? 0 : 1;
643 
644    enum glsl_base_type base_type = glsl_get_base_type(type->type);
645    unsigned components = glsl_get_vector_elements(type->type);
646 
647    nir_def *offset = vtn_get_nir_ssa(b, w[5 + a]);
648    struct vtn_value *p = vtn_value(b, w[6 + a], vtn_value_type_pointer);
649 
650    struct vtn_ssa_value *comps[NIR_MAX_VEC_COMPONENTS];
651    nir_def *ncomps[NIR_MAX_VEC_COMPONENTS];
652 
653    nir_def *moffset = nir_imul_imm(&b->nb, offset,
654       (vec_aligned && components == 3) ? 4 : components);
655    nir_deref_instr *deref = vtn_pointer_to_deref(b, p->pointer);
656 
657    unsigned alignment = vec_aligned ? glsl_get_cl_alignment(type->type) :
658                                       glsl_get_bit_size(type->type) / 8;
659    enum glsl_base_type ptr_base_type =
660       glsl_get_base_type(p->pointer->type->pointed->type);
661    if (base_type != ptr_base_type) {
662       vtn_fail_if(ptr_base_type != GLSL_TYPE_FLOAT16 ||
663                   (base_type != GLSL_TYPE_FLOAT &&
664                    base_type != GLSL_TYPE_DOUBLE),
665                   "vload/vstore cannot do type conversion. "
666                   "vload/vstore_half can only convert from half to other "
667                   "floating-point types.");
668 
669       /* Above-computed alignment was for floats/doubles, not halves */
670       alignment /= glsl_get_bit_size(type->type) / glsl_base_type_get_bit_size(ptr_base_type);
671    }
672 
673    deref = nir_alignment_deref_cast(&b->nb, deref, alignment, 0);
674 
675    for (int i = 0; i < components; i++) {
676       nir_def *coffset = nir_iadd_imm(&b->nb, moffset, i);
677       nir_deref_instr *arr_deref = nir_build_deref_ptr_as_array(&b->nb, deref, coffset);
678 
679       if (load) {
680          comps[i] = vtn_local_load(b, arr_deref, p->type->access);
681          ncomps[i] = comps[i]->def;
682          if (base_type != ptr_base_type) {
683             assert(ptr_base_type == GLSL_TYPE_FLOAT16 &&
684                    (base_type == GLSL_TYPE_FLOAT ||
685                     base_type == GLSL_TYPE_DOUBLE));
686             ncomps[i] = nir_f2fN(&b->nb, ncomps[i],
687                                  glsl_base_type_get_bit_size(base_type));
688          }
689       } else {
690          struct vtn_ssa_value *ssa = vtn_create_ssa_value(b, glsl_scalar_type(base_type));
691          struct vtn_ssa_value *val = vtn_ssa_value(b, w[5]);
692          ssa->def = nir_channel(&b->nb, val->def, i);
693          if (base_type != ptr_base_type) {
694             assert(ptr_base_type == GLSL_TYPE_FLOAT16 &&
695                    (base_type == GLSL_TYPE_FLOAT ||
696                     base_type == GLSL_TYPE_DOUBLE));
697             if (rounding == nir_rounding_mode_undef) {
698                ssa->def = nir_f2f16(&b->nb, ssa->def);
699             } else {
700                ssa->def = nir_convert_alu_types(&b->nb, 16, ssa->def,
701                                                 nir_type_float | ssa->def->bit_size,
702                                                 nir_type_float16,
703                                                 rounding, false);
704             }
705          }
706          vtn_local_store(b, ssa, arr_deref, p->type->access);
707       }
708    }
709    if (load) {
710       vtn_push_nir_ssa(b, w[2], nir_vec(&b->nb, ncomps, components));
711    }
712 }
713 
714 static void
vtn_handle_opencl_vload(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,const uint32_t * w,unsigned count)715 vtn_handle_opencl_vload(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
716                         const uint32_t *w, unsigned count)
717 {
718    _handle_v_load_store(b, opcode, w, count, true,
719                         opcode == OpenCLstd_Vloada_halfn,
720                         nir_rounding_mode_undef);
721 }
722 
723 static void
vtn_handle_opencl_vstore(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,const uint32_t * w,unsigned count)724 vtn_handle_opencl_vstore(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
725                          const uint32_t *w, unsigned count)
726 {
727    _handle_v_load_store(b, opcode, w, count, false,
728                         opcode == OpenCLstd_Vstorea_halfn,
729                         nir_rounding_mode_undef);
730 }
731 
732 static void
vtn_handle_opencl_vstore_half_r(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,const uint32_t * w,unsigned count)733 vtn_handle_opencl_vstore_half_r(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
734                                 const uint32_t *w, unsigned count)
735 {
736    _handle_v_load_store(b, opcode, w, count, false,
737                         opcode == OpenCLstd_Vstorea_halfn_r,
738                         vtn_rounding_mode_to_nir(b, w[8]));
739 }
740 
741 static unsigned
vtn_add_printf_string(struct vtn_builder * b,uint32_t id,u_printf_info * info)742 vtn_add_printf_string(struct vtn_builder *b, uint32_t id, u_printf_info *info)
743 {
744    nir_deref_instr *deref = vtn_nir_deref(b, id);
745 
746    while (deref->deref_type != nir_deref_type_var) {
747       nir_scalar parent = nir_scalar_resolved(deref->parent.ssa, 0);
748       if (parent.def->parent_instr->type != nir_instr_type_deref) {
749          deref = NULL;
750          break;
751       }
752       vtn_assert(parent.comp == 0);
753       deref = nir_instr_as_deref(parent.def->parent_instr);
754    }
755 
756    vtn_fail_if(deref == NULL || !nir_deref_mode_is(deref, nir_var_mem_constant),
757                "Printf string argument must be a pointer to a constant variable");
758    vtn_fail_if(deref->var->constant_initializer == NULL,
759                "Printf string argument must have an initializer");
760    vtn_fail_if(!glsl_type_is_array(deref->var->type),
761                "Printf string must be an char array");
762    const struct glsl_type *char_type = glsl_get_array_element(deref->var->type);
763    vtn_fail_if(char_type != glsl_uint8_t_type() &&
764                char_type != glsl_int8_t_type(),
765                "Printf string must be an char array");
766 
767    nir_constant *c = deref->var->constant_initializer;
768    assert(c->num_elements == glsl_get_length(deref->var->type));
769 
770    unsigned idx = info->string_size;
771    info->strings = reralloc_size(b->shader, info->strings,
772                                  idx + c->num_elements);
773    info->string_size += c->num_elements;
774 
775    char *str = &info->strings[idx];
776    bool found_null = false;
777    for (unsigned i = 0; i < c->num_elements; i++) {
778       memcpy((char *)str + i, c->elements[i]->values, 1);
779       if (str[i] == '\0')
780          found_null = true;
781    }
782    vtn_fail_if(!found_null, "Printf string must be null terminated");
783    return idx;
784 }
785 
786 /* printf is special because there are no limits on args */
787 static void
handle_printf(struct vtn_builder * b,uint32_t opcode,const uint32_t * w_src,unsigned num_srcs,const uint32_t * w_dest)788 handle_printf(struct vtn_builder *b, uint32_t opcode,
789               const uint32_t *w_src, unsigned num_srcs, const uint32_t *w_dest)
790 {
791    if (!b->options->printf) {
792       vtn_push_nir_ssa(b, w_dest[1], nir_imm_int(&b->nb, -1));
793       return;
794    }
795 
796    /* Step 1. extract the format string */
797 
798    /*
799     * info_idx is 1-based to match clover/llvm
800     * the backend indexes the info table at info_idx - 1.
801     */
802    b->shader->printf_info_count++;
803    unsigned info_idx = b->shader->printf_info_count;
804 
805    b->shader->printf_info = reralloc(b->shader, b->shader->printf_info,
806                                      u_printf_info, info_idx);
807    u_printf_info *info = &b->shader->printf_info[info_idx - 1];
808 
809    info->strings = NULL;
810    info->string_size = 0;
811 
812    vtn_add_printf_string(b, w_src[0], info);
813 
814    info->num_args = num_srcs - 1;
815    info->arg_sizes = ralloc_array(b->shader, unsigned, info->num_args);
816 
817    /* Step 2, build an ad-hoc struct type out of the args */
818    unsigned field_offset = 0;
819    struct glsl_struct_field *fields =
820       rzalloc_array(b, struct glsl_struct_field, num_srcs - 1);
821    for (unsigned i = 1; i < num_srcs; ++i) {
822       struct vtn_value *val = vtn_untyped_value(b, w_src[i]);
823       struct vtn_type *src_type = val->type;
824       fields[i - 1].type = src_type->type;
825       fields[i - 1].name = ralloc_asprintf(b->shader, "arg_%u", i);
826       field_offset = align(field_offset, 4);
827       fields[i - 1].offset = field_offset;
828       info->arg_sizes[i - 1] = glsl_get_cl_size(src_type->type);
829       field_offset += glsl_get_cl_size(src_type->type);
830    }
831    const struct glsl_type *struct_type =
832       glsl_struct_type(fields, num_srcs - 1, "printf", true);
833 
834    /* Step 3, create a variable of that type and populate its fields */
835    nir_variable *var = nir_local_variable_create(b->nb.impl, struct_type, NULL);
836    nir_deref_instr *deref_var = nir_build_deref_var(&b->nb, var);
837    size_t fmt_pos = 0;
838    for (unsigned i = 1; i < num_srcs; ++i) {
839       nir_deref_instr *field_deref =
840          nir_build_deref_struct(&b->nb, deref_var, i - 1);
841       nir_def *field_src = vtn_ssa_value(b, w_src[i])->def;
842       /* extract strings */
843       fmt_pos = util_printf_next_spec_pos(info->strings, fmt_pos);
844       if (fmt_pos != -1 && info->strings[fmt_pos] == 's') {
845          unsigned idx = vtn_add_printf_string(b, w_src[i], info);
846          nir_store_deref(&b->nb, field_deref,
847                          nir_imm_intN_t(&b->nb, idx, field_src->bit_size),
848                          ~0 /* write_mask */);
849       } else
850          nir_store_deref(&b->nb, field_deref, field_src, ~0);
851    }
852 
853    /* Lastly, the actual intrinsic */
854    nir_def *fmt_idx = nir_imm_int(&b->nb, info_idx);
855    nir_def *ret = nir_printf(&b->nb, fmt_idx, &deref_var->def);
856    vtn_push_nir_ssa(b, w_dest[1], ret);
857 
858    b->nb.shader->info.uses_printf = true;
859 }
860 
861 static nir_def *
handle_round(struct vtn_builder * b,uint32_t opcode,unsigned num_srcs,nir_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)862 handle_round(struct vtn_builder *b, uint32_t opcode,
863              unsigned num_srcs, nir_def **srcs, struct vtn_type **src_types,
864              const struct vtn_type *dest_type)
865 {
866    nir_def *src = srcs[0];
867    nir_builder *nb = &b->nb;
868    nir_def *half = nir_imm_floatN_t(nb, 0.5, src->bit_size);
869    nir_def *truncated = nir_ftrunc(nb, src);
870    nir_def *remainder = nir_fsub(nb, src, truncated);
871 
872    return nir_bcsel(nb, nir_fge(nb, nir_fabs(nb, remainder), half),
873                     nir_fadd(nb, truncated, nir_fsign(nb, src)), truncated);
874 }
875 
876 static nir_def *
handle_shuffle(struct vtn_builder * b,uint32_t opcode,unsigned num_srcs,nir_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)877 handle_shuffle(struct vtn_builder *b, uint32_t opcode,
878                unsigned num_srcs, nir_def **srcs, struct vtn_type **src_types,
879                const struct vtn_type *dest_type)
880 {
881    struct nir_def *input = srcs[0];
882    struct nir_def *mask = srcs[1];
883 
884    unsigned out_elems = dest_type->length;
885    nir_def *outres[NIR_MAX_VEC_COMPONENTS];
886    unsigned in_elems = input->num_components;
887    if (mask->bit_size != 32)
888       mask = nir_u2u32(&b->nb, mask);
889    mask = nir_iand(&b->nb, mask, nir_imm_intN_t(&b->nb, in_elems - 1, mask->bit_size));
890    for (unsigned i = 0; i < out_elems; i++)
891       outres[i] = nir_vector_extract(&b->nb, input, nir_channel(&b->nb, mask, i));
892 
893    return nir_vec(&b->nb, outres, out_elems);
894 }
895 
896 static nir_def *
handle_shuffle2(struct vtn_builder * b,uint32_t opcode,unsigned num_srcs,nir_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)897 handle_shuffle2(struct vtn_builder *b, uint32_t opcode,
898                 unsigned num_srcs, nir_def **srcs, struct vtn_type **src_types,
899                 const struct vtn_type *dest_type)
900 {
901    struct nir_def *input0 = srcs[0];
902    struct nir_def *input1 = srcs[1];
903    struct nir_def *mask = srcs[2];
904 
905    unsigned out_elems = dest_type->length;
906    nir_def *outres[NIR_MAX_VEC_COMPONENTS];
907    unsigned in_elems = input0->num_components;
908    unsigned total_mask = 2 * in_elems - 1;
909    unsigned half_mask = in_elems - 1;
910    if (mask->bit_size != 32)
911       mask = nir_u2u32(&b->nb, mask);
912    mask = nir_iand(&b->nb, mask, nir_imm_intN_t(&b->nb, total_mask, mask->bit_size));
913    for (unsigned i = 0; i < out_elems; i++) {
914       nir_def *this_mask = nir_channel(&b->nb, mask, i);
915       nir_def *vmask = nir_iand(&b->nb, this_mask, nir_imm_intN_t(&b->nb, half_mask, mask->bit_size));
916       nir_def *val0 = nir_vector_extract(&b->nb, input0, vmask);
917       nir_def *val1 = nir_vector_extract(&b->nb, input1, vmask);
918       nir_def *sel = nir_ilt_imm(&b->nb, this_mask, in_elems);
919       outres[i] = nir_bcsel(&b->nb, sel, val0, val1);
920    }
921    return nir_vec(&b->nb, outres, out_elems);
922 }
923 
924 bool
vtn_handle_opencl_instruction(struct vtn_builder * b,SpvOp ext_opcode,const uint32_t * w,unsigned count)925 vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode,
926                               const uint32_t *w, unsigned count)
927 {
928    enum OpenCLstd_Entrypoints cl_opcode = (enum OpenCLstd_Entrypoints) ext_opcode;
929 
930    switch (cl_opcode) {
931    case OpenCLstd_Fabs:
932    case OpenCLstd_SAbs:
933    case OpenCLstd_UAbs:
934    case OpenCLstd_SAdd_sat:
935    case OpenCLstd_UAdd_sat:
936    case OpenCLstd_Ceil:
937    case OpenCLstd_Floor:
938    case OpenCLstd_Fmax:
939    case OpenCLstd_SHadd:
940    case OpenCLstd_UHadd:
941    case OpenCLstd_SMax:
942    case OpenCLstd_UMax:
943    case OpenCLstd_Fmin:
944    case OpenCLstd_SMin:
945    case OpenCLstd_UMin:
946    case OpenCLstd_Mix:
947    case OpenCLstd_Native_cos:
948    case OpenCLstd_Native_divide:
949    case OpenCLstd_Native_exp2:
950    case OpenCLstd_Native_log2:
951    case OpenCLstd_Native_powr:
952    case OpenCLstd_Native_recip:
953    case OpenCLstd_Native_rsqrt:
954    case OpenCLstd_Native_sin:
955    case OpenCLstd_Native_sqrt:
956    case OpenCLstd_SMul_hi:
957    case OpenCLstd_UMul_hi:
958    case OpenCLstd_Popcount:
959    case OpenCLstd_SRhadd:
960    case OpenCLstd_URhadd:
961    case OpenCLstd_Rsqrt:
962    case OpenCLstd_Sign:
963    case OpenCLstd_Sqrt:
964    case OpenCLstd_SSub_sat:
965    case OpenCLstd_USub_sat:
966    case OpenCLstd_Trunc:
967    case OpenCLstd_Rint:
968    case OpenCLstd_Half_divide:
969    case OpenCLstd_Half_recip:
970       handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_alu);
971       return true;
972    case OpenCLstd_SAbs_diff:
973    case OpenCLstd_UAbs_diff:
974    case OpenCLstd_SMad_hi:
975    case OpenCLstd_UMad_hi:
976    case OpenCLstd_SMad24:
977    case OpenCLstd_UMad24:
978    case OpenCLstd_SMul24:
979    case OpenCLstd_UMul24:
980    case OpenCLstd_Bitselect:
981    case OpenCLstd_FClamp:
982    case OpenCLstd_SClamp:
983    case OpenCLstd_UClamp:
984    case OpenCLstd_Copysign:
985    case OpenCLstd_Cross:
986    case OpenCLstd_Degrees:
987    case OpenCLstd_Fdim:
988    case OpenCLstd_Fma:
989    case OpenCLstd_Distance:
990    case OpenCLstd_Fast_distance:
991    case OpenCLstd_Fast_length:
992    case OpenCLstd_Fast_normalize:
993    case OpenCLstd_Half_rsqrt:
994    case OpenCLstd_Half_sqrt:
995    case OpenCLstd_Length:
996    case OpenCLstd_Mad:
997    case OpenCLstd_Maxmag:
998    case OpenCLstd_Minmag:
999    case OpenCLstd_Nan:
1000    case OpenCLstd_Nextafter:
1001    case OpenCLstd_Normalize:
1002    case OpenCLstd_Radians:
1003    case OpenCLstd_Rotate:
1004    case OpenCLstd_Select:
1005    case OpenCLstd_Step:
1006    case OpenCLstd_Smoothstep:
1007    case OpenCLstd_S_Upsample:
1008    case OpenCLstd_U_Upsample:
1009    case OpenCLstd_Clz:
1010    case OpenCLstd_Ctz:
1011    case OpenCLstd_Native_exp:
1012    case OpenCLstd_Native_exp10:
1013    case OpenCLstd_Native_log:
1014    case OpenCLstd_Native_log10:
1015    case OpenCLstd_Acos:
1016    case OpenCLstd_Acosh:
1017    case OpenCLstd_Acospi:
1018    case OpenCLstd_Asin:
1019    case OpenCLstd_Asinh:
1020    case OpenCLstd_Asinpi:
1021    case OpenCLstd_Atan:
1022    case OpenCLstd_Atan2:
1023    case OpenCLstd_Atanh:
1024    case OpenCLstd_Atanpi:
1025    case OpenCLstd_Atan2pi:
1026    case OpenCLstd_Fract:
1027    case OpenCLstd_Frexp:
1028    case OpenCLstd_Exp:
1029    case OpenCLstd_Exp2:
1030    case OpenCLstd_Expm1:
1031    case OpenCLstd_Exp10:
1032    case OpenCLstd_Fmod:
1033    case OpenCLstd_Ilogb:
1034    case OpenCLstd_Log:
1035    case OpenCLstd_Log2:
1036    case OpenCLstd_Log10:
1037    case OpenCLstd_Log1p:
1038    case OpenCLstd_Logb:
1039    case OpenCLstd_Ldexp:
1040    case OpenCLstd_Cos:
1041    case OpenCLstd_Cosh:
1042    case OpenCLstd_Cospi:
1043    case OpenCLstd_Sin:
1044    case OpenCLstd_Sinh:
1045    case OpenCLstd_Sinpi:
1046    case OpenCLstd_Tan:
1047    case OpenCLstd_Tanh:
1048    case OpenCLstd_Tanpi:
1049    case OpenCLstd_Cbrt:
1050    case OpenCLstd_Erfc:
1051    case OpenCLstd_Erf:
1052    case OpenCLstd_Lgamma:
1053    case OpenCLstd_Lgamma_r:
1054    case OpenCLstd_Tgamma:
1055    case OpenCLstd_Pow:
1056    case OpenCLstd_Powr:
1057    case OpenCLstd_Pown:
1058    case OpenCLstd_Rootn:
1059    case OpenCLstd_Remainder:
1060    case OpenCLstd_Remquo:
1061    case OpenCLstd_Hypot:
1062    case OpenCLstd_Sincos:
1063    case OpenCLstd_Modf:
1064    case OpenCLstd_UMad_sat:
1065    case OpenCLstd_SMad_sat:
1066    case OpenCLstd_Native_tan:
1067    case OpenCLstd_Half_cos:
1068    case OpenCLstd_Half_exp:
1069    case OpenCLstd_Half_exp2:
1070    case OpenCLstd_Half_exp10:
1071    case OpenCLstd_Half_log:
1072    case OpenCLstd_Half_log2:
1073    case OpenCLstd_Half_log10:
1074    case OpenCLstd_Half_powr:
1075    case OpenCLstd_Half_sin:
1076    case OpenCLstd_Half_tan:
1077       handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_special);
1078       return true;
1079    case OpenCLstd_Vloadn:
1080    case OpenCLstd_Vload_half:
1081    case OpenCLstd_Vload_halfn:
1082    case OpenCLstd_Vloada_halfn:
1083       vtn_handle_opencl_vload(b, cl_opcode, w, count);
1084       return true;
1085    case OpenCLstd_Vstoren:
1086    case OpenCLstd_Vstore_half:
1087    case OpenCLstd_Vstore_halfn:
1088    case OpenCLstd_Vstorea_halfn:
1089       vtn_handle_opencl_vstore(b, cl_opcode, w, count);
1090       return true;
1091    case OpenCLstd_Vstore_half_r:
1092    case OpenCLstd_Vstore_halfn_r:
1093    case OpenCLstd_Vstorea_halfn_r:
1094       vtn_handle_opencl_vstore_half_r(b, cl_opcode, w, count);
1095       return true;
1096    case OpenCLstd_Shuffle:
1097       handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_shuffle);
1098       return true;
1099    case OpenCLstd_Shuffle2:
1100       handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_shuffle2);
1101       return true;
1102    case OpenCLstd_Round:
1103       handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_round);
1104       return true;
1105    case OpenCLstd_Printf:
1106       handle_printf(b, ext_opcode, w + 5, count - 5, w + 1);
1107       return true;
1108    case OpenCLstd_Prefetch:
1109       /* TODO maybe add a nir instruction for this? */
1110       return true;
1111    default:
1112       vtn_fail("unhandled opencl opc: %u\n", ext_opcode);
1113       return false;
1114    }
1115 }
1116 
1117 bool
vtn_handle_opencl_core_instruction(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)1118 vtn_handle_opencl_core_instruction(struct vtn_builder *b, SpvOp opcode,
1119                                    const uint32_t *w, unsigned count)
1120 {
1121    switch (opcode) {
1122    case SpvOpGroupAsyncCopy:
1123       handle_instr(b, opcode, w + 4, count - 4, w + 1, handle_core);
1124       return true;
1125    case SpvOpGroupWaitEvents:
1126       handle_instr(b, opcode, w + 2, count - 2, NULL, handle_core);
1127       return true;
1128    default:
1129       return false;
1130    }
1131    return true;
1132 }
1133