xref: /aosp_15_r20/external/swiftshader/third_party/SPIRV-Tools/tools/objdump/extract_source.cpp (revision 03ce13f70fcc45d86ee91b7ee4cab1936a95046e)
1 // Copyright (c) 2023 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "extract_source.h"
16 
17 #include <cassert>
18 #include <string>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "source/opt/log.h"
23 #include "spirv-tools/libspirv.hpp"
24 #include "spirv/unified1/spirv.hpp"
25 #include "tools/util/cli_consumer.h"
26 
27 namespace {
28 
29 constexpr auto kDefaultEnvironment = SPV_ENV_UNIVERSAL_1_6;
30 
31 // Extract a string literal from a given range.
32 // Copies all the characters from `begin` to the first '\0' it encounters, while
33 // removing escape patterns.
34 // Not finding a '\0' before reaching `end` fails the extraction.
35 //
36 // Returns `true` if the extraction succeeded.
37 // `output` value is undefined if false is returned.
ExtractStringLiteral(const spv_position_t & loc,const char * begin,const char * end,std::string * output)38 spv_result_t ExtractStringLiteral(const spv_position_t& loc, const char* begin,
39                                   const char* end, std::string* output) {
40   size_t sourceLength = std::distance(begin, end);
41   std::string escapedString;
42   escapedString.resize(sourceLength);
43 
44   size_t writeIndex = 0;
45   size_t readIndex = 0;
46   for (; readIndex < sourceLength; writeIndex++, readIndex++) {
47     const char read = begin[readIndex];
48     if (read == '\0') {
49       escapedString.resize(writeIndex);
50       output->append(escapedString);
51       return SPV_SUCCESS;
52     }
53 
54     if (read == '\\') {
55       ++readIndex;
56     }
57     escapedString[writeIndex] = begin[readIndex];
58   }
59 
60   spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
61                   "Missing NULL terminator for literal string.");
62   return SPV_ERROR_INVALID_BINARY;
63 }
64 
extractOpString(const spv_position_t & loc,const spv_parsed_instruction_t & instruction,std::string * output)65 spv_result_t extractOpString(const spv_position_t& loc,
66                              const spv_parsed_instruction_t& instruction,
67                              std::string* output) {
68   assert(output != nullptr);
69   assert(instruction.opcode == spv::Op::OpString);
70   if (instruction.num_operands != 2) {
71     spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
72                     "Missing operands for OpString.");
73     return SPV_ERROR_INVALID_BINARY;
74   }
75 
76   const auto& operand = instruction.operands[1];
77   const char* stringBegin =
78       reinterpret_cast<const char*>(instruction.words + operand.offset);
79   const char* stringEnd = reinterpret_cast<const char*>(
80       instruction.words + operand.offset + operand.num_words);
81   return ExtractStringLiteral(loc, stringBegin, stringEnd, output);
82 }
83 
extractOpSourceContinued(const spv_position_t & loc,const spv_parsed_instruction_t & instruction,std::string * output)84 spv_result_t extractOpSourceContinued(
85     const spv_position_t& loc, const spv_parsed_instruction_t& instruction,
86     std::string* output) {
87   assert(output != nullptr);
88   assert(instruction.opcode == spv::Op::OpSourceContinued);
89   if (instruction.num_operands != 1) {
90     spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
91                     "Missing operands for OpSourceContinued.");
92     return SPV_ERROR_INVALID_BINARY;
93   }
94 
95   const auto& operand = instruction.operands[0];
96   const char* stringBegin =
97       reinterpret_cast<const char*>(instruction.words + operand.offset);
98   const char* stringEnd = reinterpret_cast<const char*>(
99       instruction.words + operand.offset + operand.num_words);
100   return ExtractStringLiteral(loc, stringBegin, stringEnd, output);
101 }
102 
extractOpSource(const spv_position_t & loc,const spv_parsed_instruction_t & instruction,spv::Id * filename,std::string * code)103 spv_result_t extractOpSource(const spv_position_t& loc,
104                              const spv_parsed_instruction_t& instruction,
105                              spv::Id* filename, std::string* code) {
106   assert(filename != nullptr && code != nullptr);
107   assert(instruction.opcode == spv::Op::OpSource);
108   // OpCode [ Source Language | Version | File (optional) | Source (optional) ]
109   if (instruction.num_words < 3) {
110     spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
111                     "Missing operands for OpSource.");
112     return SPV_ERROR_INVALID_BINARY;
113   }
114 
115   *filename = 0;
116   *code = "";
117   if (instruction.num_words < 4) {
118     return SPV_SUCCESS;
119   }
120   *filename = instruction.words[3];
121 
122   if (instruction.num_words < 5) {
123     return SPV_SUCCESS;
124   }
125 
126   const char* stringBegin =
127       reinterpret_cast<const char*>(instruction.words + 4);
128   const char* stringEnd =
129       reinterpret_cast<const char*>(instruction.words + instruction.num_words);
130   return ExtractStringLiteral(loc, stringBegin, stringEnd, code);
131 }
132 
133 }  // namespace
134 
ExtractSourceFromModule(const std::vector<uint32_t> & binary,std::unordered_map<std::string,std::string> * output)135 bool ExtractSourceFromModule(
136     const std::vector<uint32_t>& binary,
137     std::unordered_map<std::string, std::string>* output) {
138   auto context = spvtools::SpirvTools(kDefaultEnvironment);
139   context.SetMessageConsumer(spvtools::utils::CLIMessageConsumer);
140 
141   // There is nothing valuable in the header.
142   spvtools::HeaderParser headerParser = [](const spv_endianness_t,
143                                            const spv_parsed_header_t&) {
144     return SPV_SUCCESS;
145   };
146 
147   std::unordered_map<uint32_t, std::string> stringMap;
148   std::vector<std::pair<spv::Id, std::string>> sources;
149   spv::Op lastOpcode = spv::Op::OpMax;
150   size_t instructionIndex = 0;
151 
152   spvtools::InstructionParser instructionParser =
153       [&stringMap, &sources, &lastOpcode,
154        &instructionIndex](const spv_parsed_instruction_t& instruction) {
155         const spv_position_t loc = {0, 0, instructionIndex + 1};
156         spv_result_t result = SPV_SUCCESS;
157 
158         if (instruction.opcode == spv::Op::OpString) {
159           std::string content;
160           result = extractOpString(loc, instruction, &content);
161           if (result == SPV_SUCCESS) {
162             stringMap.emplace(instruction.result_id, std::move(content));
163           }
164         } else if (instruction.opcode == spv::Op::OpSource) {
165           spv::Id filenameId;
166           std::string code;
167           result = extractOpSource(loc, instruction, &filenameId, &code);
168           if (result == SPV_SUCCESS) {
169             sources.emplace_back(std::make_pair(filenameId, std::move(code)));
170           }
171         } else if (instruction.opcode == spv::Op::OpSourceContinued) {
172           if (lastOpcode != spv::Op::OpSource) {
173             spvtools::Error(spvtools::utils::CLIMessageConsumer, "", loc,
174                             "OpSourceContinued MUST follow an OpSource.");
175             return SPV_ERROR_INVALID_BINARY;
176           }
177 
178           assert(sources.size() > 0);
179           result = extractOpSourceContinued(loc, instruction,
180                                             &sources.back().second);
181         }
182 
183         ++instructionIndex;
184         lastOpcode = static_cast<spv::Op>(instruction.opcode);
185         return result;
186       };
187 
188   if (!context.Parse(binary, headerParser, instructionParser)) {
189     return false;
190   }
191 
192   std::string defaultName = "unnamed-";
193   size_t unnamedCount = 0;
194   for (auto & [ id, code ] : sources) {
195     std::string filename;
196     const auto it = stringMap.find(id);
197     if (it == stringMap.cend() || it->second.empty()) {
198       filename = "unnamed-" + std::to_string(unnamedCount) + ".hlsl";
199       ++unnamedCount;
200     } else {
201       filename = it->second;
202     }
203 
204     if (output->count(filename) != 0) {
205       spvtools::Error(spvtools::utils::CLIMessageConsumer, "", {},
206                       "Source file name conflict.");
207       return false;
208     }
209     output->insert({filename, code});
210   }
211 
212   return true;
213 }
214