1 // Copyright (c) 2023 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "extract_source.h"
16
17 #include <cassert>
18 #include <string>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "source/opt/log.h"
23 #include "spirv-tools/libspirv.hpp"
24 #include "spirv/unified1/spirv.hpp"
25 #include "tools/util/cli_consumer.h"
26
27 namespace {
28
29 constexpr auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_6;
30
31 // Extract a string literal from a given range.
32 // Copies all the characters from `begin` to the first '\0' it encounters, while
33 // removing escape patterns.
34 // Not finding a '\0' before reaching `end` fails the extraction.
35 //
36 // Returns `true` if the extraction succeeded.
37 // `output` value is undefined if false is returned.
ExtractStringLiteral(const spv_position_t & loc,const char * begin,const char * end,std::string * output)38 spv_result_t ExtractStringLiteral(const spv_position_t& loc, const char* begin,
39 const char* end, std::string* output) {
40 size_t sourceLength = std::distance(begin, end);
41 std::string escapedString;
42 escapedString.resize(sourceLength);
43
44 size_t writeIndex = 0;
45 size_t readIndex = 0;
46 for (; readIndex < sourceLength; writeIndex++, readIndex++) {
47 const char read = begin[readIndex];
48 if (read == '\0') {
49 escapedString.resize(writeIndex);
50 output->append(escapedString);
51 return SPV_SUCCESS;
52 }
53
54 if (read == '\\') {
55 ++readIndex;
56 }
57 escapedString[writeIndex] = begin[readIndex];
58 }
59
60 spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
61 "Missing NULL terminator for literal string.");
62 return SPV_ERROR_INVALID_BINARY;
63 }
64
extractOpString(const spv_position_t & loc,const spv_parsed_instruction_t & instruction,std::string * output)65 spv_result_t extractOpString(const spv_position_t& loc,
66 const spv_parsed_instruction_t& instruction,
67 std::string* output) {
68 assert(output != nullptr);
69 assert(instruction.opcode == spv::Op::OpString);
70 if (instruction.num_operands != 2) {
71 spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
72 "Missing operands for OpString.");
73 return SPV_ERROR_INVALID_BINARY;
74 }
75
76 const auto& operand = instruction.operands[1];
77 const char* stringBegin =
78 reinterpret_cast<const char*>(instruction.words + operand.offset);
79 const char* stringEnd = reinterpret_cast<const char*>(
80 instruction.words + operand.offset + operand.num_words);
81 return ExtractStringLiteral(loc, stringBegin, stringEnd, output);
82 }
83
extractOpSourceContinued(const spv_position_t & loc,const spv_parsed_instruction_t & instruction,std::string * output)84 spv_result_t extractOpSourceContinued(
85 const spv_position_t& loc, const spv_parsed_instruction_t& instruction,
86 std::string* output) {
87 assert(output != nullptr);
88 assert(instruction.opcode == spv::Op::OpSourceContinued);
89 if (instruction.num_operands != 1) {
90 spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
91 "Missing operands for OpSourceContinued.");
92 return SPV_ERROR_INVALID_BINARY;
93 }
94
95 const auto& operand = instruction.operands[0];
96 const char* stringBegin =
97 reinterpret_cast<const char*>(instruction.words + operand.offset);
98 const char* stringEnd = reinterpret_cast<const char*>(
99 instruction.words + operand.offset + operand.num_words);
100 return ExtractStringLiteral(loc, stringBegin, stringEnd, output);
101 }
102
extractOpSource(const spv_position_t & loc,const spv_parsed_instruction_t & instruction,spv::Id * filename,std::string * code)103 spv_result_t extractOpSource(const spv_position_t& loc,
104 const spv_parsed_instruction_t& instruction,
105 spv::Id* filename, std::string* code) {
106 assert(filename != nullptr && code != nullptr);
107 assert(instruction.opcode == spv::Op::OpSource);
108 // OpCode [ Source Language | Version | File (optional) | Source (optional) ]
109 if (instruction.num_words < 3) {
110 spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
111 "Missing operands for OpSource.");
112 return SPV_ERROR_INVALID_BINARY;
113 }
114
115 *filename = 0;
116 *code = "";
117 if (instruction.num_words < 4) {
118 return SPV_SUCCESS;
119 }
120 *filename = instruction.words[3];
121
122 if (instruction.num_words < 5) {
123 return SPV_SUCCESS;
124 }
125
126 const char* stringBegin =
127 reinterpret_cast<const char*>(instruction.words + 4);
128 const char* stringEnd =
129 reinterpret_cast<const char*>(instruction.words + instruction.num_words);
130 return ExtractStringLiteral(loc, stringBegin, stringEnd, code);
131 }
132
133 } // namespace
134
ExtractSourceFromModule(const std::vector<uint32_t> & binary,std::unordered_map<std::string,std::string> * output)135 bool ExtractSourceFromModule(
136 const std::vector<uint32_t>& binary,
137 std::unordered_map<std::string, std::string>* output) {
138 auto context = spvtools::SpirvTools(kDefaultEnvironment);
139 context.SetMessageConsumer(spvtools::utils::CLIMessageConsumer);
140
141 // There is nothing valuable in the header.
142 spvtools::HeaderParser headerParser = [](const spv_endianness_t,
143 const spv_parsed_header_t&) {
144 return SPV_SUCCESS;
145 };
146
147 std::unordered_map<uint32_t, std::string> stringMap;
148 std::vector<std::pair<spv::Id, std::string>> sources;
149 spv::Op lastOpcode = spv::Op::OpMax;
150 size_t instructionIndex = 0;
151
152 spvtools::InstructionParser instructionParser =
153 [&stringMap, &sources, &lastOpcode,
154 &instructionIndex](const spv_parsed_instruction_t& instruction) {
155 const spv_position_t loc = {0, 0, instructionIndex + 1};
156 spv_result_t result = SPV_SUCCESS;
157
158 if (instruction.opcode == spv::Op::OpString) {
159 std::string content;
160 result = extractOpString(loc, instruction, &content);
161 if (result == SPV_SUCCESS) {
162 stringMap.emplace(instruction.result_id, std::move(content));
163 }
164 } else if (instruction.opcode == spv::Op::OpSource) {
165 spv::Id filenameId;
166 std::string code;
167 result = extractOpSource(loc, instruction, &filenameId, &code);
168 if (result == SPV_SUCCESS) {
169 sources.emplace_back(std::make_pair(filenameId, std::move(code)));
170 }
171 } else if (instruction.opcode == spv::Op::OpSourceContinued) {
172 if (lastOpcode != spv::Op::OpSource) {
173 spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
174 "OpSourceContinued MUST follow an OpSource.");
175 return SPV_ERROR_INVALID_BINARY;
176 }
177
178 assert(sources.size() > 0);
179 result = extractOpSourceContinued(loc, instruction,
180 &sources.back().second);
181 }
182
183 ++instructionIndex;
184 lastOpcode = static_cast<spv::Op>(instruction.opcode);
185 return result;
186 };
187
188 if (!context.Parse(binary, headerParser, instructionParser)) {
189 return false;
190 }
191
192 std::string defaultName = "unnamed-";
193 size_t unnamedCount = 0;
194 for (auto & [ id, code ] : sources) {
195 std::string filename;
196 const auto it = stringMap.find(id);
197 if (it == stringMap.cend() || it->second.empty()) {
198 filename = "unnamed-" + std::to_string(unnamedCount) + ".hlsl";
199 ++unnamedCount;
200 } else {
201 filename = it->second;
202 }
203
204 if (output->count(filename) != 0) {
205 spvtools::Error(spvtools::utils::CLIMessageConsumer, "", {},
206 "Source file name conflict.");
207 return false;
208 }
209 output->insert({filename, code});
210 }
211
212 return true;
213 }
214