xref: /aosp_15_r20/external/mesa3d/src/compiler/clc/clc_helpers.cpp (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 //
2 // Copyright 2012-2016 Francisco Jerez
3 // Copyright 2012-2016 Advanced Micro Devices, Inc.
4 // Copyright 2014-2016 Jan Vesely
5 // Copyright 2014-2015 Serge Martin
6 // Copyright 2015 Zoltan Gilian
7 //
8 // Permission is hereby granted, free of charge, to any person obtaining a
9 // copy of this software and associated documentation files (the "Software"),
10 // to deal in the Software without restriction, including without limitation
11 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
12 // and/or sell copies of the Software, and to permit persons to whom the
13 // Software is furnished to do so, subject to the following conditions:
14 //
15 // The above copyright notice and this permission notice shall be included in
16 // all copies or substantial portions of the Software.
17 //
18 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
21 // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
22 // OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
23 // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
24 // OTHER DEALINGS IN THE SOFTWARE.
25 
26 #include <cstdlib>
27 #include <filesystem>
28 #include <sstream>
29 #include <mutex>
30 
31 #include <llvm/ADT/ArrayRef.h>
32 #include <llvm/IR/DiagnosticPrinter.h>
33 #include <llvm/IR/DiagnosticInfo.h>
34 #include <llvm/IR/LegacyPassManager.h>
35 #include <llvm/IR/LLVMContext.h>
36 #include <llvm/IR/Type.h>
37 #include <llvm/MC/TargetRegistry.h>
38 #include <llvm/Target/TargetMachine.h>
39 #include <llvm/Support/raw_ostream.h>
40 #include <llvm/Bitcode/BitcodeWriter.h>
41 #include <llvm/Bitcode/BitcodeReader.h>
42 #include <llvm-c/Core.h>
43 #include <llvm-c/Target.h>
44 #include <LLVMSPIRVLib/LLVMSPIRVLib.h>
45 
46 #include <clang/Config/config.h>
47 #include <clang/Driver/Driver.h>
48 #include <clang/CodeGen/CodeGenAction.h>
49 #include <clang/Lex/PreprocessorOptions.h>
50 #include <clang/Frontend/CompilerInstance.h>
51 #include <clang/Frontend/TextDiagnosticBuffer.h>
52 #include <clang/Frontend/TextDiagnosticPrinter.h>
53 #include <clang/Basic/TargetInfo.h>
54 
55 #include <spirv-tools/libspirv.hpp>
56 #include <spirv-tools/linker.hpp>
57 #include <spirv-tools/optimizer.hpp>
58 
59 #include "util/macros.h"
60 #include "glsl_types.h"
61 
62 #include "spirv.h"
63 
64 #if DETECT_OS_POSIX
65 #include <dlfcn.h>
66 #endif
67 
68 #ifdef USE_STATIC_OPENCL_C_H
69 #include "opencl-c-base.h.h"
70 #include "opencl-c.h.h"
71 #endif
72 
73 #include "clc_helpers.h"
74 
75 namespace fs = std::filesystem;
76 
77 /* Use the highest version of SPIRV supported by SPIRV-Tools. */
78 constexpr spv_target_env spirv_target = SPV_ENV_UNIVERSAL_1_5;
79 
80 constexpr SPIRV::VersionNumber invalid_spirv_trans_version = static_cast<SPIRV::VersionNumber>(0);
81 
82 using ::llvm::Function;
83 using ::llvm::legacy::PassManager;
84 using ::llvm::LLVMContext;
85 using ::llvm::Module;
86 using ::llvm::raw_string_ostream;
87 using ::llvm::TargetRegistry;
88 using ::clang::driver::Driver;
89 
90 static void
91 clc_dump_llvm(const llvm::Module *mod, FILE *f);
92 
93 static void
94 #if LLVM_VERSION_MAJOR >= 19
llvm_log_handler(const::llvm::DiagnosticInfo * di,void * data)95 llvm_log_handler(const ::llvm::DiagnosticInfo *di, void *data) {
96 #else
97 llvm_log_handler(const ::llvm::DiagnosticInfo &di, void *data) {
98 #endif
99    const clc_logger *logger = static_cast<clc_logger *>(data);
100 
101    std::string log;
102    raw_string_ostream os { log };
103    ::llvm::DiagnosticPrinterRawOStream printer { os };
104 #if LLVM_VERSION_MAJOR >= 19
105    di->print(printer);
106 #else
107    di.print(printer);
108 #endif
109 
110    clc_error(logger, "%s", log.c_str());
111 }
112 
113 class SPIRVKernelArg {
114 public:
115    SPIRVKernelArg(uint32_t id, uint32_t typeId) : id(id), typeId(typeId),
116                                                   addrQualifier(CLC_KERNEL_ARG_ADDRESS_PRIVATE),
117                                                   accessQualifier(0),
118                                                   typeQualifier(0) { }
119    ~SPIRVKernelArg() { }
120 
121    uint32_t id;
122    uint32_t typeId;
123    std::string name;
124    std::string typeName;
125    enum clc_kernel_arg_address_qualifier addrQualifier;
126    unsigned accessQualifier;
127    unsigned typeQualifier;
128 };
129 
130 class SPIRVKernelInfo {
131 public:
132    SPIRVKernelInfo(uint32_t fid, const char *nm)
133       : funcId(fid), name(nm), vecHint(0), localSize(), localSizeHint() { }
134    ~SPIRVKernelInfo() { }
135 
136    uint32_t funcId;
137    std::string name;
138    std::vector<SPIRVKernelArg> args;
139    unsigned vecHint;
140    unsigned localSize[3];
141    unsigned localSizeHint[3];
142 };
143 
144 class SPIRVKernelParser {
145 public:
146    SPIRVKernelParser() : curKernel(NULL)
147    {
148       ctx = spvContextCreate(spirv_target);
149    }
150 
151    ~SPIRVKernelParser()
152    {
153      spvContextDestroy(ctx);
154    }
155 
156    void parseEntryPoint(const spv_parsed_instruction_t *ins)
157    {
158       assert(ins->num_operands >= 3);
159 
160       const spv_parsed_operand_t *op = &ins->operands[1];
161 
162       assert(op->type == SPV_OPERAND_TYPE_ID);
163 
164       uint32_t funcId = ins->words[op->offset];
165 
166       for (auto &iter : kernels) {
167          if (funcId == iter.funcId)
168             return;
169       }
170 
171       op = &ins->operands[2];
172       assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
173       const char *name = reinterpret_cast<const char *>(ins->words + op->offset);
174 
175       kernels.push_back(SPIRVKernelInfo(funcId, name));
176    }
177 
178    void parseFunction(const spv_parsed_instruction_t *ins)
179    {
180       assert(ins->num_operands == 4);
181 
182       const spv_parsed_operand_t *op = &ins->operands[1];
183 
184       assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
185 
186       uint32_t funcId = ins->words[op->offset];
187 
188       for (auto &kernel : kernels) {
189          if (funcId == kernel.funcId && !kernel.args.size()) {
190             curKernel = &kernel;
191 	    return;
192          }
193       }
194    }
195 
196    void parseFunctionParam(const spv_parsed_instruction_t *ins)
197    {
198       const spv_parsed_operand_t *op;
199       uint32_t id, typeId;
200 
201       if (!curKernel)
202          return;
203 
204       assert(ins->num_operands == 2);
205       op = &ins->operands[0];
206       assert(op->type == SPV_OPERAND_TYPE_TYPE_ID);
207       typeId = ins->words[op->offset];
208       op = &ins->operands[1];
209       assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
210       id = ins->words[op->offset];
211       curKernel->args.push_back(SPIRVKernelArg(id, typeId));
212    }
213 
214    void parseName(const spv_parsed_instruction_t *ins)
215    {
216       const spv_parsed_operand_t *op;
217       const char *name;
218       uint32_t id;
219 
220       assert(ins->num_operands == 2);
221 
222       op = &ins->operands[0];
223       assert(op->type == SPV_OPERAND_TYPE_ID);
224       id = ins->words[op->offset];
225       op = &ins->operands[1];
226       assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
227       name = reinterpret_cast<const char *>(ins->words + op->offset);
228 
229       for (auto &kernel : kernels) {
230          for (auto &arg : kernel.args) {
231             if (arg.id == id && arg.name.empty()) {
232               arg.name = name;
233               break;
234 	    }
235          }
236       }
237    }
238 
239    void parseTypePointer(const spv_parsed_instruction_t *ins)
240    {
241       enum clc_kernel_arg_address_qualifier addrQualifier;
242       uint32_t typeId, storageClass;
243       const spv_parsed_operand_t *op;
244 
245       assert(ins->num_operands == 3);
246 
247       op = &ins->operands[0];
248       assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
249       typeId = ins->words[op->offset];
250 
251       op = &ins->operands[1];
252       assert(op->type == SPV_OPERAND_TYPE_STORAGE_CLASS);
253       storageClass = ins->words[op->offset];
254       switch (storageClass) {
255       case SpvStorageClassCrossWorkgroup:
256          addrQualifier = CLC_KERNEL_ARG_ADDRESS_GLOBAL;
257          break;
258       case SpvStorageClassWorkgroup:
259          addrQualifier = CLC_KERNEL_ARG_ADDRESS_LOCAL;
260          break;
261       case SpvStorageClassUniformConstant:
262          addrQualifier = CLC_KERNEL_ARG_ADDRESS_CONSTANT;
263          break;
264       default:
265          addrQualifier = CLC_KERNEL_ARG_ADDRESS_PRIVATE;
266          break;
267       }
268 
269       for (auto &kernel : kernels) {
270 	 for (auto &arg : kernel.args) {
271             if (arg.typeId == typeId) {
272                arg.addrQualifier = addrQualifier;
273                if (addrQualifier == CLC_KERNEL_ARG_ADDRESS_CONSTANT)
274                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
275             }
276          }
277       }
278    }
279 
280    void parseOpString(const spv_parsed_instruction_t *ins)
281    {
282       const spv_parsed_operand_t *op;
283       std::string str;
284 
285       assert(ins->num_operands == 2);
286 
287       op = &ins->operands[1];
288       assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
289       str = reinterpret_cast<const char *>(ins->words + op->offset);
290 
291       size_t start = 0;
292       enum class string_type {
293          arg_type,
294          arg_type_qual,
295       } str_type;
296 
297       if (str.find("kernel_arg_type.") == 0) {
298          start = sizeof("kernel_arg_type.") - 1;
299          str_type = string_type::arg_type;
300       } else if (str.find("kernel_arg_type_qual.") == 0) {
301          start = sizeof("kernel_arg_type_qual.") - 1;
302          str_type = string_type::arg_type_qual;
303       } else {
304          return;
305       }
306 
307       for (auto &kernel : kernels) {
308          size_t pos;
309 
310 	 pos = str.find(kernel.name, start);
311          if (pos == std::string::npos ||
312              pos != start || str[start + kernel.name.size()] != '.')
313             continue;
314 
315 	 pos = start + kernel.name.size();
316          if (str[pos++] != '.')
317             continue;
318 
319          for (auto &arg : kernel.args) {
320             if (arg.name.empty())
321                break;
322 
323             size_t entryEnd = str.find(',', pos);
324 	    if (entryEnd == std::string::npos)
325                break;
326 
327             std::string entryVal = str.substr(pos, entryEnd - pos);
328             pos = entryEnd + 1;
329 
330             if (str_type == string_type::arg_type) {
331                arg.typeName = std::move(entryVal);
332             } else if (str_type == string_type::arg_type_qual) {
333                if (entryVal.find("const") != std::string::npos)
334                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
335             }
336          }
337       }
338    }
339 
340    void applyDecoration(uint32_t id, const spv_parsed_instruction_t *ins)
341    {
342       auto iter = decorationGroups.find(id);
343       if (iter != decorationGroups.end()) {
344          for (uint32_t entry : iter->second)
345             applyDecoration(entry, ins);
346          return;
347       }
348 
349       const spv_parsed_operand_t *op;
350       uint32_t decoration;
351 
352       assert(ins->num_operands >= 2);
353 
354       op = &ins->operands[1];
355       assert(op->type == SPV_OPERAND_TYPE_DECORATION);
356       decoration = ins->words[op->offset];
357 
358       if (decoration == SpvDecorationSpecId) {
359          uint32_t spec_id = ins->words[ins->operands[2].offset];
360          for (auto &c : specConstants) {
361             if (c.second.id == spec_id) {
362                return;
363             }
364          }
365          specConstants.emplace_back(id, clc_parsed_spec_constant{ spec_id });
366          return;
367       }
368 
369       for (auto &kernel : kernels) {
370          for (auto &arg : kernel.args) {
371             if (arg.id == id) {
372                switch (decoration) {
373                case SpvDecorationVolatile:
374                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_VOLATILE;
375                   break;
376                case SpvDecorationConstant:
377                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
378                   break;
379                case SpvDecorationRestrict:
380                   arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_RESTRICT;
381                   break;
382                case SpvDecorationFuncParamAttr:
383                   op = &ins->operands[2];
384                   assert(op->type == SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE);
385                   switch (ins->words[op->offset]) {
386                   case SpvFunctionParameterAttributeNoAlias:
387                      arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_RESTRICT;
388                      break;
389                   case SpvFunctionParameterAttributeNoWrite:
390                      arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
391                      break;
392                   }
393                   break;
394                }
395             }
396 
397          }
398       }
399    }
400 
401    void parseOpDecorate(const spv_parsed_instruction_t *ins)
402    {
403       const spv_parsed_operand_t *op;
404       uint32_t id;
405 
406       assert(ins->num_operands >= 2);
407 
408       op = &ins->operands[0];
409       assert(op->type == SPV_OPERAND_TYPE_ID);
410       id = ins->words[op->offset];
411 
412       applyDecoration(id, ins);
413    }
414 
415    void parseOpGroupDecorate(const spv_parsed_instruction_t *ins)
416    {
417       assert(ins->num_operands >= 2);
418 
419       const spv_parsed_operand_t *op = &ins->operands[0];
420       assert(op->type == SPV_OPERAND_TYPE_ID);
421       uint32_t groupId = ins->words[op->offset];
422 
423       auto lowerBound = decorationGroups.lower_bound(groupId);
424       if (lowerBound != decorationGroups.end() &&
425           lowerBound->first == groupId)
426          // Group already filled out
427          return;
428 
429       auto iter = decorationGroups.emplace_hint(lowerBound, groupId, std::vector<uint32_t>{});
430       auto& vec = iter->second;
431       vec.reserve(ins->num_operands - 1);
432       for (uint32_t i = 1; i < ins->num_operands; ++i) {
433          op = &ins->operands[i];
434          assert(op->type == SPV_OPERAND_TYPE_ID);
435          vec.push_back(ins->words[op->offset]);
436       }
437    }
438 
439    void parseOpTypeImage(const spv_parsed_instruction_t *ins)
440    {
441       const spv_parsed_operand_t *op;
442       uint32_t typeId;
443       unsigned accessQualifier = CLC_KERNEL_ARG_ACCESS_READ;
444 
445       op = &ins->operands[0];
446       assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
447       typeId = ins->words[op->offset];
448 
449       if (ins->num_operands >= 9) {
450          op = &ins->operands[8];
451          assert(op->type == SPV_OPERAND_TYPE_ACCESS_QUALIFIER);
452          switch (ins->words[op->offset]) {
453          case SpvAccessQualifierReadOnly:
454             accessQualifier = CLC_KERNEL_ARG_ACCESS_READ;
455             break;
456          case SpvAccessQualifierWriteOnly:
457             accessQualifier = CLC_KERNEL_ARG_ACCESS_WRITE;
458             break;
459          case SpvAccessQualifierReadWrite:
460             accessQualifier = CLC_KERNEL_ARG_ACCESS_WRITE |
461                CLC_KERNEL_ARG_ACCESS_READ;
462             break;
463          }
464       }
465 
466       for (auto &kernel : kernels) {
467 	 for (auto &arg : kernel.args) {
468             if (arg.typeId == typeId) {
469                arg.accessQualifier = accessQualifier;
470                arg.addrQualifier = CLC_KERNEL_ARG_ADDRESS_GLOBAL;
471             }
472          }
473       }
474    }
475 
476    void parseExecutionMode(const spv_parsed_instruction_t *ins)
477    {
478       uint32_t executionMode = ins->words[ins->operands[1].offset];
479       uint32_t funcId = ins->words[ins->operands[0].offset];
480 
481       for (auto& kernel : kernels) {
482          if (kernel.funcId == funcId) {
483             switch (executionMode) {
484             case SpvExecutionModeVecTypeHint:
485                kernel.vecHint = ins->words[ins->operands[2].offset];
486                break;
487             case SpvExecutionModeLocalSize:
488                kernel.localSize[0] = ins->words[ins->operands[2].offset];
489                kernel.localSize[1] = ins->words[ins->operands[3].offset];
490                kernel.localSize[2] = ins->words[ins->operands[4].offset];
491                break;
492             case SpvExecutionModeLocalSizeHint:
493                kernel.localSizeHint[0] = ins->words[ins->operands[2].offset];
494                kernel.localSizeHint[1] = ins->words[ins->operands[3].offset];
495                kernel.localSizeHint[2] = ins->words[ins->operands[4].offset];
496                break;
497             default:
498                return;
499             }
500          }
501       }
502    }
503 
504    void parseLiteralType(const spv_parsed_instruction_t *ins)
505    {
506       uint32_t typeId = ins->words[ins->operands[0].offset];
507       auto& literalType = literalTypes[typeId];
508       switch (ins->opcode) {
509       case SpvOpTypeBool:
510          literalType = CLC_SPEC_CONSTANT_BOOL;
511          break;
512       case SpvOpTypeFloat: {
513          uint32_t sizeInBits = ins->words[ins->operands[1].offset];
514          switch (sizeInBits) {
515          case 32:
516             literalType = CLC_SPEC_CONSTANT_FLOAT;
517             break;
518          case 64:
519             literalType = CLC_SPEC_CONSTANT_DOUBLE;
520             break;
521          case 16:
522             /* Can't be used for a spec constant */
523             break;
524          default:
525             unreachable("Unexpected float bit size");
526          }
527          break;
528       }
529       case SpvOpTypeInt: {
530          uint32_t sizeInBits = ins->words[ins->operands[1].offset];
531          bool isSigned = ins->words[ins->operands[2].offset];
532          if (isSigned) {
533             switch (sizeInBits) {
534             case 8:
535                literalType = CLC_SPEC_CONSTANT_INT8;
536                break;
537             case 16:
538                literalType = CLC_SPEC_CONSTANT_INT16;
539                break;
540             case 32:
541                literalType = CLC_SPEC_CONSTANT_INT32;
542                break;
543             case 64:
544                literalType = CLC_SPEC_CONSTANT_INT64;
545                break;
546             default:
547                unreachable("Unexpected int bit size");
548             }
549          } else {
550             switch (sizeInBits) {
551             case 8:
552                literalType = CLC_SPEC_CONSTANT_UINT8;
553                break;
554             case 16:
555                literalType = CLC_SPEC_CONSTANT_UINT16;
556                break;
557             case 32:
558                literalType = CLC_SPEC_CONSTANT_UINT32;
559                break;
560             case 64:
561                literalType = CLC_SPEC_CONSTANT_UINT64;
562                break;
563             default:
564                unreachable("Unexpected uint bit size");
565             }
566          }
567          break;
568       }
569       default:
570          unreachable("Unexpected type opcode");
571       }
572    }
573 
574    void parseSpecConstant(const spv_parsed_instruction_t *ins)
575    {
576       uint32_t id = ins->result_id;
577       for (auto& c : specConstants) {
578          if (c.first == id) {
579             auto& data = c.second;
580             switch (ins->opcode) {
581             case SpvOpSpecConstant: {
582                uint32_t typeId = ins->words[ins->operands[0].offset];
583 
584                // This better be an integer or float type
585                auto typeIter = literalTypes.find(typeId);
586                assert(typeIter != literalTypes.end());
587 
588                data.type = typeIter->second;
589                break;
590             }
591             case SpvOpSpecConstantFalse:
592             case SpvOpSpecConstantTrue:
593                data.type = CLC_SPEC_CONSTANT_BOOL;
594                break;
595             default:
596                unreachable("Composites and Ops are not directly specializable.");
597             }
598          }
599       }
600    }
601 
602    static spv_result_t
603    parseInstruction(void *data, const spv_parsed_instruction_t *ins)
604    {
605       SPIRVKernelParser *parser = reinterpret_cast<SPIRVKernelParser *>(data);
606 
607       switch (ins->opcode) {
608       case SpvOpName:
609          parser->parseName(ins);
610          break;
611       case SpvOpEntryPoint:
612          parser->parseEntryPoint(ins);
613          break;
614       case SpvOpFunction:
615          parser->parseFunction(ins);
616          break;
617       case SpvOpFunctionParameter:
618          parser->parseFunctionParam(ins);
619          break;
620       case SpvOpFunctionEnd:
621       case SpvOpLabel:
622          parser->curKernel = NULL;
623          break;
624       case SpvOpTypePointer:
625          parser->parseTypePointer(ins);
626          break;
627       case SpvOpTypeImage:
628          parser->parseOpTypeImage(ins);
629          break;
630       case SpvOpString:
631          parser->parseOpString(ins);
632          break;
633       case SpvOpDecorate:
634          parser->parseOpDecorate(ins);
635          break;
636       case SpvOpGroupDecorate:
637          parser->parseOpGroupDecorate(ins);
638          break;
639       case SpvOpExecutionMode:
640          parser->parseExecutionMode(ins);
641          break;
642       case SpvOpTypeBool:
643       case SpvOpTypeInt:
644       case SpvOpTypeFloat:
645          parser->parseLiteralType(ins);
646          break;
647       case SpvOpSpecConstant:
648       case SpvOpSpecConstantFalse:
649       case SpvOpSpecConstantTrue:
650          parser->parseSpecConstant(ins);
651          break;
652       default:
653          break;
654       }
655 
656       return SPV_SUCCESS;
657    }
658 
659    bool parseBinary(const struct clc_binary &spvbin, const struct clc_logger *logger)
660    {
661       /* 3 passes should be enough to retrieve all kernel information:
662        * 1st pass: all entry point name and number of args
663        * 2nd pass: argument names and type names
664        * 3rd pass: pointer type names
665        */
666       for (unsigned pass = 0; pass < 3; pass++) {
667          spv_diagnostic diagnostic = NULL;
668          auto result = spvBinaryParse(ctx, reinterpret_cast<void *>(this),
669                                       static_cast<uint32_t*>(spvbin.data), spvbin.size / 4,
670                                       NULL, parseInstruction, &diagnostic);
671 
672          if (result != SPV_SUCCESS) {
673             if (diagnostic && logger)
674                logger->error(logger->priv, diagnostic->error);
675             return false;
676          }
677       }
678 
679       return true;
680    }
681 
682    std::vector<SPIRVKernelInfo> kernels;
683    std::vector<std::pair<uint32_t, clc_parsed_spec_constant>> specConstants;
684    std::map<uint32_t, enum clc_spec_constant_type> literalTypes;
685    std::map<uint32_t, std::vector<uint32_t>> decorationGroups;
686    SPIRVKernelInfo *curKernel;
687    spv_context ctx;
688 };
689 
690 bool
691 clc_spirv_get_kernels_info(const struct clc_binary *spvbin,
692                            const struct clc_kernel_info **out_kernels,
693                            unsigned *num_kernels,
694                            const struct clc_parsed_spec_constant **out_spec_constants,
695                            unsigned *num_spec_constants,
696                            const struct clc_logger *logger)
697 {
698    struct clc_kernel_info *kernels = NULL;
699    struct clc_parsed_spec_constant *spec_constants = NULL;
700 
701    SPIRVKernelParser parser;
702 
703    if (!parser.parseBinary(*spvbin, logger))
704       return false;
705 
706    *num_kernels = parser.kernels.size();
707    *num_spec_constants = parser.specConstants.size();
708    if (*num_kernels) {
709       kernels = reinterpret_cast<struct clc_kernel_info *>(calloc(*num_kernels,
710                                                                   sizeof(*kernels)));
711       assert(kernels);
712       for (unsigned i = 0; i < parser.kernels.size(); i++) {
713          kernels[i].name = strdup(parser.kernels[i].name.c_str());
714          kernels[i].num_args = parser.kernels[i].args.size();
715          kernels[i].vec_hint_size = parser.kernels[i].vecHint >> 16;
716          kernels[i].vec_hint_type = (enum clc_vec_hint_type)(parser.kernels[i].vecHint & 0xFFFF);
717          memcpy(kernels[i].local_size, parser.kernels[i].localSize, sizeof(kernels[i].local_size));
718          memcpy(kernels[i].local_size_hint, parser.kernels[i].localSizeHint, sizeof(kernels[i].local_size_hint));
719          if (!kernels[i].num_args)
720             continue;
721 
722          struct clc_kernel_arg *args;
723 
724          args = reinterpret_cast<struct clc_kernel_arg *>(calloc(kernels[i].num_args,
725                                                                  sizeof(*kernels->args)));
726          kernels[i].args = args;
727          assert(args);
728          for (unsigned j = 0; j < kernels[i].num_args; j++) {
729             if (!parser.kernels[i].args[j].name.empty())
730                args[j].name = strdup(parser.kernels[i].args[j].name.c_str());
731             args[j].type_name = strdup(parser.kernels[i].args[j].typeName.c_str());
732             args[j].address_qualifier = parser.kernels[i].args[j].addrQualifier;
733             args[j].type_qualifier = parser.kernels[i].args[j].typeQualifier;
734             args[j].access_qualifier = parser.kernels[i].args[j].accessQualifier;
735          }
736       }
737    }
738 
739    if (*num_spec_constants) {
740       spec_constants = reinterpret_cast<struct clc_parsed_spec_constant *>(calloc(*num_spec_constants,
741                                                                                   sizeof(*spec_constants)));
742       assert(spec_constants);
743 
744       for (unsigned i = 0; i < parser.specConstants.size(); ++i) {
745          spec_constants[i] = parser.specConstants[i].second;
746       }
747    }
748 
749    *out_kernels = kernels;
750    *out_spec_constants = spec_constants;
751 
752    return true;
753 }
754 
755 void
756 clc_free_kernels_info(const struct clc_kernel_info *kernels,
757                       unsigned num_kernels)
758 {
759    if (!kernels)
760       return;
761 
762    for (unsigned i = 0; i < num_kernels; i++) {
763       if (kernels[i].args) {
764          for (unsigned j = 0; j < kernels[i].num_args; j++) {
765             free((void *)kernels[i].args[j].name);
766             free((void *)kernels[i].args[j].type_name);
767          }
768          free((void *)kernels[i].args);
769       }
770       free((void *)kernels[i].name);
771    }
772 
773    free((void *)kernels);
774 }
775 
776 static std::unique_ptr<::llvm::Module>
777 clc_compile_to_llvm_module(LLVMContext &llvm_ctx,
778                            const struct clc_compile_args *args,
779                            const struct clc_logger *logger)
780 {
781    static_assert(std::has_unique_object_representations<clc_optional_features>(),
782                  "no padding allowed inside clc_optional_features");
783 
784    std::string diag_log_str;
785    raw_string_ostream diag_log_stream { diag_log_str };
786 
787    std::unique_ptr<clang::CompilerInstance> c { new clang::CompilerInstance };
788 
789    clang::DiagnosticsEngine diag {
790       new clang::DiagnosticIDs,
791       new clang::DiagnosticOptions,
792       new clang::TextDiagnosticPrinter(diag_log_stream,
793                                        &c->getDiagnosticOpts())
794    };
795 
796 #if LLVM_VERSION_MAJOR >= 17
797    const char *triple = args->address_bits == 32 ? "spir-unknown-unknown" : "spirv64-unknown-unknown";
798 #else
799    const char *triple = args->address_bits == 32 ? "spir-unknown-unknown" : "spir64-unknown-unknown";
800 #endif
801 
802    std::vector<const char *> clang_opts = {
803       args->source.name,
804       "-triple", triple,
805       // By default, clang prefers to use modules to pull in the default headers,
806       // which doesn't work with our technique of embedding the headers in our binary
807       "-fdeclare-opencl-builtins",
808 #if LLVM_VERSION_MAJOR < 17
809       "-no-opaque-pointers",
810 #endif
811       // Add a default CL compiler version. Clang will pick the last one specified
812       // on the command line, so the app can override this one.
813       "-cl-std=cl1.2",
814       // The LLVM-SPIRV-Translator doesn't support memset with variable size
815       "-fno-builtin-memset",
816       // LLVM's optimizations can produce code that the translator can't translate
817       "-O0",
818       // Ensure inline functions are actually emitted
819       "-fgnu89-inline",
820    };
821 
822    // We assume there's appropriate defines for __OPENCL_VERSION__ and __IMAGE_SUPPORT__
823    // being provided by the caller here.
824    clang_opts.insert(clang_opts.end(), args->args, args->args + args->num_args);
825 
826    if (!clang::CompilerInvocation::CreateFromArgs(c->getInvocation(),
827                                                   clang_opts,
828                                                   diag)) {
829       clc_error(logger, "Couldn't create Clang invocation.\n");
830       return {};
831    }
832 
833    if (diag.hasErrorOccurred()) {
834       clc_error(logger, "%sErrors occurred during Clang invocation.\n",
835                 diag_log_str.c_str());
836       return {};
837    }
838 
839    // This is a workaround for a Clang bug which causes the number
840    // of warnings and errors to be printed to stderr.
841    // http://www.llvm.org/bugs/show_bug.cgi?id=19735
842    c->getDiagnosticOpts().ShowCarets = false;
843 
844    c->createDiagnostics(new clang::TextDiagnosticPrinter(
845                            diag_log_stream,
846                            &c->getDiagnosticOpts()));
847 
848    c->setTarget(clang::TargetInfo::CreateTargetInfo(
849                    c->getDiagnostics(), c->getInvocation().TargetOpts));
850 
851    c->getFrontendOpts().ProgramAction = clang::frontend::EmitLLVMOnly;
852 
853 #ifdef USE_STATIC_OPENCL_C_H
854    c->getHeaderSearchOpts().UseBuiltinIncludes = false;
855    c->getHeaderSearchOpts().UseStandardSystemIncludes = false;
856 
857    // Add opencl-c generic search path
858    {
859       ::llvm::SmallString<128> system_header_path;
860       ::llvm::sys::path::system_temp_directory(true, system_header_path);
861       ::llvm::sys::path::append(system_header_path, "openclon12");
862       c->getHeaderSearchOpts().AddPath(system_header_path.str(),
863                                        clang::frontend::Angled,
864                                        false, false);
865 
866       ::llvm::sys::path::append(system_header_path, "opencl-c-base.h");
867       c->getPreprocessorOpts().addRemappedFile(system_header_path.str(),
868          ::llvm::MemoryBuffer::getMemBuffer(llvm::StringRef(opencl_c_base_source, ARRAY_SIZE(opencl_c_base_source) - 1)).release());
869       // this line is actually important to make it include `opencl-c.h`
870       ::llvm::sys::path::remove_filename(system_header_path);
871       ::llvm::sys::path::append(system_header_path, "opencl-c.h");
872       c->getPreprocessorOpts().addRemappedFile(system_header_path.str(),
873          ::llvm::MemoryBuffer::getMemBuffer(llvm::StringRef(opencl_c_source, ARRAY_SIZE(opencl_c_source) - 1)).release());
874    }
875 #else
876 
877    Dl_info info;
878    if (dladdr((void *)clang::CompilerInvocation::CreateFromArgs, &info) == 0) {
879       clc_error(logger, "Couldn't find libclang path.\n");
880       return {};
881    }
882 
883    char *clang_path = realpath(info.dli_fname, NULL);
884    if (clang_path == nullptr) {
885       clc_error(logger, "Couldn't find libclang path.\n");
886       return {};
887    }
888 
889    // GetResourcePath is a way to retrieve the actual libclang resource dir based on a given binary
890    // or library.
891    auto tmp_res_path =
892 #if LLVM_VERSION_MAJOR >= 20
893       Driver::GetResourcesPath(std::string(clang_path));
894 #else
895       Driver::GetResourcesPath(std::string(clang_path), CLANG_RESOURCE_DIR);
896 #endif
897    auto clang_res_path = fs::path(tmp_res_path) / "include";
898 
899    free(clang_path);
900 
901    c->getHeaderSearchOpts().UseBuiltinIncludes = true;
902    c->getHeaderSearchOpts().UseStandardSystemIncludes = true;
903    c->getHeaderSearchOpts().ResourceDir = clang_res_path.string();
904 
905    // Add opencl-c generic search path
906    c->getHeaderSearchOpts().AddPath(clang_res_path.string(),
907                                     clang::frontend::Angled,
908                                     false, false);
909 
910    auto clang_install_res_path =
911       fs::path(LLVM_LIB_DIR) / "clang" / std::to_string(LLVM_VERSION_MAJOR) / "include";
912    c->getHeaderSearchOpts().AddPath(clang_install_res_path.string(),
913                                     clang::frontend::Angled,
914                                     false, false);
915 #endif
916 
917    // Enable/Disable optional OpenCL C features. Some can be toggled via `OpenCLExtensionsAsWritten`
918    // others we have to (un)define via macros ourselves.
919 
920    // Undefine clang added SPIR(V) defines so we don't magically enable extensions
921    c->getPreprocessorOpts().addMacroUndef("__SPIR__");
922    c->getPreprocessorOpts().addMacroUndef("__SPIRV__");
923 
924    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("-all");
925    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_byte_addressable_store");
926    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_global_int32_base_atomics");
927    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_global_int32_extended_atomics");
928    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_local_int32_base_atomics");
929    c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_local_int32_extended_atomics");
930    c->getPreprocessorOpts().addMacroDef("cl_khr_expect_assume=1");
931 
932    bool needs_opencl_c_h = false;
933    if (args->features.fp16) {
934       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_fp16");
935    }
936    if (args->features.fp64) {
937       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_fp64");
938       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_fp64");
939    }
940    if (args->features.int64) {
941       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cles_khr_int64");
942       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_int64");
943    } else {
944       // clang defines this unconditionally, we need to fix that.
945       c->getPreprocessorOpts().addMacroUndef("__opencl_c_int64");
946    }
947    if (args->features.images) {
948       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_images");
949    } else {
950       // clang defines this unconditionally, we need to fix that.
951       c->getPreprocessorOpts().addMacroUndef("__IMAGE_SUPPORT__");
952    }
953    if (args->features.images_read_write) {
954       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_read_write_images");
955    }
956    if (args->features.images_write_3d) {
957       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_3d_image_writes");
958       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_3d_image_writes");
959    }
960    if (args->features.images_depth) {
961       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_depth_images");
962    }
963    if (args->features.images_gl_depth) {
964       c->getPreprocessorOpts().addMacroDef("cl_khr_gl_depth_images=1");
965    }
966    if (args->features.images_mipmap) {
967       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_mipmap_image");
968    }
969    if (args->features.images_mipmap_writes) {
970       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_mipmap_image_writes");
971    }
972    if (args->features.images_gl_msaa) {
973       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_gl_msaa_sharing");
974    }
975    if (args->features.intel_subgroups) {
976       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_intel_subgroups");
977       needs_opencl_c_h = true;
978    }
979    if (args->features.subgroups) {
980       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_subgroups");
981       if (args->features.subgroups_shuffle) {
982          c->getPreprocessorOpts().addMacroDef("cl_khr_subgroup_shuffle=1");
983       }
984       if (args->features.subgroups_shuffle_relative) {
985          c->getPreprocessorOpts().addMacroDef("cl_khr_subgroup_shuffle_relative=1");
986       }
987    }
988    if (args->features.subgroups_ifp) {
989       assert(args->features.subgroups);
990       c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_subgroups");
991    }
992    if (args->features.integer_dot_product) {
993       c->getPreprocessorOpts().addMacroDef("cl_khr_integer_dot_product=1");
994       c->getPreprocessorOpts().addMacroDef("__opencl_c_integer_dot_product_input_4x8bit_packed=1");
995       c->getPreprocessorOpts().addMacroDef("__opencl_c_integer_dot_product_input_4x8bit=1");
996    }
997 
998    // Add opencl include
999    c->getPreprocessorOpts().Includes.push_back("opencl-c-base.h");
1000    if (needs_opencl_c_h) {
1001       c->getPreprocessorOpts().Includes.push_back("opencl-c.h");
1002    }
1003 
1004    if (args->num_headers) {
1005       ::llvm::SmallString<128> tmp_header_path;
1006       ::llvm::sys::path::system_temp_directory(true, tmp_header_path);
1007       ::llvm::sys::path::append(tmp_header_path, "openclon12");
1008 
1009       c->getHeaderSearchOpts().AddPath(tmp_header_path.str(),
1010                                        clang::frontend::Quoted,
1011                                        false, false);
1012 
1013       for (size_t i = 0; i < args->num_headers; i++) {
1014          auto path_copy = tmp_header_path;
1015          ::llvm::sys::path::append(path_copy, ::llvm::sys::path::convert_to_slash(args->headers[i].name));
1016          c->getPreprocessorOpts().addRemappedFile(path_copy.str(),
1017             ::llvm::MemoryBuffer::getMemBufferCopy(args->headers[i].value).release());
1018       }
1019    }
1020 
1021    c->getPreprocessorOpts().addRemappedFile(
1022            args->source.name,
1023            ::llvm::MemoryBuffer::getMemBufferCopy(std::string(args->source.value)).release());
1024 
1025    // Compile the code
1026    clang::EmitLLVMOnlyAction act(&llvm_ctx);
1027    if (!c->ExecuteAction(act)) {
1028       clc_error(logger, "%sError executing LLVM compilation action.\n",
1029                 diag_log_str.c_str());
1030       return {};
1031    }
1032 
1033    auto mod = act.takeModule();
1034 
1035    if (clc_debug_flags() & CLC_DEBUG_DUMP_LLVM)
1036       clc_dump_llvm(mod.get(), stdout);
1037 
1038    return mod;
1039 }
1040 
1041 static SPIRV::VersionNumber
1042 spirv_version_to_llvm_spirv_translator_version(enum clc_spirv_version version)
1043 {
1044    switch (version) {
1045    case CLC_SPIRV_VERSION_MAX: return SPIRV::VersionNumber::MaximumVersion;
1046    case CLC_SPIRV_VERSION_1_0: return SPIRV::VersionNumber::SPIRV_1_0;
1047    case CLC_SPIRV_VERSION_1_1: return SPIRV::VersionNumber::SPIRV_1_1;
1048    case CLC_SPIRV_VERSION_1_2: return SPIRV::VersionNumber::SPIRV_1_2;
1049    case CLC_SPIRV_VERSION_1_3: return SPIRV::VersionNumber::SPIRV_1_3;
1050    case CLC_SPIRV_VERSION_1_4: return SPIRV::VersionNumber::SPIRV_1_4;
1051    default:      return invalid_spirv_trans_version;
1052    }
1053 }
1054 
1055 static int
1056 llvm_mod_to_spirv(std::unique_ptr<::llvm::Module> mod,
1057                   LLVMContext &context,
1058                   const struct clc_compile_args *args,
1059                   const struct clc_logger *logger,
1060                   struct clc_binary *out_spirv)
1061 {
1062    std::string log;
1063 
1064    SPIRV::VersionNumber version =
1065       spirv_version_to_llvm_spirv_translator_version(args->spirv_version);
1066    if (version == invalid_spirv_trans_version) {
1067       clc_error(logger, "Invalid/unsupported SPIRV specified.\n");
1068       return -1;
1069    }
1070 
1071    const char *const *extensions = args->allowed_spirv_extensions;
1072    if (!extensions) {
1073       /* The SPIR-V parser doesn't handle all extensions */
1074       static const char *default_extensions[] = {
1075          "SPV_EXT_shader_atomic_float_add",
1076          "SPV_EXT_shader_atomic_float_min_max",
1077          "SPV_KHR_float_controls",
1078          NULL,
1079       };
1080       extensions = default_extensions;
1081    }
1082 
1083    SPIRV::TranslatorOpts::ExtensionsStatusMap ext_map;
1084    for (int i = 0; extensions[i]; i++) {
1085 #define EXT(X) \
1086       if (strcmp(#X, extensions[i]) == 0) \
1087          ext_map.insert(std::make_pair(SPIRV::ExtensionID::X, true));
1088 #include "LLVMSPIRVLib/LLVMSPIRVExtensions.inc"
1089 #undef EXT
1090    }
1091    SPIRV::TranslatorOpts spirv_opts = SPIRV::TranslatorOpts(version, ext_map);
1092 
1093    /* This was the default in 12.0 and older, but currently we'll fail to parse without this */
1094    spirv_opts.setPreserveOCLKernelArgTypeMetadataThroughString(true);
1095 
1096 #if LLVM_VERSION_MAJOR >= 17
1097    if (args->use_llvm_spirv_target) {
1098       const char *triple = args->address_bits == 32 ? "spirv-unknown-unknown" : "spirv64-unknown-unknown";
1099       std::string error_msg("");
1100       auto target = TargetRegistry::lookupTarget(triple, error_msg);
1101       if (target) {
1102          auto TM = target->createTargetMachine(
1103             triple, "", "", {}, std::nullopt, std::nullopt,
1104 #if LLVM_VERSION_MAJOR >= 18
1105             ::llvm::CodeGenOptLevel::None
1106 #else
1107             ::llvm::CodeGenOpt::None
1108 #endif
1109          );
1110 
1111          auto PM = PassManager();
1112          ::llvm::SmallVector<char> buf;
1113          auto OS = ::llvm::raw_svector_ostream(buf);
1114          TM->addPassesToEmitFile(
1115             PM, OS, nullptr,
1116 #if LLVM_VERSION_MAJOR >= 18
1117             ::llvm::CodeGenFileType::ObjectFile
1118 #else
1119             ::llvm::CGFT_ObjectFile
1120 #endif
1121          );
1122 
1123          PM.run(*mod);
1124 
1125          out_spirv->size = buf.size_in_bytes();
1126          out_spirv->data = malloc(out_spirv->size);
1127          memcpy(out_spirv->data, buf.data(), out_spirv->size);
1128          return 0;
1129       } else {
1130          clc_error(logger, "LLVM SPIR-V target not found.\n");
1131          return -1;
1132       }
1133    }
1134 #endif
1135 
1136    std::ostringstream spv_stream;
1137    if (!::llvm::writeSpirv(mod.get(), spirv_opts, spv_stream, log)) {
1138       clc_error(logger, "%sTranslation from LLVM IR to SPIR-V failed.\n",
1139                 log.c_str());
1140       return -1;
1141    }
1142 
1143    const std::string spv_out = spv_stream.str();
1144    out_spirv->size = spv_out.size();
1145    out_spirv->data = malloc(out_spirv->size);
1146    memcpy(out_spirv->data, spv_out.data(), out_spirv->size);
1147 
1148    return 0;
1149 }
1150 
1151 int
1152 clc_c_to_spir(const struct clc_compile_args *args,
1153               const struct clc_logger *logger,
1154               struct clc_binary *out_spir)
1155 {
1156    clc_initialize_llvm();
1157 
1158    LLVMContext llvm_ctx;
1159    llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
1160                                          const_cast<clc_logger *>(logger));
1161 
1162    auto mod = clc_compile_to_llvm_module(llvm_ctx, args, logger);
1163    if (!mod)
1164       return -1;
1165 
1166    ::llvm::SmallVector<char, 0> buffer;
1167    ::llvm::BitcodeWriter writer(buffer);
1168    writer.writeModule(*mod);
1169 
1170    out_spir->size = buffer.size_in_bytes();
1171    out_spir->data = malloc(out_spir->size);
1172    memcpy(out_spir->data, buffer.data(), out_spir->size);
1173 
1174    return 0;
1175 }
1176 
1177 int
1178 clc_c_to_spirv(const struct clc_compile_args *args,
1179                const struct clc_logger *logger,
1180                struct clc_binary *out_spirv)
1181 {
1182    clc_initialize_llvm();
1183 
1184    LLVMContext llvm_ctx;
1185    llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
1186                                          const_cast<clc_logger *>(logger));
1187 
1188    auto mod = clc_compile_to_llvm_module(llvm_ctx, args, logger);
1189    if (!mod)
1190       return -1;
1191    return llvm_mod_to_spirv(std::move(mod), llvm_ctx, args, logger, out_spirv);
1192 }
1193 
1194 int
1195 clc_spir_to_spirv(const struct clc_binary *in_spir,
1196                   const struct clc_logger *logger,
1197                   struct clc_binary *out_spirv)
1198 {
1199    clc_initialize_llvm();
1200 
1201    LLVMContext llvm_ctx;
1202    llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
1203                                          const_cast<clc_logger *>(logger));
1204 
1205    ::llvm::StringRef spir_ref(static_cast<const char*>(in_spir->data), in_spir->size);
1206    auto mod = ::llvm::parseBitcodeFile(::llvm::MemoryBufferRef(spir_ref, "<spir>"), llvm_ctx);
1207    if (!mod)
1208       return -1;
1209 
1210    return llvm_mod_to_spirv(std::move(mod.get()), llvm_ctx, NULL, logger, out_spirv);
1211 }
1212 
1213 class SPIRVMessageConsumer {
1214 public:
1215    SPIRVMessageConsumer(const struct clc_logger *logger): logger(logger) {}
1216 
1217    void operator()(spv_message_level_t level, const char *src,
1218                    const spv_position_t &pos, const char *msg)
1219    {
1220       if (level == SPV_MSG_INFO || level == SPV_MSG_DEBUG)
1221          return;
1222 
1223       std::ostringstream message;
1224       message << "(file=" << (src ? src : "input")
1225               << ",line=" << pos.line
1226               << ",column=" << pos.column
1227               << ",index=" << pos.index
1228               << "): " << msg << "\n";
1229 
1230       if (level == SPV_MSG_WARNING)
1231          clc_warning(logger, "%s", message.str().c_str());
1232       else
1233          clc_error(logger, "%s", message.str().c_str());
1234    }
1235 
1236 private:
1237    const struct clc_logger *logger;
1238 };
1239 
1240 int
1241 clc_link_spirv_binaries(const struct clc_linker_args *args,
1242                         const struct clc_logger *logger,
1243                         struct clc_binary *out_spirv)
1244 {
1245    std::vector<std::vector<uint32_t>> binaries;
1246 
1247    for (unsigned i = 0; i < args->num_in_objs; i++) {
1248       const uint32_t *data = static_cast<const uint32_t *>(args->in_objs[i]->data);
1249       std::vector<uint32_t> bin(data, data + (args->in_objs[i]->size / 4));
1250       binaries.push_back(bin);
1251    }
1252 
1253    SPIRVMessageConsumer msgconsumer(logger);
1254    spvtools::Context context(spirv_target);
1255    context.SetMessageConsumer(msgconsumer);
1256    spvtools::LinkerOptions options;
1257    options.SetAllowPartialLinkage(args->create_library);
1258    #if defined(HAS_SPIRV_LINK_LLVM_WORKAROUND) && LLVM_VERSION_MAJOR >= 17
1259       options.SetAllowPtrTypeMismatch(true);
1260    #endif
1261    options.SetCreateLibrary(args->create_library);
1262    std::vector<uint32_t> linkingResult;
1263    spv_result_t status = spvtools::Link(context, binaries, &linkingResult, options);
1264    if (status != SPV_SUCCESS) {
1265       #if !defined(HAS_SPIRV_LINK_LLVM_WORKAROUND) && LLVM_VERSION_MAJOR >= 17
1266         clc_warning(logger, "SPIRV-Tools doesn't contain https://github.com/KhronosGroup/SPIRV-Tools/pull/5534\n");
1267         clc_warning(logger, "Please update in order to prevent spurious linking failures\n");
1268       #endif
1269       return -1;
1270    }
1271 
1272    out_spirv->size = linkingResult.size() * 4;
1273    out_spirv->data = static_cast<uint32_t *>(malloc(out_spirv->size));
1274    memcpy(out_spirv->data, linkingResult.data(), out_spirv->size);
1275 
1276    return 0;
1277 }
1278 
1279 bool
1280 clc_validate_spirv(const struct clc_binary *spirv,
1281                    const struct clc_logger *logger,
1282                    const struct clc_validator_options *options)
1283 {
1284    SPIRVMessageConsumer msgconsumer(logger);
1285    spvtools::SpirvTools tools(spirv_target);
1286    tools.SetMessageConsumer(msgconsumer);
1287    spvtools::ValidatorOptions spirv_options;
1288    const uint32_t *data = static_cast<const uint32_t *>(spirv->data);
1289 
1290    if (options) {
1291       spirv_options.SetUniversalLimit(
1292          spv_validator_limit_max_function_args,
1293          options->limit_max_function_arg);
1294    }
1295 
1296    return tools.Validate(data, spirv->size / 4, spirv_options);
1297 }
1298 
1299 int
1300 clc_spirv_specialize(const struct clc_binary *in_spirv,
1301                      const struct clc_parsed_spirv *parsed_data,
1302                      const struct clc_spirv_specialization_consts *consts,
1303                      struct clc_binary *out_spirv)
1304 {
1305    std::unordered_map<uint32_t, std::vector<uint32_t>> spec_const_map;
1306    for (unsigned i = 0; i < consts->num_specializations; ++i) {
1307       unsigned id = consts->specializations[i].id;
1308       auto parsed_spec_const = std::find_if(parsed_data->spec_constants,
1309          parsed_data->spec_constants + parsed_data->num_spec_constants,
1310          [id](const clc_parsed_spec_constant &c) { return c.id == id; });
1311       assert(parsed_spec_const != parsed_data->spec_constants + parsed_data->num_spec_constants);
1312 
1313       std::vector<uint32_t> words;
1314       switch (parsed_spec_const->type) {
1315       case CLC_SPEC_CONSTANT_BOOL:
1316          words.push_back(consts->specializations[i].value.b);
1317          break;
1318       case CLC_SPEC_CONSTANT_INT32:
1319       case CLC_SPEC_CONSTANT_UINT32:
1320       case CLC_SPEC_CONSTANT_FLOAT:
1321          words.push_back(consts->specializations[i].value.u32);
1322          break;
1323       case CLC_SPEC_CONSTANT_INT16:
1324          words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i16);
1325          break;
1326       case CLC_SPEC_CONSTANT_INT8:
1327          words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i8);
1328          break;
1329       case CLC_SPEC_CONSTANT_UINT16:
1330          words.push_back((uint32_t)consts->specializations[i].value.u16);
1331          break;
1332       case CLC_SPEC_CONSTANT_UINT8:
1333          words.push_back((uint32_t)consts->specializations[i].value.u8);
1334          break;
1335       case CLC_SPEC_CONSTANT_DOUBLE:
1336       case CLC_SPEC_CONSTANT_INT64:
1337       case CLC_SPEC_CONSTANT_UINT64:
1338          words.resize(2);
1339          memcpy(words.data(), &consts->specializations[i].value.u64, 8);
1340          break;
1341       case CLC_SPEC_CONSTANT_UNKNOWN:
1342          assert(0);
1343          break;
1344       }
1345 
1346       ASSERTED auto ret = spec_const_map.emplace(id, std::move(words));
1347       assert(ret.second);
1348    }
1349 
1350    spvtools::Optimizer opt(spirv_target);
1351    opt.RegisterPass(spvtools::CreateSetSpecConstantDefaultValuePass(std::move(spec_const_map)));
1352 
1353    std::vector<uint32_t> result;
1354    if (!opt.Run(static_cast<const uint32_t*>(in_spirv->data), in_spirv->size / 4, &result))
1355       return false;
1356 
1357    out_spirv->size = result.size() * 4;
1358    out_spirv->data = malloc(out_spirv->size);
1359    memcpy(out_spirv->data, result.data(), out_spirv->size);
1360    return true;
1361 }
1362 
1363 static void
1364 clc_dump_llvm(const llvm::Module *mod, FILE *f)
1365 {
1366    std::string out;
1367    raw_string_ostream os(out);
1368 
1369    mod->print(os, nullptr);
1370    os.flush();
1371 
1372    fwrite(out.c_str(), out.size(), 1, f);
1373 }
1374 
1375 void
1376 clc_dump_spirv(const struct clc_binary *spvbin, FILE *f)
1377 {
1378    spvtools::SpirvTools tools(spirv_target);
1379    const uint32_t *data = static_cast<const uint32_t *>(spvbin->data);
1380    std::vector<uint32_t> bin(data, data + (spvbin->size / 4));
1381    std::string out;
1382    tools.Disassemble(bin, &out,
1383                      SPV_BINARY_TO_TEXT_OPTION_INDENT |
1384                      SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
1385    fwrite(out.c_str(), out.size(), 1, f);
1386 }
1387 
1388 void
1389 clc_free_spir_binary(struct clc_binary *spir)
1390 {
1391    free(spir->data);
1392 }
1393 
1394 void
1395 clc_free_spirv_binary(struct clc_binary *spvbin)
1396 {
1397    free(spvbin->data);
1398 }
1399 
1400 void
1401 initialize_llvm_once(void)
1402 {
1403    LLVMInitializeAllTargets();
1404    LLVMInitializeAllTargetInfos();
1405    LLVMInitializeAllTargetMCs();
1406    LLVMInitializeAllAsmParsers();
1407    LLVMInitializeAllAsmPrinters();
1408 }
1409 
1410 std::once_flag initialize_llvm_once_flag;
1411 
1412 void
1413 clc_initialize_llvm(void)
1414 {
1415    std::call_once(initialize_llvm_once_flag,
1416                   []() { initialize_llvm_once(); });
1417 }
1418