1 //===- DXILResource.cpp - DXIL Resource helper objects --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains helper objects for working with DXIL Resources.
10 ///
11 //===----------------------------------------------------------------------===//
12
13 #include "DXILResource.h"
14 #include "CBufferDataLayout.h"
15 #include "llvm/ADT/StringSwitch.h"
16 #include "llvm/IR/IRBuilder.h"
17 #include "llvm/IR/Metadata.h"
18 #include "llvm/IR/Module.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/Format.h"
21
22 using namespace llvm;
23 using namespace llvm::dxil;
24 using namespace llvm::hlsl;
25
collect(Module & M)26 template <typename T> void ResourceTable<T>::collect(Module &M) {
27 NamedMDNode *Entry = M.getNamedMetadata(MDName);
28 if (!Entry || Entry->getNumOperands() == 0)
29 return;
30
31 uint32_t Counter = 0;
32 for (auto *Res : Entry->operands()) {
33 Data.push_back(T(Counter++, FrontendResource(cast<MDNode>(Res))));
34 }
35 }
36
collect(Module & M)37 template <> void ResourceTable<ConstantBuffer>::collect(Module &M) {
38 NamedMDNode *Entry = M.getNamedMetadata(MDName);
39 if (!Entry || Entry->getNumOperands() == 0)
40 return;
41
42 uint32_t Counter = 0;
43 for (auto *Res : Entry->operands()) {
44 Data.push_back(
45 ConstantBuffer(Counter++, FrontendResource(cast<MDNode>(Res))));
46 }
47 // FIXME: share CBufferDataLayout with CBuffer load lowering.
48 // See https://github.com/llvm/llvm-project/issues/58381
49 CBufferDataLayout CBDL(M.getDataLayout(), /*IsLegacy*/ true);
50 for (auto &CB : Data)
51 CB.setSize(CBDL);
52 }
53
collect(Module & M)54 void Resources::collect(Module &M) {
55 UAVs.collect(M);
56 CBuffers.collect(M);
57 }
58
ResourceBase(uint32_t I,FrontendResource R)59 ResourceBase::ResourceBase(uint32_t I, FrontendResource R)
60 : ID(I), GV(R.getGlobalVariable()), Name(""), Space(R.getSpace()),
61 LowerBound(R.getResourceIndex()), RangeSize(1) {
62 if (auto *ArrTy = dyn_cast<ArrayType>(GV->getValueType()))
63 RangeSize = ArrTy->getNumElements();
64 }
65
getComponentTypeName(ComponentType CompType)66 StringRef ResourceBase::getComponentTypeName(ComponentType CompType) {
67 switch (CompType) {
68 case ComponentType::LastEntry:
69 case ComponentType::Invalid:
70 return "invalid";
71 case ComponentType::I1:
72 return "i1";
73 case ComponentType::I16:
74 return "i16";
75 case ComponentType::U16:
76 return "u16";
77 case ComponentType::I32:
78 return "i32";
79 case ComponentType::U32:
80 return "u32";
81 case ComponentType::I64:
82 return "i64";
83 case ComponentType::U64:
84 return "u64";
85 case ComponentType::F16:
86 return "f16";
87 case ComponentType::F32:
88 return "f32";
89 case ComponentType::F64:
90 return "f64";
91 case ComponentType::SNormF16:
92 return "snorm_f16";
93 case ComponentType::UNormF16:
94 return "unorm_f16";
95 case ComponentType::SNormF32:
96 return "snorm_f32";
97 case ComponentType::UNormF32:
98 return "unorm_f32";
99 case ComponentType::SNormF64:
100 return "snorm_f64";
101 case ComponentType::UNormF64:
102 return "unorm_f64";
103 case ComponentType::PackedS8x32:
104 return "p32i8";
105 case ComponentType::PackedU8x32:
106 return "p32u8";
107 }
108 }
109
printComponentType(Kinds Kind,ComponentType CompType,unsigned Alignment,raw_ostream & OS)110 void ResourceBase::printComponentType(Kinds Kind, ComponentType CompType,
111 unsigned Alignment, raw_ostream &OS) {
112 switch (Kind) {
113 default:
114 // TODO: add vector size.
115 OS << right_justify(getComponentTypeName(CompType), Alignment);
116 break;
117 case Kinds::RawBuffer:
118 OS << right_justify("byte", Alignment);
119 break;
120 case Kinds::StructuredBuffer:
121 OS << right_justify("struct", Alignment);
122 break;
123 case Kinds::CBuffer:
124 case Kinds::Sampler:
125 OS << right_justify("NA", Alignment);
126 break;
127 case Kinds::Invalid:
128 case Kinds::NumEntries:
129 break;
130 }
131 }
132
getKindName(Kinds Kind)133 StringRef ResourceBase::getKindName(Kinds Kind) {
134 switch (Kind) {
135 case Kinds::NumEntries:
136 case Kinds::Invalid:
137 return "invalid";
138 case Kinds::Texture1D:
139 return "1d";
140 case Kinds::Texture2D:
141 return "2d";
142 case Kinds::Texture2DMS:
143 return "2dMS";
144 case Kinds::Texture3D:
145 return "3d";
146 case Kinds::TextureCube:
147 return "cube";
148 case Kinds::Texture1DArray:
149 return "1darray";
150 case Kinds::Texture2DArray:
151 return "2darray";
152 case Kinds::Texture2DMSArray:
153 return "2darrayMS";
154 case Kinds::TextureCubeArray:
155 return "cubearray";
156 case Kinds::TypedBuffer:
157 return "buf";
158 case Kinds::RawBuffer:
159 return "rawbuf";
160 case Kinds::StructuredBuffer:
161 return "structbuf";
162 case Kinds::CBuffer:
163 return "cbuffer";
164 case Kinds::Sampler:
165 return "sampler";
166 case Kinds::TBuffer:
167 return "tbuffer";
168 case Kinds::RTAccelerationStructure:
169 return "ras";
170 case Kinds::FeedbackTexture2D:
171 return "fbtex2d";
172 case Kinds::FeedbackTexture2DArray:
173 return "fbtex2darray";
174 }
175 }
176
printKind(Kinds Kind,unsigned Alignment,raw_ostream & OS,bool SRV,bool HasCounter,uint32_t SampleCount)177 void ResourceBase::printKind(Kinds Kind, unsigned Alignment, raw_ostream &OS,
178 bool SRV, bool HasCounter, uint32_t SampleCount) {
179 switch (Kind) {
180 default:
181 OS << right_justify(getKindName(Kind), Alignment);
182 break;
183
184 case Kinds::RawBuffer:
185 case Kinds::StructuredBuffer:
186 if (SRV)
187 OS << right_justify("r/o", Alignment);
188 else {
189 if (!HasCounter)
190 OS << right_justify("r/w", Alignment);
191 else
192 OS << right_justify("r/w+cnt", Alignment);
193 }
194 break;
195 case Kinds::TypedBuffer:
196 OS << right_justify("buf", Alignment);
197 break;
198 case Kinds::Texture2DMS:
199 case Kinds::Texture2DMSArray: {
200 std::string DimName = getKindName(Kind).str();
201 if (SampleCount)
202 DimName += std::to_string(SampleCount);
203 OS << right_justify(DimName, Alignment);
204 } break;
205 case Kinds::CBuffer:
206 case Kinds::Sampler:
207 OS << right_justify("NA", Alignment);
208 break;
209 case Kinds::Invalid:
210 case Kinds::NumEntries:
211 break;
212 }
213 }
214
print(raw_ostream & OS,StringRef IDPrefix,StringRef BindingPrefix) const215 void ResourceBase::print(raw_ostream &OS, StringRef IDPrefix,
216 StringRef BindingPrefix) const {
217 std::string ResID = IDPrefix.str();
218 ResID += std::to_string(ID);
219 OS << right_justify(ResID, 8);
220
221 std::string Bind = BindingPrefix.str();
222 Bind += std::to_string(LowerBound);
223 if (Space)
224 Bind += ",space" + std::to_string(Space);
225
226 OS << right_justify(Bind, 15);
227 if (RangeSize != UINT_MAX)
228 OS << right_justify(std::to_string(RangeSize), 6) << "\n";
229 else
230 OS << right_justify("unbounded", 6) << "\n";
231 }
232
UAVResource(uint32_t I,FrontendResource R)233 UAVResource::UAVResource(uint32_t I, FrontendResource R)
234 : ResourceBase(I, R),
235 Shape(static_cast<ResourceBase::Kinds>(R.getResourceKind())),
236 GloballyCoherent(false), HasCounter(false), IsROV(false), ExtProps() {
237 parseSourceType(R.getSourceType());
238 }
239
print(raw_ostream & OS) const240 void UAVResource::print(raw_ostream &OS) const {
241 OS << "; " << left_justify(Name, 31);
242
243 OS << right_justify("UAV", 10);
244
245 printComponentType(
246 Shape, ExtProps.ElementType.value_or(ComponentType::Invalid), 8, OS);
247
248 // FIXME: support SampleCount.
249 // See https://github.com/llvm/llvm-project/issues/58175
250 printKind(Shape, 12, OS, /*SRV*/ false, HasCounter);
251 // Print the binding part.
252 ResourceBase::print(OS, "U", "u");
253 }
254
255 // FIXME: Capture this in HLSL source. I would go do this right now, but I want
256 // to get this in first so that I can make sure to capture all the extra
257 // information we need to remove the source type string from here (See issue:
258 // https://github.com/llvm/llvm-project/issues/57991).
parseSourceType(StringRef S)259 void UAVResource::parseSourceType(StringRef S) {
260 IsROV = S.startswith("RasterizerOrdered");
261 if (IsROV)
262 S = S.substr(strlen("RasterizerOrdered"));
263 if (S.startswith("RW"))
264 S = S.substr(strlen("RW"));
265
266 // Note: I'm deliberately not handling any of the Texture buffer types at the
267 // moment. I want to resolve the issue above before adding Texture or Sampler
268 // support.
269 Shape = StringSwitch<ResourceBase::Kinds>(S)
270 .StartsWith("Buffer<", Kinds::TypedBuffer)
271 .StartsWith("ByteAddressBuffer<", Kinds::RawBuffer)
272 .StartsWith("StructuredBuffer<", Kinds::StructuredBuffer)
273 .Default(Kinds::Invalid);
274 assert(Shape != Kinds::Invalid && "Unsupported buffer type");
275
276 S = S.substr(S.find("<") + 1);
277
278 constexpr size_t PrefixLen = StringRef("vector<").size();
279 if (S.startswith("vector<"))
280 S = S.substr(PrefixLen, S.find(",") - PrefixLen);
281 else
282 S = S.substr(0, S.find(">"));
283
284 ComponentType ElTy = StringSwitch<ResourceBase::ComponentType>(S)
285 .Case("bool", ComponentType::I1)
286 .Case("int16_t", ComponentType::I16)
287 .Case("uint16_t", ComponentType::U16)
288 .Case("int32_t", ComponentType::I32)
289 .Case("uint32_t", ComponentType::U32)
290 .Case("int64_t", ComponentType::I64)
291 .Case("uint64_t", ComponentType::U64)
292 .Case("half", ComponentType::F16)
293 .Case("float", ComponentType::F32)
294 .Case("double", ComponentType::F64)
295 .Default(ComponentType::Invalid);
296 if (ElTy != ComponentType::Invalid)
297 ExtProps.ElementType = ElTy;
298 }
299
ConstantBuffer(uint32_t I,hlsl::FrontendResource R)300 ConstantBuffer::ConstantBuffer(uint32_t I, hlsl::FrontendResource R)
301 : ResourceBase(I, R) {}
302
setSize(CBufferDataLayout & DL)303 void ConstantBuffer::setSize(CBufferDataLayout &DL) {
304 CBufferSizeInBytes = DL.getTypeAllocSizeInBytes(GV->getValueType());
305 }
306
print(raw_ostream & OS) const307 void ConstantBuffer::print(raw_ostream &OS) const {
308 OS << "; " << left_justify(Name, 31);
309
310 OS << right_justify("cbuffer", 10);
311
312 printComponentType(Kinds::CBuffer, ComponentType::Invalid, 8, OS);
313
314 printKind(Kinds::CBuffer, 12, OS, /*SRV*/ false, /*HasCounter*/ false);
315 // Print the binding part.
316 ResourceBase::print(OS, "CB", "cb");
317 }
318
print(raw_ostream & OS) const319 template <typename T> void ResourceTable<T>::print(raw_ostream &OS) const {
320 for (auto &Res : Data)
321 Res.print(OS);
322 }
323
write(LLVMContext & Ctx) const324 MDNode *ResourceBase::ExtendedProperties::write(LLVMContext &Ctx) const {
325 IRBuilder<> B(Ctx);
326 SmallVector<Metadata *> Entries;
327 if (ElementType) {
328 Entries.emplace_back(
329 ConstantAsMetadata::get(B.getInt32(TypedBufferElementType)));
330 Entries.emplace_back(ConstantAsMetadata::get(
331 B.getInt32(static_cast<uint32_t>(*ElementType))));
332 }
333 if (Entries.empty())
334 return nullptr;
335 return MDNode::get(Ctx, Entries);
336 }
337
write(LLVMContext & Ctx,MutableArrayRef<Metadata * > Entries) const338 void ResourceBase::write(LLVMContext &Ctx,
339 MutableArrayRef<Metadata *> Entries) const {
340 IRBuilder<> B(Ctx);
341 Entries[0] = ConstantAsMetadata::get(B.getInt32(ID));
342 Entries[1] = ConstantAsMetadata::get(GV);
343 Entries[2] = MDString::get(Ctx, Name);
344 Entries[3] = ConstantAsMetadata::get(B.getInt32(Space));
345 Entries[4] = ConstantAsMetadata::get(B.getInt32(LowerBound));
346 Entries[5] = ConstantAsMetadata::get(B.getInt32(RangeSize));
347 }
348
write() const349 MDNode *UAVResource::write() const {
350 auto &Ctx = GV->getContext();
351 IRBuilder<> B(Ctx);
352 Metadata *Entries[11];
353 ResourceBase::write(Ctx, Entries);
354 Entries[6] =
355 ConstantAsMetadata::get(B.getInt32(static_cast<uint32_t>(Shape)));
356 Entries[7] = ConstantAsMetadata::get(B.getInt1(GloballyCoherent));
357 Entries[8] = ConstantAsMetadata::get(B.getInt1(HasCounter));
358 Entries[9] = ConstantAsMetadata::get(B.getInt1(IsROV));
359 Entries[10] = ExtProps.write(Ctx);
360 return MDNode::get(Ctx, Entries);
361 }
362
write() const363 MDNode *ConstantBuffer::write() const {
364 auto &Ctx = GV->getContext();
365 IRBuilder<> B(Ctx);
366 Metadata *Entries[7];
367 ResourceBase::write(Ctx, Entries);
368
369 Entries[6] = ConstantAsMetadata::get(B.getInt32(CBufferSizeInBytes));
370 return MDNode::get(Ctx, Entries);
371 }
372
write(Module & M) const373 template <typename T> MDNode *ResourceTable<T>::write(Module &M) const {
374 if (Data.empty())
375 return nullptr;
376 SmallVector<Metadata *> MDs;
377 for (auto &Res : Data)
378 MDs.emplace_back(Res.write());
379
380 NamedMDNode *Entry = M.getNamedMetadata(MDName);
381 if (Entry)
382 Entry->eraseFromParent();
383
384 return MDNode::get(M.getContext(), MDs);
385 }
386
write(Module & M) const387 void Resources::write(Module &M) const {
388 Metadata *ResourceMDs[4] = {nullptr, nullptr, nullptr, nullptr};
389
390 ResourceMDs[1] = UAVs.write(M);
391
392 ResourceMDs[2] = CBuffers.write(M);
393
394 bool HasResource = ResourceMDs[0] != nullptr || ResourceMDs[1] != nullptr ||
395 ResourceMDs[2] != nullptr || ResourceMDs[3] != nullptr;
396
397 if (HasResource) {
398 NamedMDNode *DXResMD = M.getOrInsertNamedMetadata("dx.resources");
399 DXResMD->addOperand(MDNode::get(M.getContext(), ResourceMDs));
400 }
401
402 NamedMDNode *Entry = M.getNamedMetadata("hlsl.uavs");
403 if (Entry)
404 Entry->eraseFromParent();
405 }
406
print(raw_ostream & O) const407 void Resources::print(raw_ostream &O) const {
408 O << ";\n"
409 << "; Resource Bindings:\n"
410 << ";\n"
411 << "; Name Type Format Dim "
412 "ID HLSL Bind Count\n"
413 << "; ------------------------------ ---------- ------- ----------- "
414 "------- -------------- ------\n";
415
416 CBuffers.print(O);
417 UAVs.print(O);
418 }
419
dump() const420 void Resources::dump() const { print(dbgs()); }
421