1 //
2 // Copyright 2012-2016 Francisco Jerez
3 // Copyright 2012-2016 Advanced Micro Devices, Inc.
4 // Copyright 2014-2016 Jan Vesely
5 // Copyright 2014-2015 Serge Martin
6 // Copyright 2015 Zoltan Gilian
7 //
8 // Permission is hereby granted, free of charge, to any person obtaining a
9 // copy of this software and associated documentation files (the "Software"),
10 // to deal in the Software without restriction, including without limitation
11 // the rights to use, copy, modify, merge, publish, distribute, sublicense,
12 // and/or sell copies of the Software, and to permit persons to whom the
13 // Software is furnished to do so, subject to the following conditions:
14 //
15 // The above copyright notice and this permission notice shall be included in
16 // all copies or substantial portions of the Software.
17 //
18 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
21 // THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR
22 // OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
23 // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
24 // OTHER DEALINGS IN THE SOFTWARE.
25
26 #include <cstdlib>
27 #include <filesystem>
28 #include <sstream>
29 #include <mutex>
30
31 #include <llvm/ADT/ArrayRef.h>
32 #include <llvm/IR/DiagnosticPrinter.h>
33 #include <llvm/IR/DiagnosticInfo.h>
34 #include <llvm/IR/LegacyPassManager.h>
35 #include <llvm/IR/LLVMContext.h>
36 #include <llvm/IR/Type.h>
37 #include <llvm/MC/TargetRegistry.h>
38 #include <llvm/Target/TargetMachine.h>
39 #include <llvm/Support/raw_ostream.h>
40 #include <llvm/Bitcode/BitcodeWriter.h>
41 #include <llvm/Bitcode/BitcodeReader.h>
42 #include <llvm-c/Core.h>
43 #include <llvm-c/Target.h>
44 #include <LLVMSPIRVLib/LLVMSPIRVLib.h>
45
46 #include <clang/Config/config.h>
47 #include <clang/Driver/Driver.h>
48 #include <clang/CodeGen/CodeGenAction.h>
49 #include <clang/Lex/PreprocessorOptions.h>
50 #include <clang/Frontend/CompilerInstance.h>
51 #include <clang/Frontend/TextDiagnosticBuffer.h>
52 #include <clang/Frontend/TextDiagnosticPrinter.h>
53 #include <clang/Basic/TargetInfo.h>
54
55 #include <spirv-tools/libspirv.hpp>
56 #include <spirv-tools/linker.hpp>
57 #include <spirv-tools/optimizer.hpp>
58
59 #include "util/macros.h"
60 #include "glsl_types.h"
61
62 #include "spirv.h"
63
64 #if DETECT_OS_POSIX
65 #include <dlfcn.h>
66 #endif
67
68 #ifdef USE_STATIC_OPENCL_C_H
69 #include "opencl-c-base.h.h"
70 #include "opencl-c.h.h"
71 #endif
72
73 #include "clc_helpers.h"
74
75 namespace fs = std::filesystem;
76
77 /* Use the highest version of SPIRV supported by SPIRV-Tools. */
78 constexpr spv_target_env spirv_target = SPV_ENV_UNIVERSAL_1_5;
79
80 constexpr SPIRV::VersionNumber invalid_spirv_trans_version = static_cast<SPIRV::VersionNumber>(0);
81
82 using ::llvm::Function;
83 using ::llvm::legacy::PassManager;
84 using ::llvm::LLVMContext;
85 using ::llvm::Module;
86 using ::llvm::raw_string_ostream;
87 using ::llvm::TargetRegistry;
88 using ::clang::driver::Driver;
89
90 static void
91 clc_dump_llvm(const llvm::Module *mod, FILE *f);
92
93 static void
94 #if LLVM_VERSION_MAJOR >= 19
llvm_log_handler(const::llvm::DiagnosticInfo * di,void * data)95 llvm_log_handler(const ::llvm::DiagnosticInfo *di, void *data) {
96 #else
97 llvm_log_handler(const ::llvm::DiagnosticInfo &di, void *data) {
98 #endif
99 const clc_logger *logger = static_cast<clc_logger *>(data);
100
101 std::string log;
102 raw_string_ostream os { log };
103 ::llvm::DiagnosticPrinterRawOStream printer { os };
104 #if LLVM_VERSION_MAJOR >= 19
105 di->print(printer);
106 #else
107 di.print(printer);
108 #endif
109
110 clc_error(logger, "%s", log.c_str());
111 }
112
113 class SPIRVKernelArg {
114 public:
115 SPIRVKernelArg(uint32_t id, uint32_t typeId) : id(id), typeId(typeId),
116 addrQualifier(CLC_KERNEL_ARG_ADDRESS_PRIVATE),
117 accessQualifier(0),
118 typeQualifier(0) { }
119 ~SPIRVKernelArg() { }
120
121 uint32_t id;
122 uint32_t typeId;
123 std::string name;
124 std::string typeName;
125 enum clc_kernel_arg_address_qualifier addrQualifier;
126 unsigned accessQualifier;
127 unsigned typeQualifier;
128 };
129
130 class SPIRVKernelInfo {
131 public:
132 SPIRVKernelInfo(uint32_t fid, const char *nm)
133 : funcId(fid), name(nm), vecHint(0), localSize(), localSizeHint() { }
134 ~SPIRVKernelInfo() { }
135
136 uint32_t funcId;
137 std::string name;
138 std::vector<SPIRVKernelArg> args;
139 unsigned vecHint;
140 unsigned localSize[3];
141 unsigned localSizeHint[3];
142 };
143
144 class SPIRVKernelParser {
145 public:
146 SPIRVKernelParser() : curKernel(NULL)
147 {
148 ctx = spvContextCreate(spirv_target);
149 }
150
151 ~SPIRVKernelParser()
152 {
153 spvContextDestroy(ctx);
154 }
155
156 void parseEntryPoint(const spv_parsed_instruction_t *ins)
157 {
158 assert(ins->num_operands >= 3);
159
160 const spv_parsed_operand_t *op = &ins->operands[1];
161
162 assert(op->type == SPV_OPERAND_TYPE_ID);
163
164 uint32_t funcId = ins->words[op->offset];
165
166 for (auto &iter : kernels) {
167 if (funcId == iter.funcId)
168 return;
169 }
170
171 op = &ins->operands[2];
172 assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
173 const char *name = reinterpret_cast<const char *>(ins->words + op->offset);
174
175 kernels.push_back(SPIRVKernelInfo(funcId, name));
176 }
177
178 void parseFunction(const spv_parsed_instruction_t *ins)
179 {
180 assert(ins->num_operands == 4);
181
182 const spv_parsed_operand_t *op = &ins->operands[1];
183
184 assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
185
186 uint32_t funcId = ins->words[op->offset];
187
188 for (auto &kernel : kernels) {
189 if (funcId == kernel.funcId && !kernel.args.size()) {
190 curKernel = &kernel;
191 return;
192 }
193 }
194 }
195
196 void parseFunctionParam(const spv_parsed_instruction_t *ins)
197 {
198 const spv_parsed_operand_t *op;
199 uint32_t id, typeId;
200
201 if (!curKernel)
202 return;
203
204 assert(ins->num_operands == 2);
205 op = &ins->operands[0];
206 assert(op->type == SPV_OPERAND_TYPE_TYPE_ID);
207 typeId = ins->words[op->offset];
208 op = &ins->operands[1];
209 assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
210 id = ins->words[op->offset];
211 curKernel->args.push_back(SPIRVKernelArg(id, typeId));
212 }
213
214 void parseName(const spv_parsed_instruction_t *ins)
215 {
216 const spv_parsed_operand_t *op;
217 const char *name;
218 uint32_t id;
219
220 assert(ins->num_operands == 2);
221
222 op = &ins->operands[0];
223 assert(op->type == SPV_OPERAND_TYPE_ID);
224 id = ins->words[op->offset];
225 op = &ins->operands[1];
226 assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
227 name = reinterpret_cast<const char *>(ins->words + op->offset);
228
229 for (auto &kernel : kernels) {
230 for (auto &arg : kernel.args) {
231 if (arg.id == id && arg.name.empty()) {
232 arg.name = name;
233 break;
234 }
235 }
236 }
237 }
238
239 void parseTypePointer(const spv_parsed_instruction_t *ins)
240 {
241 enum clc_kernel_arg_address_qualifier addrQualifier;
242 uint32_t typeId, storageClass;
243 const spv_parsed_operand_t *op;
244
245 assert(ins->num_operands == 3);
246
247 op = &ins->operands[0];
248 assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
249 typeId = ins->words[op->offset];
250
251 op = &ins->operands[1];
252 assert(op->type == SPV_OPERAND_TYPE_STORAGE_CLASS);
253 storageClass = ins->words[op->offset];
254 switch (storageClass) {
255 case SpvStorageClassCrossWorkgroup:
256 addrQualifier = CLC_KERNEL_ARG_ADDRESS_GLOBAL;
257 break;
258 case SpvStorageClassWorkgroup:
259 addrQualifier = CLC_KERNEL_ARG_ADDRESS_LOCAL;
260 break;
261 case SpvStorageClassUniformConstant:
262 addrQualifier = CLC_KERNEL_ARG_ADDRESS_CONSTANT;
263 break;
264 default:
265 addrQualifier = CLC_KERNEL_ARG_ADDRESS_PRIVATE;
266 break;
267 }
268
269 for (auto &kernel : kernels) {
270 for (auto &arg : kernel.args) {
271 if (arg.typeId == typeId) {
272 arg.addrQualifier = addrQualifier;
273 if (addrQualifier == CLC_KERNEL_ARG_ADDRESS_CONSTANT)
274 arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
275 }
276 }
277 }
278 }
279
280 void parseOpString(const spv_parsed_instruction_t *ins)
281 {
282 const spv_parsed_operand_t *op;
283 std::string str;
284
285 assert(ins->num_operands == 2);
286
287 op = &ins->operands[1];
288 assert(op->type == SPV_OPERAND_TYPE_LITERAL_STRING);
289 str = reinterpret_cast<const char *>(ins->words + op->offset);
290
291 size_t start = 0;
292 enum class string_type {
293 arg_type,
294 arg_type_qual,
295 } str_type;
296
297 if (str.find("kernel_arg_type.") == 0) {
298 start = sizeof("kernel_arg_type.") - 1;
299 str_type = string_type::arg_type;
300 } else if (str.find("kernel_arg_type_qual.") == 0) {
301 start = sizeof("kernel_arg_type_qual.") - 1;
302 str_type = string_type::arg_type_qual;
303 } else {
304 return;
305 }
306
307 for (auto &kernel : kernels) {
308 size_t pos;
309
310 pos = str.find(kernel.name, start);
311 if (pos == std::string::npos ||
312 pos != start || str[start + kernel.name.size()] != '.')
313 continue;
314
315 pos = start + kernel.name.size();
316 if (str[pos++] != '.')
317 continue;
318
319 for (auto &arg : kernel.args) {
320 if (arg.name.empty())
321 break;
322
323 size_t entryEnd = str.find(',', pos);
324 if (entryEnd == std::string::npos)
325 break;
326
327 std::string entryVal = str.substr(pos, entryEnd - pos);
328 pos = entryEnd + 1;
329
330 if (str_type == string_type::arg_type) {
331 arg.typeName = std::move(entryVal);
332 } else if (str_type == string_type::arg_type_qual) {
333 if (entryVal.find("const") != std::string::npos)
334 arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
335 }
336 }
337 }
338 }
339
340 void applyDecoration(uint32_t id, const spv_parsed_instruction_t *ins)
341 {
342 auto iter = decorationGroups.find(id);
343 if (iter != decorationGroups.end()) {
344 for (uint32_t entry : iter->second)
345 applyDecoration(entry, ins);
346 return;
347 }
348
349 const spv_parsed_operand_t *op;
350 uint32_t decoration;
351
352 assert(ins->num_operands >= 2);
353
354 op = &ins->operands[1];
355 assert(op->type == SPV_OPERAND_TYPE_DECORATION);
356 decoration = ins->words[op->offset];
357
358 if (decoration == SpvDecorationSpecId) {
359 uint32_t spec_id = ins->words[ins->operands[2].offset];
360 for (auto &c : specConstants) {
361 if (c.second.id == spec_id) {
362 return;
363 }
364 }
365 specConstants.emplace_back(id, clc_parsed_spec_constant{ spec_id });
366 return;
367 }
368
369 for (auto &kernel : kernels) {
370 for (auto &arg : kernel.args) {
371 if (arg.id == id) {
372 switch (decoration) {
373 case SpvDecorationVolatile:
374 arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_VOLATILE;
375 break;
376 case SpvDecorationConstant:
377 arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
378 break;
379 case SpvDecorationRestrict:
380 arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_RESTRICT;
381 break;
382 case SpvDecorationFuncParamAttr:
383 op = &ins->operands[2];
384 assert(op->type == SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE);
385 switch (ins->words[op->offset]) {
386 case SpvFunctionParameterAttributeNoAlias:
387 arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_RESTRICT;
388 break;
389 case SpvFunctionParameterAttributeNoWrite:
390 arg.typeQualifier |= CLC_KERNEL_ARG_TYPE_CONST;
391 break;
392 }
393 break;
394 }
395 }
396
397 }
398 }
399 }
400
401 void parseOpDecorate(const spv_parsed_instruction_t *ins)
402 {
403 const spv_parsed_operand_t *op;
404 uint32_t id;
405
406 assert(ins->num_operands >= 2);
407
408 op = &ins->operands[0];
409 assert(op->type == SPV_OPERAND_TYPE_ID);
410 id = ins->words[op->offset];
411
412 applyDecoration(id, ins);
413 }
414
415 void parseOpGroupDecorate(const spv_parsed_instruction_t *ins)
416 {
417 assert(ins->num_operands >= 2);
418
419 const spv_parsed_operand_t *op = &ins->operands[0];
420 assert(op->type == SPV_OPERAND_TYPE_ID);
421 uint32_t groupId = ins->words[op->offset];
422
423 auto lowerBound = decorationGroups.lower_bound(groupId);
424 if (lowerBound != decorationGroups.end() &&
425 lowerBound->first == groupId)
426 // Group already filled out
427 return;
428
429 auto iter = decorationGroups.emplace_hint(lowerBound, groupId, std::vector<uint32_t>{});
430 auto& vec = iter->second;
431 vec.reserve(ins->num_operands - 1);
432 for (uint32_t i = 1; i < ins->num_operands; ++i) {
433 op = &ins->operands[i];
434 assert(op->type == SPV_OPERAND_TYPE_ID);
435 vec.push_back(ins->words[op->offset]);
436 }
437 }
438
439 void parseOpTypeImage(const spv_parsed_instruction_t *ins)
440 {
441 const spv_parsed_operand_t *op;
442 uint32_t typeId;
443 unsigned accessQualifier = CLC_KERNEL_ARG_ACCESS_READ;
444
445 op = &ins->operands[0];
446 assert(op->type == SPV_OPERAND_TYPE_RESULT_ID);
447 typeId = ins->words[op->offset];
448
449 if (ins->num_operands >= 9) {
450 op = &ins->operands[8];
451 assert(op->type == SPV_OPERAND_TYPE_ACCESS_QUALIFIER);
452 switch (ins->words[op->offset]) {
453 case SpvAccessQualifierReadOnly:
454 accessQualifier = CLC_KERNEL_ARG_ACCESS_READ;
455 break;
456 case SpvAccessQualifierWriteOnly:
457 accessQualifier = CLC_KERNEL_ARG_ACCESS_WRITE;
458 break;
459 case SpvAccessQualifierReadWrite:
460 accessQualifier = CLC_KERNEL_ARG_ACCESS_WRITE |
461 CLC_KERNEL_ARG_ACCESS_READ;
462 break;
463 }
464 }
465
466 for (auto &kernel : kernels) {
467 for (auto &arg : kernel.args) {
468 if (arg.typeId == typeId) {
469 arg.accessQualifier = accessQualifier;
470 arg.addrQualifier = CLC_KERNEL_ARG_ADDRESS_GLOBAL;
471 }
472 }
473 }
474 }
475
476 void parseExecutionMode(const spv_parsed_instruction_t *ins)
477 {
478 uint32_t executionMode = ins->words[ins->operands[1].offset];
479 uint32_t funcId = ins->words[ins->operands[0].offset];
480
481 for (auto& kernel : kernels) {
482 if (kernel.funcId == funcId) {
483 switch (executionMode) {
484 case SpvExecutionModeVecTypeHint:
485 kernel.vecHint = ins->words[ins->operands[2].offset];
486 break;
487 case SpvExecutionModeLocalSize:
488 kernel.localSize[0] = ins->words[ins->operands[2].offset];
489 kernel.localSize[1] = ins->words[ins->operands[3].offset];
490 kernel.localSize[2] = ins->words[ins->operands[4].offset];
491 break;
492 case SpvExecutionModeLocalSizeHint:
493 kernel.localSizeHint[0] = ins->words[ins->operands[2].offset];
494 kernel.localSizeHint[1] = ins->words[ins->operands[3].offset];
495 kernel.localSizeHint[2] = ins->words[ins->operands[4].offset];
496 break;
497 default:
498 return;
499 }
500 }
501 }
502 }
503
504 void parseLiteralType(const spv_parsed_instruction_t *ins)
505 {
506 uint32_t typeId = ins->words[ins->operands[0].offset];
507 auto& literalType = literalTypes[typeId];
508 switch (ins->opcode) {
509 case SpvOpTypeBool:
510 literalType = CLC_SPEC_CONSTANT_BOOL;
511 break;
512 case SpvOpTypeFloat: {
513 uint32_t sizeInBits = ins->words[ins->operands[1].offset];
514 switch (sizeInBits) {
515 case 32:
516 literalType = CLC_SPEC_CONSTANT_FLOAT;
517 break;
518 case 64:
519 literalType = CLC_SPEC_CONSTANT_DOUBLE;
520 break;
521 case 16:
522 /* Can't be used for a spec constant */
523 break;
524 default:
525 unreachable("Unexpected float bit size");
526 }
527 break;
528 }
529 case SpvOpTypeInt: {
530 uint32_t sizeInBits = ins->words[ins->operands[1].offset];
531 bool isSigned = ins->words[ins->operands[2].offset];
532 if (isSigned) {
533 switch (sizeInBits) {
534 case 8:
535 literalType = CLC_SPEC_CONSTANT_INT8;
536 break;
537 case 16:
538 literalType = CLC_SPEC_CONSTANT_INT16;
539 break;
540 case 32:
541 literalType = CLC_SPEC_CONSTANT_INT32;
542 break;
543 case 64:
544 literalType = CLC_SPEC_CONSTANT_INT64;
545 break;
546 default:
547 unreachable("Unexpected int bit size");
548 }
549 } else {
550 switch (sizeInBits) {
551 case 8:
552 literalType = CLC_SPEC_CONSTANT_UINT8;
553 break;
554 case 16:
555 literalType = CLC_SPEC_CONSTANT_UINT16;
556 break;
557 case 32:
558 literalType = CLC_SPEC_CONSTANT_UINT32;
559 break;
560 case 64:
561 literalType = CLC_SPEC_CONSTANT_UINT64;
562 break;
563 default:
564 unreachable("Unexpected uint bit size");
565 }
566 }
567 break;
568 }
569 default:
570 unreachable("Unexpected type opcode");
571 }
572 }
573
574 void parseSpecConstant(const spv_parsed_instruction_t *ins)
575 {
576 uint32_t id = ins->result_id;
577 for (auto& c : specConstants) {
578 if (c.first == id) {
579 auto& data = c.second;
580 switch (ins->opcode) {
581 case SpvOpSpecConstant: {
582 uint32_t typeId = ins->words[ins->operands[0].offset];
583
584 // This better be an integer or float type
585 auto typeIter = literalTypes.find(typeId);
586 assert(typeIter != literalTypes.end());
587
588 data.type = typeIter->second;
589 break;
590 }
591 case SpvOpSpecConstantFalse:
592 case SpvOpSpecConstantTrue:
593 data.type = CLC_SPEC_CONSTANT_BOOL;
594 break;
595 default:
596 unreachable("Composites and Ops are not directly specializable.");
597 }
598 }
599 }
600 }
601
602 static spv_result_t
603 parseInstruction(void *data, const spv_parsed_instruction_t *ins)
604 {
605 SPIRVKernelParser *parser = reinterpret_cast<SPIRVKernelParser *>(data);
606
607 switch (ins->opcode) {
608 case SpvOpName:
609 parser->parseName(ins);
610 break;
611 case SpvOpEntryPoint:
612 parser->parseEntryPoint(ins);
613 break;
614 case SpvOpFunction:
615 parser->parseFunction(ins);
616 break;
617 case SpvOpFunctionParameter:
618 parser->parseFunctionParam(ins);
619 break;
620 case SpvOpFunctionEnd:
621 case SpvOpLabel:
622 parser->curKernel = NULL;
623 break;
624 case SpvOpTypePointer:
625 parser->parseTypePointer(ins);
626 break;
627 case SpvOpTypeImage:
628 parser->parseOpTypeImage(ins);
629 break;
630 case SpvOpString:
631 parser->parseOpString(ins);
632 break;
633 case SpvOpDecorate:
634 parser->parseOpDecorate(ins);
635 break;
636 case SpvOpGroupDecorate:
637 parser->parseOpGroupDecorate(ins);
638 break;
639 case SpvOpExecutionMode:
640 parser->parseExecutionMode(ins);
641 break;
642 case SpvOpTypeBool:
643 case SpvOpTypeInt:
644 case SpvOpTypeFloat:
645 parser->parseLiteralType(ins);
646 break;
647 case SpvOpSpecConstant:
648 case SpvOpSpecConstantFalse:
649 case SpvOpSpecConstantTrue:
650 parser->parseSpecConstant(ins);
651 break;
652 default:
653 break;
654 }
655
656 return SPV_SUCCESS;
657 }
658
659 bool parseBinary(const struct clc_binary &spvbin, const struct clc_logger *logger)
660 {
661 /* 3 passes should be enough to retrieve all kernel information:
662 * 1st pass: all entry point name and number of args
663 * 2nd pass: argument names and type names
664 * 3rd pass: pointer type names
665 */
666 for (unsigned pass = 0; pass < 3; pass++) {
667 spv_diagnostic diagnostic = NULL;
668 auto result = spvBinaryParse(ctx, reinterpret_cast<void *>(this),
669 static_cast<uint32_t*>(spvbin.data), spvbin.size / 4,
670 NULL, parseInstruction, &diagnostic);
671
672 if (result != SPV_SUCCESS) {
673 if (diagnostic && logger)
674 logger->error(logger->priv, diagnostic->error);
675 return false;
676 }
677 }
678
679 return true;
680 }
681
682 std::vector<SPIRVKernelInfo> kernels;
683 std::vector<std::pair<uint32_t, clc_parsed_spec_constant>> specConstants;
684 std::map<uint32_t, enum clc_spec_constant_type> literalTypes;
685 std::map<uint32_t, std::vector<uint32_t>> decorationGroups;
686 SPIRVKernelInfo *curKernel;
687 spv_context ctx;
688 };
689
690 bool
691 clc_spirv_get_kernels_info(const struct clc_binary *spvbin,
692 const struct clc_kernel_info **out_kernels,
693 unsigned *num_kernels,
694 const struct clc_parsed_spec_constant **out_spec_constants,
695 unsigned *num_spec_constants,
696 const struct clc_logger *logger)
697 {
698 struct clc_kernel_info *kernels = NULL;
699 struct clc_parsed_spec_constant *spec_constants = NULL;
700
701 SPIRVKernelParser parser;
702
703 if (!parser.parseBinary(*spvbin, logger))
704 return false;
705
706 *num_kernels = parser.kernels.size();
707 *num_spec_constants = parser.specConstants.size();
708 if (*num_kernels) {
709 kernels = reinterpret_cast<struct clc_kernel_info *>(calloc(*num_kernels,
710 sizeof(*kernels)));
711 assert(kernels);
712 for (unsigned i = 0; i < parser.kernels.size(); i++) {
713 kernels[i].name = strdup(parser.kernels[i].name.c_str());
714 kernels[i].num_args = parser.kernels[i].args.size();
715 kernels[i].vec_hint_size = parser.kernels[i].vecHint >> 16;
716 kernels[i].vec_hint_type = (enum clc_vec_hint_type)(parser.kernels[i].vecHint & 0xFFFF);
717 memcpy(kernels[i].local_size, parser.kernels[i].localSize, sizeof(kernels[i].local_size));
718 memcpy(kernels[i].local_size_hint, parser.kernels[i].localSizeHint, sizeof(kernels[i].local_size_hint));
719 if (!kernels[i].num_args)
720 continue;
721
722 struct clc_kernel_arg *args;
723
724 args = reinterpret_cast<struct clc_kernel_arg *>(calloc(kernels[i].num_args,
725 sizeof(*kernels->args)));
726 kernels[i].args = args;
727 assert(args);
728 for (unsigned j = 0; j < kernels[i].num_args; j++) {
729 if (!parser.kernels[i].args[j].name.empty())
730 args[j].name = strdup(parser.kernels[i].args[j].name.c_str());
731 args[j].type_name = strdup(parser.kernels[i].args[j].typeName.c_str());
732 args[j].address_qualifier = parser.kernels[i].args[j].addrQualifier;
733 args[j].type_qualifier = parser.kernels[i].args[j].typeQualifier;
734 args[j].access_qualifier = parser.kernels[i].args[j].accessQualifier;
735 }
736 }
737 }
738
739 if (*num_spec_constants) {
740 spec_constants = reinterpret_cast<struct clc_parsed_spec_constant *>(calloc(*num_spec_constants,
741 sizeof(*spec_constants)));
742 assert(spec_constants);
743
744 for (unsigned i = 0; i < parser.specConstants.size(); ++i) {
745 spec_constants[i] = parser.specConstants[i].second;
746 }
747 }
748
749 *out_kernels = kernels;
750 *out_spec_constants = spec_constants;
751
752 return true;
753 }
754
755 void
756 clc_free_kernels_info(const struct clc_kernel_info *kernels,
757 unsigned num_kernels)
758 {
759 if (!kernels)
760 return;
761
762 for (unsigned i = 0; i < num_kernels; i++) {
763 if (kernels[i].args) {
764 for (unsigned j = 0; j < kernels[i].num_args; j++) {
765 free((void *)kernels[i].args[j].name);
766 free((void *)kernels[i].args[j].type_name);
767 }
768 free((void *)kernels[i].args);
769 }
770 free((void *)kernels[i].name);
771 }
772
773 free((void *)kernels);
774 }
775
776 static std::unique_ptr<::llvm::Module>
777 clc_compile_to_llvm_module(LLVMContext &llvm_ctx,
778 const struct clc_compile_args *args,
779 const struct clc_logger *logger)
780 {
781 static_assert(std::has_unique_object_representations<clc_optional_features>(),
782 "no padding allowed inside clc_optional_features");
783
784 std::string diag_log_str;
785 raw_string_ostream diag_log_stream { diag_log_str };
786
787 std::unique_ptr<clang::CompilerInstance> c { new clang::CompilerInstance };
788
789 clang::DiagnosticsEngine diag {
790 new clang::DiagnosticIDs,
791 new clang::DiagnosticOptions,
792 new clang::TextDiagnosticPrinter(diag_log_stream,
793 &c->getDiagnosticOpts())
794 };
795
796 #if LLVM_VERSION_MAJOR >= 17
797 const char *triple = args->address_bits == 32 ? "spir-unknown-unknown" : "spirv64-unknown-unknown";
798 #else
799 const char *triple = args->address_bits == 32 ? "spir-unknown-unknown" : "spir64-unknown-unknown";
800 #endif
801
802 std::vector<const char *> clang_opts = {
803 args->source.name,
804 "-triple", triple,
805 // By default, clang prefers to use modules to pull in the default headers,
806 // which doesn't work with our technique of embedding the headers in our binary
807 "-fdeclare-opencl-builtins",
808 #if LLVM_VERSION_MAJOR < 17
809 "-no-opaque-pointers",
810 #endif
811 // Add a default CL compiler version. Clang will pick the last one specified
812 // on the command line, so the app can override this one.
813 "-cl-std=cl1.2",
814 // The LLVM-SPIRV-Translator doesn't support memset with variable size
815 "-fno-builtin-memset",
816 // LLVM's optimizations can produce code that the translator can't translate
817 "-O0",
818 // Ensure inline functions are actually emitted
819 "-fgnu89-inline",
820 };
821
822 // We assume there's appropriate defines for __OPENCL_VERSION__ and __IMAGE_SUPPORT__
823 // being provided by the caller here.
824 clang_opts.insert(clang_opts.end(), args->args, args->args + args->num_args);
825
826 if (!clang::CompilerInvocation::CreateFromArgs(c->getInvocation(),
827 clang_opts,
828 diag)) {
829 clc_error(logger, "Couldn't create Clang invocation.\n");
830 return {};
831 }
832
833 if (diag.hasErrorOccurred()) {
834 clc_error(logger, "%sErrors occurred during Clang invocation.\n",
835 diag_log_str.c_str());
836 return {};
837 }
838
839 // This is a workaround for a Clang bug which causes the number
840 // of warnings and errors to be printed to stderr.
841 // http://www.llvm.org/bugs/show_bug.cgi?id=19735
842 c->getDiagnosticOpts().ShowCarets = false;
843
844 c->createDiagnostics(new clang::TextDiagnosticPrinter(
845 diag_log_stream,
846 &c->getDiagnosticOpts()));
847
848 c->setTarget(clang::TargetInfo::CreateTargetInfo(
849 c->getDiagnostics(), c->getInvocation().TargetOpts));
850
851 c->getFrontendOpts().ProgramAction = clang::frontend::EmitLLVMOnly;
852
853 #ifdef USE_STATIC_OPENCL_C_H
854 c->getHeaderSearchOpts().UseBuiltinIncludes = false;
855 c->getHeaderSearchOpts().UseStandardSystemIncludes = false;
856
857 // Add opencl-c generic search path
858 {
859 ::llvm::SmallString<128> system_header_path;
860 ::llvm::sys::path::system_temp_directory(true, system_header_path);
861 ::llvm::sys::path::append(system_header_path, "openclon12");
862 c->getHeaderSearchOpts().AddPath(system_header_path.str(),
863 clang::frontend::Angled,
864 false, false);
865
866 ::llvm::sys::path::append(system_header_path, "opencl-c-base.h");
867 c->getPreprocessorOpts().addRemappedFile(system_header_path.str(),
868 ::llvm::MemoryBuffer::getMemBuffer(llvm::StringRef(opencl_c_base_source, ARRAY_SIZE(opencl_c_base_source) - 1)).release());
869 // this line is actually important to make it include `opencl-c.h`
870 ::llvm::sys::path::remove_filename(system_header_path);
871 ::llvm::sys::path::append(system_header_path, "opencl-c.h");
872 c->getPreprocessorOpts().addRemappedFile(system_header_path.str(),
873 ::llvm::MemoryBuffer::getMemBuffer(llvm::StringRef(opencl_c_source, ARRAY_SIZE(opencl_c_source) - 1)).release());
874 }
875 #else
876
877 Dl_info info;
878 if (dladdr((void *)clang::CompilerInvocation::CreateFromArgs, &info) == 0) {
879 clc_error(logger, "Couldn't find libclang path.\n");
880 return {};
881 }
882
883 char *clang_path = realpath(info.dli_fname, NULL);
884 if (clang_path == nullptr) {
885 clc_error(logger, "Couldn't find libclang path.\n");
886 return {};
887 }
888
889 // GetResourcePath is a way to retrieve the actual libclang resource dir based on a given binary
890 // or library.
891 auto tmp_res_path =
892 #if LLVM_VERSION_MAJOR >= 20
893 Driver::GetResourcesPath(std::string(clang_path));
894 #else
895 Driver::GetResourcesPath(std::string(clang_path), CLANG_RESOURCE_DIR);
896 #endif
897 auto clang_res_path = fs::path(tmp_res_path) / "include";
898
899 free(clang_path);
900
901 c->getHeaderSearchOpts().UseBuiltinIncludes = true;
902 c->getHeaderSearchOpts().UseStandardSystemIncludes = true;
903 c->getHeaderSearchOpts().ResourceDir = clang_res_path.string();
904
905 // Add opencl-c generic search path
906 c->getHeaderSearchOpts().AddPath(clang_res_path.string(),
907 clang::frontend::Angled,
908 false, false);
909
910 auto clang_install_res_path =
911 fs::path(LLVM_LIB_DIR) / "clang" / std::to_string(LLVM_VERSION_MAJOR) / "include";
912 c->getHeaderSearchOpts().AddPath(clang_install_res_path.string(),
913 clang::frontend::Angled,
914 false, false);
915 #endif
916
917 // Enable/Disable optional OpenCL C features. Some can be toggled via `OpenCLExtensionsAsWritten`
918 // others we have to (un)define via macros ourselves.
919
920 // Undefine clang added SPIR(V) defines so we don't magically enable extensions
921 c->getPreprocessorOpts().addMacroUndef("__SPIR__");
922 c->getPreprocessorOpts().addMacroUndef("__SPIRV__");
923
924 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("-all");
925 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_byte_addressable_store");
926 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_global_int32_base_atomics");
927 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_global_int32_extended_atomics");
928 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_local_int32_base_atomics");
929 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_local_int32_extended_atomics");
930 c->getPreprocessorOpts().addMacroDef("cl_khr_expect_assume=1");
931
932 bool needs_opencl_c_h = false;
933 if (args->features.fp16) {
934 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_fp16");
935 }
936 if (args->features.fp64) {
937 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_fp64");
938 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_fp64");
939 }
940 if (args->features.int64) {
941 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cles_khr_int64");
942 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_int64");
943 } else {
944 // clang defines this unconditionally, we need to fix that.
945 c->getPreprocessorOpts().addMacroUndef("__opencl_c_int64");
946 }
947 if (args->features.images) {
948 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_images");
949 } else {
950 // clang defines this unconditionally, we need to fix that.
951 c->getPreprocessorOpts().addMacroUndef("__IMAGE_SUPPORT__");
952 }
953 if (args->features.images_read_write) {
954 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_read_write_images");
955 }
956 if (args->features.images_write_3d) {
957 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_3d_image_writes");
958 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_3d_image_writes");
959 }
960 if (args->features.images_depth) {
961 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_depth_images");
962 }
963 if (args->features.images_gl_depth) {
964 c->getPreprocessorOpts().addMacroDef("cl_khr_gl_depth_images=1");
965 }
966 if (args->features.images_mipmap) {
967 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_mipmap_image");
968 }
969 if (args->features.images_mipmap_writes) {
970 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_mipmap_image_writes");
971 }
972 if (args->features.images_gl_msaa) {
973 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_gl_msaa_sharing");
974 }
975 if (args->features.intel_subgroups) {
976 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_intel_subgroups");
977 needs_opencl_c_h = true;
978 }
979 if (args->features.subgroups) {
980 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+__opencl_c_subgroups");
981 if (args->features.subgroups_shuffle) {
982 c->getPreprocessorOpts().addMacroDef("cl_khr_subgroup_shuffle=1");
983 }
984 if (args->features.subgroups_shuffle_relative) {
985 c->getPreprocessorOpts().addMacroDef("cl_khr_subgroup_shuffle_relative=1");
986 }
987 }
988 if (args->features.subgroups_ifp) {
989 assert(args->features.subgroups);
990 c->getTargetOpts().OpenCLExtensionsAsWritten.push_back("+cl_khr_subgroups");
991 }
992 if (args->features.integer_dot_product) {
993 c->getPreprocessorOpts().addMacroDef("cl_khr_integer_dot_product=1");
994 c->getPreprocessorOpts().addMacroDef("__opencl_c_integer_dot_product_input_4x8bit_packed=1");
995 c->getPreprocessorOpts().addMacroDef("__opencl_c_integer_dot_product_input_4x8bit=1");
996 }
997
998 // Add opencl include
999 c->getPreprocessorOpts().Includes.push_back("opencl-c-base.h");
1000 if (needs_opencl_c_h) {
1001 c->getPreprocessorOpts().Includes.push_back("opencl-c.h");
1002 }
1003
1004 if (args->num_headers) {
1005 ::llvm::SmallString<128> tmp_header_path;
1006 ::llvm::sys::path::system_temp_directory(true, tmp_header_path);
1007 ::llvm::sys::path::append(tmp_header_path, "openclon12");
1008
1009 c->getHeaderSearchOpts().AddPath(tmp_header_path.str(),
1010 clang::frontend::Quoted,
1011 false, false);
1012
1013 for (size_t i = 0; i < args->num_headers; i++) {
1014 auto path_copy = tmp_header_path;
1015 ::llvm::sys::path::append(path_copy, ::llvm::sys::path::convert_to_slash(args->headers[i].name));
1016 c->getPreprocessorOpts().addRemappedFile(path_copy.str(),
1017 ::llvm::MemoryBuffer::getMemBufferCopy(args->headers[i].value).release());
1018 }
1019 }
1020
1021 c->getPreprocessorOpts().addRemappedFile(
1022 args->source.name,
1023 ::llvm::MemoryBuffer::getMemBufferCopy(std::string(args->source.value)).release());
1024
1025 // Compile the code
1026 clang::EmitLLVMOnlyAction act(&llvm_ctx);
1027 if (!c->ExecuteAction(act)) {
1028 clc_error(logger, "%sError executing LLVM compilation action.\n",
1029 diag_log_str.c_str());
1030 return {};
1031 }
1032
1033 auto mod = act.takeModule();
1034
1035 if (clc_debug_flags() & CLC_DEBUG_DUMP_LLVM)
1036 clc_dump_llvm(mod.get(), stdout);
1037
1038 return mod;
1039 }
1040
1041 static SPIRV::VersionNumber
1042 spirv_version_to_llvm_spirv_translator_version(enum clc_spirv_version version)
1043 {
1044 switch (version) {
1045 case CLC_SPIRV_VERSION_MAX: return SPIRV::VersionNumber::MaximumVersion;
1046 case CLC_SPIRV_VERSION_1_0: return SPIRV::VersionNumber::SPIRV_1_0;
1047 case CLC_SPIRV_VERSION_1_1: return SPIRV::VersionNumber::SPIRV_1_1;
1048 case CLC_SPIRV_VERSION_1_2: return SPIRV::VersionNumber::SPIRV_1_2;
1049 case CLC_SPIRV_VERSION_1_3: return SPIRV::VersionNumber::SPIRV_1_3;
1050 case CLC_SPIRV_VERSION_1_4: return SPIRV::VersionNumber::SPIRV_1_4;
1051 default: return invalid_spirv_trans_version;
1052 }
1053 }
1054
1055 static int
1056 llvm_mod_to_spirv(std::unique_ptr<::llvm::Module> mod,
1057 LLVMContext &context,
1058 const struct clc_compile_args *args,
1059 const struct clc_logger *logger,
1060 struct clc_binary *out_spirv)
1061 {
1062 std::string log;
1063
1064 SPIRV::VersionNumber version =
1065 spirv_version_to_llvm_spirv_translator_version(args->spirv_version);
1066 if (version == invalid_spirv_trans_version) {
1067 clc_error(logger, "Invalid/unsupported SPIRV specified.\n");
1068 return -1;
1069 }
1070
1071 const char *const *extensions = args->allowed_spirv_extensions;
1072 if (!extensions) {
1073 /* The SPIR-V parser doesn't handle all extensions */
1074 static const char *default_extensions[] = {
1075 "SPV_EXT_shader_atomic_float_add",
1076 "SPV_EXT_shader_atomic_float_min_max",
1077 "SPV_KHR_float_controls",
1078 NULL,
1079 };
1080 extensions = default_extensions;
1081 }
1082
1083 SPIRV::TranslatorOpts::ExtensionsStatusMap ext_map;
1084 for (int i = 0; extensions[i]; i++) {
1085 #define EXT(X) \
1086 if (strcmp(#X, extensions[i]) == 0) \
1087 ext_map.insert(std::make_pair(SPIRV::ExtensionID::X, true));
1088 #include "LLVMSPIRVLib/LLVMSPIRVExtensions.inc"
1089 #undef EXT
1090 }
1091 SPIRV::TranslatorOpts spirv_opts = SPIRV::TranslatorOpts(version, ext_map);
1092
1093 /* This was the default in 12.0 and older, but currently we'll fail to parse without this */
1094 spirv_opts.setPreserveOCLKernelArgTypeMetadataThroughString(true);
1095
1096 #if LLVM_VERSION_MAJOR >= 17
1097 if (args->use_llvm_spirv_target) {
1098 const char *triple = args->address_bits == 32 ? "spirv-unknown-unknown" : "spirv64-unknown-unknown";
1099 std::string error_msg("");
1100 auto target = TargetRegistry::lookupTarget(triple, error_msg);
1101 if (target) {
1102 auto TM = target->createTargetMachine(
1103 triple, "", "", {}, std::nullopt, std::nullopt,
1104 #if LLVM_VERSION_MAJOR >= 18
1105 ::llvm::CodeGenOptLevel::None
1106 #else
1107 ::llvm::CodeGenOpt::None
1108 #endif
1109 );
1110
1111 auto PM = PassManager();
1112 ::llvm::SmallVector<char> buf;
1113 auto OS = ::llvm::raw_svector_ostream(buf);
1114 TM->addPassesToEmitFile(
1115 PM, OS, nullptr,
1116 #if LLVM_VERSION_MAJOR >= 18
1117 ::llvm::CodeGenFileType::ObjectFile
1118 #else
1119 ::llvm::CGFT_ObjectFile
1120 #endif
1121 );
1122
1123 PM.run(*mod);
1124
1125 out_spirv->size = buf.size_in_bytes();
1126 out_spirv->data = malloc(out_spirv->size);
1127 memcpy(out_spirv->data, buf.data(), out_spirv->size);
1128 return 0;
1129 } else {
1130 clc_error(logger, "LLVM SPIR-V target not found.\n");
1131 return -1;
1132 }
1133 }
1134 #endif
1135
1136 std::ostringstream spv_stream;
1137 if (!::llvm::writeSpirv(mod.get(), spirv_opts, spv_stream, log)) {
1138 clc_error(logger, "%sTranslation from LLVM IR to SPIR-V failed.\n",
1139 log.c_str());
1140 return -1;
1141 }
1142
1143 const std::string spv_out = spv_stream.str();
1144 out_spirv->size = spv_out.size();
1145 out_spirv->data = malloc(out_spirv->size);
1146 memcpy(out_spirv->data, spv_out.data(), out_spirv->size);
1147
1148 return 0;
1149 }
1150
1151 int
1152 clc_c_to_spir(const struct clc_compile_args *args,
1153 const struct clc_logger *logger,
1154 struct clc_binary *out_spir)
1155 {
1156 clc_initialize_llvm();
1157
1158 LLVMContext llvm_ctx;
1159 llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
1160 const_cast<clc_logger *>(logger));
1161
1162 auto mod = clc_compile_to_llvm_module(llvm_ctx, args, logger);
1163 if (!mod)
1164 return -1;
1165
1166 ::llvm::SmallVector<char, 0> buffer;
1167 ::llvm::BitcodeWriter writer(buffer);
1168 writer.writeModule(*mod);
1169
1170 out_spir->size = buffer.size_in_bytes();
1171 out_spir->data = malloc(out_spir->size);
1172 memcpy(out_spir->data, buffer.data(), out_spir->size);
1173
1174 return 0;
1175 }
1176
1177 int
1178 clc_c_to_spirv(const struct clc_compile_args *args,
1179 const struct clc_logger *logger,
1180 struct clc_binary *out_spirv)
1181 {
1182 clc_initialize_llvm();
1183
1184 LLVMContext llvm_ctx;
1185 llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
1186 const_cast<clc_logger *>(logger));
1187
1188 auto mod = clc_compile_to_llvm_module(llvm_ctx, args, logger);
1189 if (!mod)
1190 return -1;
1191 return llvm_mod_to_spirv(std::move(mod), llvm_ctx, args, logger, out_spirv);
1192 }
1193
1194 int
1195 clc_spir_to_spirv(const struct clc_binary *in_spir,
1196 const struct clc_logger *logger,
1197 struct clc_binary *out_spirv)
1198 {
1199 clc_initialize_llvm();
1200
1201 LLVMContext llvm_ctx;
1202 llvm_ctx.setDiagnosticHandlerCallBack(llvm_log_handler,
1203 const_cast<clc_logger *>(logger));
1204
1205 ::llvm::StringRef spir_ref(static_cast<const char*>(in_spir->data), in_spir->size);
1206 auto mod = ::llvm::parseBitcodeFile(::llvm::MemoryBufferRef(spir_ref, "<spir>"), llvm_ctx);
1207 if (!mod)
1208 return -1;
1209
1210 return llvm_mod_to_spirv(std::move(mod.get()), llvm_ctx, NULL, logger, out_spirv);
1211 }
1212
1213 class SPIRVMessageConsumer {
1214 public:
1215 SPIRVMessageConsumer(const struct clc_logger *logger): logger(logger) {}
1216
1217 void operator()(spv_message_level_t level, const char *src,
1218 const spv_position_t &pos, const char *msg)
1219 {
1220 if (level == SPV_MSG_INFO || level == SPV_MSG_DEBUG)
1221 return;
1222
1223 std::ostringstream message;
1224 message << "(file=" << (src ? src : "input")
1225 << ",line=" << pos.line
1226 << ",column=" << pos.column
1227 << ",index=" << pos.index
1228 << "): " << msg << "\n";
1229
1230 if (level == SPV_MSG_WARNING)
1231 clc_warning(logger, "%s", message.str().c_str());
1232 else
1233 clc_error(logger, "%s", message.str().c_str());
1234 }
1235
1236 private:
1237 const struct clc_logger *logger;
1238 };
1239
1240 int
1241 clc_link_spirv_binaries(const struct clc_linker_args *args,
1242 const struct clc_logger *logger,
1243 struct clc_binary *out_spirv)
1244 {
1245 std::vector<std::vector<uint32_t>> binaries;
1246
1247 for (unsigned i = 0; i < args->num_in_objs; i++) {
1248 const uint32_t *data = static_cast<const uint32_t *>(args->in_objs[i]->data);
1249 std::vector<uint32_t> bin(data, data + (args->in_objs[i]->size / 4));
1250 binaries.push_back(bin);
1251 }
1252
1253 SPIRVMessageConsumer msgconsumer(logger);
1254 spvtools::Context context(spirv_target);
1255 context.SetMessageConsumer(msgconsumer);
1256 spvtools::LinkerOptions options;
1257 options.SetAllowPartialLinkage(args->create_library);
1258 #if defined(HAS_SPIRV_LINK_LLVM_WORKAROUND) && LLVM_VERSION_MAJOR >= 17
1259 options.SetAllowPtrTypeMismatch(true);
1260 #endif
1261 options.SetCreateLibrary(args->create_library);
1262 std::vector<uint32_t> linkingResult;
1263 spv_result_t status = spvtools::Link(context, binaries, &linkingResult, options);
1264 if (status != SPV_SUCCESS) {
1265 #if !defined(HAS_SPIRV_LINK_LLVM_WORKAROUND) && LLVM_VERSION_MAJOR >= 17
1266 clc_warning(logger, "SPIRV-Tools doesn't contain https://github.com/KhronosGroup/SPIRV-Tools/pull/5534\n");
1267 clc_warning(logger, "Please update in order to prevent spurious linking failures\n");
1268 #endif
1269 return -1;
1270 }
1271
1272 out_spirv->size = linkingResult.size() * 4;
1273 out_spirv->data = static_cast<uint32_t *>(malloc(out_spirv->size));
1274 memcpy(out_spirv->data, linkingResult.data(), out_spirv->size);
1275
1276 return 0;
1277 }
1278
1279 bool
1280 clc_validate_spirv(const struct clc_binary *spirv,
1281 const struct clc_logger *logger,
1282 const struct clc_validator_options *options)
1283 {
1284 SPIRVMessageConsumer msgconsumer(logger);
1285 spvtools::SpirvTools tools(spirv_target);
1286 tools.SetMessageConsumer(msgconsumer);
1287 spvtools::ValidatorOptions spirv_options;
1288 const uint32_t *data = static_cast<const uint32_t *>(spirv->data);
1289
1290 if (options) {
1291 spirv_options.SetUniversalLimit(
1292 spv_validator_limit_max_function_args,
1293 options->limit_max_function_arg);
1294 }
1295
1296 return tools.Validate(data, spirv->size / 4, spirv_options);
1297 }
1298
1299 int
1300 clc_spirv_specialize(const struct clc_binary *in_spirv,
1301 const struct clc_parsed_spirv *parsed_data,
1302 const struct clc_spirv_specialization_consts *consts,
1303 struct clc_binary *out_spirv)
1304 {
1305 std::unordered_map<uint32_t, std::vector<uint32_t>> spec_const_map;
1306 for (unsigned i = 0; i < consts->num_specializations; ++i) {
1307 unsigned id = consts->specializations[i].id;
1308 auto parsed_spec_const = std::find_if(parsed_data->spec_constants,
1309 parsed_data->spec_constants + parsed_data->num_spec_constants,
1310 [id](const clc_parsed_spec_constant &c) { return c.id == id; });
1311 assert(parsed_spec_const != parsed_data->spec_constants + parsed_data->num_spec_constants);
1312
1313 std::vector<uint32_t> words;
1314 switch (parsed_spec_const->type) {
1315 case CLC_SPEC_CONSTANT_BOOL:
1316 words.push_back(consts->specializations[i].value.b);
1317 break;
1318 case CLC_SPEC_CONSTANT_INT32:
1319 case CLC_SPEC_CONSTANT_UINT32:
1320 case CLC_SPEC_CONSTANT_FLOAT:
1321 words.push_back(consts->specializations[i].value.u32);
1322 break;
1323 case CLC_SPEC_CONSTANT_INT16:
1324 words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i16);
1325 break;
1326 case CLC_SPEC_CONSTANT_INT8:
1327 words.push_back((uint32_t)(int32_t)consts->specializations[i].value.i8);
1328 break;
1329 case CLC_SPEC_CONSTANT_UINT16:
1330 words.push_back((uint32_t)consts->specializations[i].value.u16);
1331 break;
1332 case CLC_SPEC_CONSTANT_UINT8:
1333 words.push_back((uint32_t)consts->specializations[i].value.u8);
1334 break;
1335 case CLC_SPEC_CONSTANT_DOUBLE:
1336 case CLC_SPEC_CONSTANT_INT64:
1337 case CLC_SPEC_CONSTANT_UINT64:
1338 words.resize(2);
1339 memcpy(words.data(), &consts->specializations[i].value.u64, 8);
1340 break;
1341 case CLC_SPEC_CONSTANT_UNKNOWN:
1342 assert(0);
1343 break;
1344 }
1345
1346 ASSERTED auto ret = spec_const_map.emplace(id, std::move(words));
1347 assert(ret.second);
1348 }
1349
1350 spvtools::Optimizer opt(spirv_target);
1351 opt.RegisterPass(spvtools::CreateSetSpecConstantDefaultValuePass(std::move(spec_const_map)));
1352
1353 std::vector<uint32_t> result;
1354 if (!opt.Run(static_cast<const uint32_t*>(in_spirv->data), in_spirv->size / 4, &result))
1355 return false;
1356
1357 out_spirv->size = result.size() * 4;
1358 out_spirv->data = malloc(out_spirv->size);
1359 memcpy(out_spirv->data, result.data(), out_spirv->size);
1360 return true;
1361 }
1362
1363 static void
1364 clc_dump_llvm(const llvm::Module *mod, FILE *f)
1365 {
1366 std::string out;
1367 raw_string_ostream os(out);
1368
1369 mod->print(os, nullptr);
1370 os.flush();
1371
1372 fwrite(out.c_str(), out.size(), 1, f);
1373 }
1374
1375 void
1376 clc_dump_spirv(const struct clc_binary *spvbin, FILE *f)
1377 {
1378 spvtools::SpirvTools tools(spirv_target);
1379 const uint32_t *data = static_cast<const uint32_t *>(spvbin->data);
1380 std::vector<uint32_t> bin(data, data + (spvbin->size / 4));
1381 std::string out;
1382 tools.Disassemble(bin, &out,
1383 SPV_BINARY_TO_TEXT_OPTION_INDENT |
1384 SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
1385 fwrite(out.c_str(), out.size(), 1, f);
1386 }
1387
1388 void
1389 clc_free_spir_binary(struct clc_binary *spir)
1390 {
1391 free(spir->data);
1392 }
1393
1394 void
1395 clc_free_spirv_binary(struct clc_binary *spvbin)
1396 {
1397 free(spvbin->data);
1398 }
1399
1400 void
1401 initialize_llvm_once(void)
1402 {
1403 LLVMInitializeAllTargets();
1404 LLVMInitializeAllTargetInfos();
1405 LLVMInitializeAllTargetMCs();
1406 LLVMInitializeAllAsmParsers();
1407 LLVMInitializeAllAsmPrinters();
1408 }
1409
1410 std::once_flag initialize_llvm_once_flag;
1411
1412 void
1413 clc_initialize_llvm(void)
1414 {
1415 std::call_once(initialize_llvm_once_flag,
1416 []() { initialize_llvm_once(); });
1417 }
1418