1 // Copyright (c) 2016 The vulkano developers
2 // Licensed under the Apache License, Version 2.0
3 // <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT
5 // license <LICENSE-MIT or https://opensource.org/licenses/MIT>,
6 // at your option. All files in the project carrying such
7 // notice may not be copied, modified, or distributed except
8 // according to those terms.
9 
10 //! A program that is run on the device.
11 //!
12 //! In Vulkan, shaders are grouped in *shader modules*. Each shader module is built from SPIR-V
13 //! code and can contain one or more entry points. Note that for the moment the official
14 //! GLSL-to-SPIR-V compiler does not support multiple entry points.
15 //!
16 //! The vulkano library can parse and introspect SPIR-V code, but it does not fully validate the
17 //! code. You are encouraged to use the `vulkano-shaders` crate that will generate Rust code that
18 //! wraps around vulkano's shaders API.
19 //!
20 //! # Shader interface
21 //!
22 //! Vulkan has specific rules for interfacing shaders with each other, and with other parts
23 //! of a program.
24 //!
25 //! ## Endianness
26 //!
27 //! The Vulkan specification requires that a Vulkan implementation has runtime support for the
28 //! types [`u8`], [`u16`], [`u32`], [`u64`] as well as their signed versions, as well as [`f32`]
29 //! and [`f64`] on the host, and that the representation and endianness of these types matches
30 //! those on the device. This means that if you have for example a `Subbuffer<u32>`, you can be
31 //! sure that it is represented the same way on the host as it is on the device, and you don't need
32 //! to worry about converting the endianness.
33 //!
34 //! ## Layout of data
35 //!
36 //! When buffers, push constants or other user-provided data are accessed in shaders,
37 //! the shader expects the values inside to be laid out in a specific way. For every uniform buffer,
38 //! storage buffer or push constant block, the SPIR-V specification requires the SPIR-V code to
39 //! provide the `Offset` decoration for every member of a struct, indicating where it is placed
40 //! relative to the start of the struct. If there are arrays or matrices among the variables, the
41 //! SPIR-V code must also provide an `ArrayStride` or `MatrixStride` decoration for them,
42 //! indicating the number of bytes between the start of each element in the array or column in the
43 //! matrix. When providing data to shaders, you must make sure that your data is placed at the
44 //! locations indicated within the SPIR-V code, or the shader will read the wrong data and produce
45 //! nonsense.
46 //!
47 //! GLSL does not require you to give explicit offsets and/or strides to your variables (although
48 //! it has the option to provide them if you wish). Instead, the shader compiler automatically
49 //! assigns every variable an offset, increasing in the order you declare them in.
50 //! To know the exact offsets that will be used, so that you can lay out your data appropriately,
51 //! you must know the alignment rules that the shader compiler uses. The shader compiler will
52 //! always give a variable the smallest offset that fits the alignment rules and doesn't overlap
53 //! with the previous variable. The shader compiler uses default alignment rules depending on the
54 //! type of block, but you can specify another layout by using the `layout` qualifier.
55 //!
56 //! ## Alignment rules
57 //!
58 //! The offset of each variable from the start of a block, matrix or array must be a
59 //! multiple of a certain number, which is called its *alignment*. The stride of an array or matrix
60 //! must likewise be a multiple of this number. An alignment is always a power-of-two value.
61 //! Regardless of whether the offset/stride is provided manually in the compiled SPIR-V code,
62 //! or assigned automatically by the shader compiler, all variable offsets/strides in a shader must
63 //! follow these alignment rules.
64 //!
65 //! Three sets of [alignment rules] are supported by Vulkan. Each one has a GLSL qualifier that
66 //! you can place in front of a block, to make the shader compiler use that layout for the block.
67 //! If you don't provide this qualifier, it will use a default alignment.
68 //!
69 //! - **Scalar alignment** (GLSL qualifier: `layout(scalar)`, requires the
70 //!   [`GL_EXT_scalar_block_layout`] GLSL extension). This is the same as the C alignment,
71 //!   expressed in Rust with the
72 //!   [`#[repr(C)]`](https://doc.rust-lang.org/nomicon/other-reprs.html#reprc) attribute.
73 //!   The shader compiler does not use this alignment by default, so you must use the GLSL
74 //!   qualifier. You must also enable the [`scalar_block_layout`] feature in Vulkan.
75 //! - **Base alignment**, also known as **std430** (GLSL qualifier: `layout(std430)`).
76 //!   The shader compiler uses this alignment by default for all shader data except uniform buffers.
77 //!   If you use the base alignment for a uniform buffer, you must also enable the
78 //!   [`uniform_buffer_standard_layout`] feature in Vulkan.
79 //! - **Extended alignment**, also known as **std140** (GLSL qualifier: `layout(std140)`).
80 //!   The shader compiler uses this alignment by default for uniform buffers.
81 //!
82 //! Each alignment type is a subset of the ones above it, so if something adheres to the extended
83 //! alignment rules, it also follows the rules for the base and scalar alignments.
84 //!
85 //! In all three of these alignment rules, a primitive/scalar value with a size of N bytes has an
86 //! alignment of N, meaning that it must have an offset that is a multiple of its size,
87 //! like in C or Rust. For example, a `float` (like a Rust `f32`) has a size of 4 bytes,
88 //! and an alignment of 4.
89 //!
90 //! The differences between the alignment rules are in how compound types (vectors, matrices,
91 //! arrays and structs) are expected to be laid out. For a compound type with an element whose
92 //! alignment is N, the scalar alignment considers the alignment of the compound type to be also N.
93 //! However, the base and extended alignments are stricter:
94 //!
95 //! | GLSL type | Scalar          | Base            | Extended                 |
96 //! |-----------|-----------------|-----------------|--------------------------|
97 //! | primitive | N               | N               | N                        |
98 //! | `vec2`    | N               | N * 2           | N * 2                    |
99 //! | `vec3`    | N               | N * 4           | N * 4                    |
100 //! | `vec4`    | N               | N * 4           | N * 4                    |
101 //! | array     | N               | N               | max(N, 16)               |
102 //! | `struct`  | N<sub>max</sub> | N<sub>max</sub> | max(N<sub>max</sub>, 16) |
103 //!
104 //! In the base and extended alignment, the alignment of a vector is the size of the whole vector,
105 //! rather than the size of its individual elements as is the case in the scalar alignment.
106 //! But note that, because alignment must be a power of two, the alignment of `vec3` cannot be
107 //! N * 3; it must be N * 4, the same alignment as `vec4`. This means that it is not possible to
108 //! tightly pack multiple `vec3` values (e.g. in an array); there will always be empty padding
109 //! between them.
110 //!
111 //! In both the scalar and base alignment, the alignment of arrays and their elements is equal to
112 //! the alignment of the contained type. In the extended alignment, however, the alignment is
113 //! always at least 16 (the size of a `vec4`). Therefore, the minimum stride of the array can be
114 //! much greater than the element size. For example, in an array of `float`, the stride must be at
115 //! least 16, even though a `float` itself is only 4 bytes in size. Every `float` element will be
116 //! followed by at least 12 bytes of unused space.
117 //!
118 //! A matrix `matCxR` is considered equivalent to an array of column vectors `vecR[C]`.
119 //! In the base and extended alignments, that means that if the matrix has 3 rows, there will be
120 //! one element's worth of padding between the column vectors. In the extended alignment,
121 //! the alignment is also at least 16, further increasing the amount of padding between the
122 //! column vectors.
123 //!
124 //! The rules for `struct`s are similar to those of arrays. When the members of the struct have
125 //! different alignment requirements, the alignment of the struct as a whole is the maximum
126 //! of the alignments of its members. As with arrays, in the extended alignment, the alignment
127 //! of a struct is at least 16.
128 //!
129 //! [alignment rules]: <https://registry.khronos.org/vulkan/specs/1.3-extensions/html/chap15.html#interfaces-resources-layout>
130 //! [`GL_EXT_scalar_block_layout`]: <https://github.com/KhronosGroup/GLSL/blob/master/extensions/ext/GL_EXT_scalar_block_layout.txt>
131 //! [`scalar_block_layout`]: crate::device::Features::scalar_block_layout
132 //! [`uniform_buffer_standard_layout`]: crate::device::Features::uniform_buffer_standard_layout
133 
134 use crate::{
135     descriptor_set::layout::DescriptorType,
136     device::{Device, DeviceOwned},
137     format::{Format, NumericType},
138     image::view::ImageViewType,
139     macros::{impl_id_counter, vulkan_bitflags_enum},
140     pipeline::{graphics::input_assembly::PrimitiveTopology, layout::PushConstantRange},
141     shader::spirv::{Capability, Spirv, SpirvError},
142     sync::PipelineStages,
143     DeviceSize, OomError, Version, VulkanError, VulkanObject,
144 };
145 use ahash::{HashMap, HashSet};
146 use std::{
147     borrow::Cow,
148     collections::hash_map::Entry,
149     error::Error,
150     ffi::{CStr, CString},
151     fmt::{Display, Error as FmtError, Formatter},
152     mem,
153     mem::MaybeUninit,
154     num::NonZeroU64,
155     ptr,
156     sync::Arc,
157 };
158 
159 pub mod reflect;
160 pub mod spirv;
161 
162 use spirv::ExecutionModel;
163 
164 // Generated by build.rs
165 include!(concat!(env!("OUT_DIR"), "/spirv_reqs.rs"));
166 
167 /// Contains SPIR-V code with one or more entry points.
168 #[derive(Debug)]
169 pub struct ShaderModule {
170     handle: ash::vk::ShaderModule,
171     device: Arc<Device>,
172     id: NonZeroU64,
173     entry_points: HashMap<String, HashMap<ExecutionModel, EntryPointInfo>>,
174 }
175 
176 impl ShaderModule {
177     /// Builds a new shader module from SPIR-V 32-bit words. The shader code is parsed and the
178     /// necessary information is extracted from it.
179     ///
180     /// # Safety
181     ///
182     /// - The SPIR-V code is not validated beyond the minimum needed to extract the information.
183     #[inline]
from_words( device: Arc<Device>, words: &[u32], ) -> Result<Arc<ShaderModule>, ShaderCreationError>184     pub unsafe fn from_words(
185         device: Arc<Device>,
186         words: &[u32],
187     ) -> Result<Arc<ShaderModule>, ShaderCreationError> {
188         let spirv = Spirv::new(words)?;
189 
190         Self::from_words_with_data(
191             device,
192             words,
193             spirv.version(),
194             reflect::spirv_capabilities(&spirv),
195             reflect::spirv_extensions(&spirv),
196             reflect::entry_points(&spirv),
197         )
198     }
199 
200     /// As `from_words`, but takes a slice of bytes.
201     ///
202     /// # Panics
203     ///
204     /// - Panics if the length of `bytes` is not a multiple of 4.
205     #[inline]
from_bytes( device: Arc<Device>, bytes: &[u8], ) -> Result<Arc<ShaderModule>, ShaderCreationError>206     pub unsafe fn from_bytes(
207         device: Arc<Device>,
208         bytes: &[u8],
209     ) -> Result<Arc<ShaderModule>, ShaderCreationError> {
210         assert!((bytes.len() % 4) == 0);
211 
212         Self::from_words(
213             device,
214             std::slice::from_raw_parts(
215                 bytes.as_ptr() as *const _,
216                 bytes.len() / mem::size_of::<u32>(),
217             ),
218         )
219     }
220 
221     /// As `from_words`, but does not parse the code. Instead, you must provide the needed
222     /// information yourself. This can be useful if you've already done parsing yourself and
223     /// want to prevent Vulkano from doing it a second time.
224     ///
225     /// # Safety
226     ///
227     /// - The SPIR-V code is not validated at all.
228     /// - The provided information must match what the SPIR-V code contains.
from_words_with_data<'a>( device: Arc<Device>, words: &[u32], spirv_version: Version, spirv_capabilities: impl IntoIterator<Item = &'a Capability>, spirv_extensions: impl IntoIterator<Item = &'a str>, entry_points: impl IntoIterator<Item = (String, ExecutionModel, EntryPointInfo)>, ) -> Result<Arc<ShaderModule>, ShaderCreationError>229     pub unsafe fn from_words_with_data<'a>(
230         device: Arc<Device>,
231         words: &[u32],
232         spirv_version: Version,
233         spirv_capabilities: impl IntoIterator<Item = &'a Capability>,
234         spirv_extensions: impl IntoIterator<Item = &'a str>,
235         entry_points: impl IntoIterator<Item = (String, ExecutionModel, EntryPointInfo)>,
236     ) -> Result<Arc<ShaderModule>, ShaderCreationError> {
237         if let Err(reason) = check_spirv_version(&device, spirv_version) {
238             return Err(ShaderCreationError::SpirvVersionNotSupported {
239                 version: spirv_version,
240                 reason,
241             });
242         }
243 
244         for &capability in spirv_capabilities {
245             if let Err(reason) = check_spirv_capability(&device, capability) {
246                 return Err(ShaderCreationError::SpirvCapabilityNotSupported {
247                     capability,
248                     reason,
249                 });
250             }
251         }
252 
253         for extension in spirv_extensions {
254             if let Err(reason) = check_spirv_extension(&device, extension) {
255                 return Err(ShaderCreationError::SpirvExtensionNotSupported {
256                     extension: extension.to_owned(),
257                     reason,
258                 });
259             }
260         }
261 
262         let handle = {
263             let infos = ash::vk::ShaderModuleCreateInfo {
264                 flags: ash::vk::ShaderModuleCreateFlags::empty(),
265                 code_size: words.len() * mem::size_of::<u32>(),
266                 p_code: words.as_ptr(),
267                 ..Default::default()
268             };
269 
270             let fns = device.fns();
271             let mut output = MaybeUninit::uninit();
272             (fns.v1_0.create_shader_module)(
273                 device.handle(),
274                 &infos,
275                 ptr::null(),
276                 output.as_mut_ptr(),
277             )
278             .result()
279             .map_err(VulkanError::from)?;
280             output.assume_init()
281         };
282 
283         let entries = entry_points.into_iter().collect::<Vec<_>>();
284         let entry_points = entries
285             .iter()
286             .map(|(name, _, _)| name)
287             .collect::<HashSet<_>>()
288             .iter()
289             .map(|name| {
290                 (
291                     (*name).clone(),
292                     entries
293                         .iter()
294                         .filter_map(|(entry_name, entry_model, info)| {
295                             if &entry_name == name {
296                                 Some((*entry_model, info.clone()))
297                             } else {
298                                 None
299                             }
300                         })
301                         .collect::<HashMap<_, _>>(),
302                 )
303             })
304             .collect();
305 
306         Ok(Arc::new(ShaderModule {
307             handle,
308             device,
309             id: Self::next_id(),
310             entry_points,
311         }))
312     }
313 
314     /// As `from_words_with_data`, but takes a slice of bytes.
315     ///
316     /// # Panics
317     ///
318     /// - Panics if the length of `bytes` is not a multiple of 4.
from_bytes_with_data<'a>( device: Arc<Device>, bytes: &[u8], spirv_version: Version, spirv_capabilities: impl IntoIterator<Item = &'a Capability>, spirv_extensions: impl IntoIterator<Item = &'a str>, entry_points: impl IntoIterator<Item = (String, ExecutionModel, EntryPointInfo)>, ) -> Result<Arc<ShaderModule>, ShaderCreationError>319     pub unsafe fn from_bytes_with_data<'a>(
320         device: Arc<Device>,
321         bytes: &[u8],
322         spirv_version: Version,
323         spirv_capabilities: impl IntoIterator<Item = &'a Capability>,
324         spirv_extensions: impl IntoIterator<Item = &'a str>,
325         entry_points: impl IntoIterator<Item = (String, ExecutionModel, EntryPointInfo)>,
326     ) -> Result<Arc<ShaderModule>, ShaderCreationError> {
327         assert!((bytes.len() % 4) == 0);
328 
329         Self::from_words_with_data(
330             device,
331             std::slice::from_raw_parts(
332                 bytes.as_ptr() as *const _,
333                 bytes.len() / mem::size_of::<u32>(),
334             ),
335             spirv_version,
336             spirv_capabilities,
337             spirv_extensions,
338             entry_points,
339         )
340     }
341 
342     /// Returns information about the entry point with the provided name. Returns `None` if no entry
343     /// point with that name exists in the shader module or if multiple entry points with the same
344     /// name exist.
345     #[inline]
entry_point<'a>(&'a self, name: &str) -> Option<EntryPoint<'a>>346     pub fn entry_point<'a>(&'a self, name: &str) -> Option<EntryPoint<'a>> {
347         self.entry_points.get(name).and_then(|infos| {
348             if infos.len() == 1 {
349                 infos.iter().next().map(|(_, info)| EntryPoint {
350                     module: self,
351                     name: CString::new(name).unwrap(),
352                     info,
353                 })
354             } else {
355                 None
356             }
357         })
358     }
359 
360     /// Returns information about the entry point with the provided name and execution model.
361     /// Returns `None` if no entry and execution model exists in the shader module.
362     #[inline]
entry_point_with_execution<'a>( &'a self, name: &str, execution: ExecutionModel, ) -> Option<EntryPoint<'a>>363     pub fn entry_point_with_execution<'a>(
364         &'a self,
365         name: &str,
366         execution: ExecutionModel,
367     ) -> Option<EntryPoint<'a>> {
368         self.entry_points.get(name).and_then(|infos| {
369             infos.get(&execution).map(|info| EntryPoint {
370                 module: self,
371                 name: CString::new(name).unwrap(),
372                 info,
373             })
374         })
375     }
376 }
377 
378 impl Drop for ShaderModule {
379     #[inline]
drop(&mut self)380     fn drop(&mut self) {
381         unsafe {
382             let fns = self.device.fns();
383             (fns.v1_0.destroy_shader_module)(self.device.handle(), self.handle, ptr::null());
384         }
385     }
386 }
387 
388 unsafe impl VulkanObject for ShaderModule {
389     type Handle = ash::vk::ShaderModule;
390 
391     #[inline]
handle(&self) -> Self::Handle392     fn handle(&self) -> Self::Handle {
393         self.handle
394     }
395 }
396 
397 unsafe impl DeviceOwned for ShaderModule {
398     #[inline]
device(&self) -> &Arc<Device>399     fn device(&self) -> &Arc<Device> {
400         &self.device
401     }
402 }
403 
404 impl_id_counter!(ShaderModule);
405 
406 /// Error that can happen when creating a new shader module.
407 #[derive(Clone, Debug)]
408 pub enum ShaderCreationError {
409     OomError(OomError),
410     SpirvCapabilityNotSupported {
411         capability: Capability,
412         reason: ShaderSupportError,
413     },
414     SpirvError(SpirvError),
415     SpirvExtensionNotSupported {
416         extension: String,
417         reason: ShaderSupportError,
418     },
419     SpirvVersionNotSupported {
420         version: Version,
421         reason: ShaderSupportError,
422     },
423 }
424 
425 impl Error for ShaderCreationError {
source(&self) -> Option<&(dyn Error + 'static)>426     fn source(&self) -> Option<&(dyn Error + 'static)> {
427         match self {
428             Self::OomError(err) => Some(err),
429             Self::SpirvCapabilityNotSupported { reason, .. } => Some(reason),
430             Self::SpirvError(err) => Some(err),
431             Self::SpirvExtensionNotSupported { reason, .. } => Some(reason),
432             Self::SpirvVersionNotSupported { reason, .. } => Some(reason),
433         }
434     }
435 }
436 
437 impl Display for ShaderCreationError {
fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError>438     fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> {
439         match self {
440             Self::OomError(_) => write!(f, "not enough memory available"),
441             Self::SpirvCapabilityNotSupported { capability, .. } => write!(
442                 f,
443                 "the SPIR-V capability {:?} enabled by the shader is not supported by the device",
444                 capability,
445             ),
446             Self::SpirvError(_) => write!(f, "the SPIR-V module could not be read"),
447             Self::SpirvExtensionNotSupported { extension, .. } => write!(
448                 f,
449                 "the SPIR-V extension {} enabled by the shader is not supported by the device",
450                 extension,
451             ),
452             Self::SpirvVersionNotSupported { version, .. } => write!(
453                 f,
454                 "the shader uses SPIR-V version {}.{}, which is not supported by the device",
455                 version.major, version.minor,
456             ),
457         }
458     }
459 }
460 
461 impl From<VulkanError> for ShaderCreationError {
from(err: VulkanError) -> Self462     fn from(err: VulkanError) -> Self {
463         Self::OomError(err.into())
464     }
465 }
466 
467 impl From<SpirvError> for ShaderCreationError {
from(err: SpirvError) -> Self468     fn from(err: SpirvError) -> Self {
469         Self::SpirvError(err)
470     }
471 }
472 
473 /// Error that can happen when checking whether a shader is supported by a device.
474 #[derive(Clone, Copy, Debug)]
475 pub enum ShaderSupportError {
476     NotSupportedByVulkan,
477     RequirementsNotMet(&'static [&'static str]),
478 }
479 
480 impl Error for ShaderSupportError {}
481 
482 impl Display for ShaderSupportError {
fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError>483     fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> {
484         match self {
485             Self::NotSupportedByVulkan => write!(f, "not supported by Vulkan"),
486             Self::RequirementsNotMet(requirements) => write!(
487                 f,
488                 "at least one of the following must be available/enabled on the device: {}",
489                 requirements.join(", "),
490             ),
491         }
492     }
493 }
494 
495 /// The information associated with a single entry point in a shader.
496 #[derive(Clone, Debug)]
497 pub struct EntryPointInfo {
498     pub execution: ShaderExecution,
499     pub descriptor_binding_requirements: HashMap<(u32, u32), DescriptorBindingRequirements>,
500     pub push_constant_requirements: Option<PushConstantRange>,
501     pub specialization_constant_requirements: HashMap<u32, SpecializationConstantRequirements>,
502     pub input_interface: ShaderInterface,
503     pub output_interface: ShaderInterface,
504 }
505 
506 /// Represents a shader entry point in a shader module.
507 ///
508 /// Can be obtained by calling [`entry_point`](ShaderModule::entry_point) on the shader module.
509 #[derive(Clone, Debug)]
510 pub struct EntryPoint<'a> {
511     module: &'a ShaderModule,
512     name: CString,
513     info: &'a EntryPointInfo,
514 }
515 
516 impl<'a> EntryPoint<'a> {
517     /// Returns the module this entry point comes from.
518     #[inline]
module(&self) -> &'a ShaderModule519     pub fn module(&self) -> &'a ShaderModule {
520         self.module
521     }
522 
523     /// Returns the name of the entry point.
524     #[inline]
name(&self) -> &CStr525     pub fn name(&self) -> &CStr {
526         &self.name
527     }
528 
529     /// Returns the execution model of the shader.
530     #[inline]
execution(&self) -> &ShaderExecution531     pub fn execution(&self) -> &ShaderExecution {
532         &self.info.execution
533     }
534 
535     /// Returns the descriptor binding requirements.
536     #[inline]
descriptor_binding_requirements( &self, ) -> impl ExactSizeIterator<Item = ((u32, u32), &DescriptorBindingRequirements)>537     pub fn descriptor_binding_requirements(
538         &self,
539     ) -> impl ExactSizeIterator<Item = ((u32, u32), &DescriptorBindingRequirements)> {
540         self.info
541             .descriptor_binding_requirements
542             .iter()
543             .map(|(k, v)| (*k, v))
544     }
545 
546     /// Returns the push constant requirements.
547     #[inline]
push_constant_requirements(&self) -> Option<&PushConstantRange>548     pub fn push_constant_requirements(&self) -> Option<&PushConstantRange> {
549         self.info.push_constant_requirements.as_ref()
550     }
551 
552     /// Returns the specialization constant requirements.
553     #[inline]
specialization_constant_requirements( &self, ) -> impl ExactSizeIterator<Item = (u32, &SpecializationConstantRequirements)>554     pub fn specialization_constant_requirements(
555         &self,
556     ) -> impl ExactSizeIterator<Item = (u32, &SpecializationConstantRequirements)> {
557         self.info
558             .specialization_constant_requirements
559             .iter()
560             .map(|(k, v)| (*k, v))
561     }
562 
563     /// Returns the input attributes used by the shader stage.
564     #[inline]
input_interface(&self) -> &ShaderInterface565     pub fn input_interface(&self) -> &ShaderInterface {
566         &self.info.input_interface
567     }
568 
569     /// Returns the output attributes used by the shader stage.
570     #[inline]
output_interface(&self) -> &ShaderInterface571     pub fn output_interface(&self) -> &ShaderInterface {
572         &self.info.output_interface
573     }
574 }
575 
576 /// The mode in which a shader executes. This includes both information about the shader type/stage,
577 /// and additional data relevant to particular shader types.
578 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
579 pub enum ShaderExecution {
580     Vertex,
581     TessellationControl,
582     TessellationEvaluation,
583     Geometry(GeometryShaderExecution),
584     Fragment(FragmentShaderExecution),
585     Compute,
586     RayGeneration,
587     AnyHit,
588     ClosestHit,
589     Miss,
590     Intersection,
591     Callable,
592     Task,
593     Mesh,
594     SubpassShading,
595 }
596 
597 /*#[derive(Clone, Copy, Debug)]
598 pub struct TessellationShaderExecution {
599     pub num_output_vertices: u32,
600     pub point_mode: bool,
601     pub subdivision: TessellationShaderSubdivision,
602 }
603 
604 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
605 pub enum TessellationShaderSubdivision {
606     Triangles,
607     Quads,
608     Isolines,
609 }*/
610 
611 /// The mode in which a geometry shader executes.
612 #[derive(Clone, Copy, Debug, Eq, PartialEq)]
613 pub struct GeometryShaderExecution {
614     pub input: GeometryShaderInput,
615     /*pub max_output_vertices: u32,
616     pub num_invocations: u32,
617     pub output: GeometryShaderOutput,*/
618 }
619 
620 /// The input primitive type that is expected by a geometry shader.
621 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
622 pub enum GeometryShaderInput {
623     Points,
624     Lines,
625     LinesWithAdjacency,
626     Triangles,
627     TrianglesWithAdjacency,
628 }
629 
630 impl GeometryShaderInput {
631     /// Returns true if the given primitive topology can be used as input for this geometry shader.
632     #[inline]
is_compatible_with(self, topology: PrimitiveTopology) -> bool633     pub fn is_compatible_with(self, topology: PrimitiveTopology) -> bool {
634         match self {
635             Self::Points => matches!(topology, PrimitiveTopology::PointList),
636             Self::Lines => matches!(
637                 topology,
638                 PrimitiveTopology::LineList | PrimitiveTopology::LineStrip
639             ),
640             Self::LinesWithAdjacency => matches!(
641                 topology,
642                 PrimitiveTopology::LineListWithAdjacency
643                     | PrimitiveTopology::LineStripWithAdjacency
644             ),
645             Self::Triangles => matches!(
646                 topology,
647                 PrimitiveTopology::TriangleList
648                     | PrimitiveTopology::TriangleStrip
649                     | PrimitiveTopology::TriangleFan,
650             ),
651             Self::TrianglesWithAdjacency => matches!(
652                 topology,
653                 PrimitiveTopology::TriangleListWithAdjacency
654                     | PrimitiveTopology::TriangleStripWithAdjacency,
655             ),
656         }
657     }
658 }
659 
660 /*#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
661 pub enum GeometryShaderOutput {
662     Points,
663     LineStrip,
664     TriangleStrip,
665 }*/
666 
667 /// The mode in which a fragment shader executes.
668 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
669 pub struct FragmentShaderExecution {
670     pub fragment_tests_stages: FragmentTestsStages,
671 }
672 
673 /// The fragment tests stages that will be executed in a fragment shader.
674 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
675 pub enum FragmentTestsStages {
676     Early,
677     Late,
678     EarlyAndLate,
679 }
680 
681 /// The requirements imposed by a shader on a binding within a descriptor set layout, and on any
682 /// resource that is bound to that binding.
683 #[derive(Clone, Debug, Default)]
684 pub struct DescriptorBindingRequirements {
685     /// The descriptor types that are allowed.
686     pub descriptor_types: Vec<DescriptorType>,
687 
688     /// The number of descriptors (array elements) that the shader requires. The descriptor set
689     /// layout can declare more than this, but never less.
690     ///
691     /// `None` means that the shader declares this as a runtime-sized array, and could potentially
692     /// access every array element provided in the descriptor set.
693     pub descriptor_count: Option<u32>,
694 
695     /// The image format that is required for image views bound to this binding. If this is
696     /// `None`, then any image format is allowed.
697     pub image_format: Option<Format>,
698 
699     /// Whether image views bound to this binding must have multisampling enabled or disabled.
700     pub image_multisampled: bool,
701 
702     /// The base scalar type required for the format of image views bound to this binding.
703     /// This is `None` for non-image bindings.
704     pub image_scalar_type: Option<ShaderScalarType>,
705 
706     /// The view type that is required for image views bound to this binding.
707     /// This is `None` for non-image bindings.
708     pub image_view_type: Option<ImageViewType>,
709 
710     /// The shader stages that the binding must be declared for.
711     pub stages: ShaderStages,
712 
713     /// The requirements for individual descriptors within a binding.
714     ///
715     /// Keys with `Some` hold requirements for a specific descriptor index, if it is statically
716     /// known in the shader (a constant). The key `None` holds requirements for indices that are
717     /// not statically known, but determined only at runtime (calculated from an input variable).
718     pub descriptors: HashMap<Option<u32>, DescriptorRequirements>,
719 }
720 
721 /// The requirements imposed by a shader on resources bound to a descriptor.
722 #[derive(Clone, Debug, Default)]
723 pub struct DescriptorRequirements {
724     /// For buffers and images, which shader stages perform read operations.
725     pub memory_read: ShaderStages,
726 
727     /// For buffers and images, which shader stages perform write operations.
728     pub memory_write: ShaderStages,
729 
730     /// For sampler bindings, whether the shader performs depth comparison operations.
731     pub sampler_compare: bool,
732 
733     /// For sampler bindings, whether the shader performs sampling operations that are not
734     /// permitted with unnormalized coordinates. This includes sampling with `ImplicitLod`,
735     /// `Dref` or `Proj` SPIR-V instructions or with an LOD bias or offset.
736     pub sampler_no_unnormalized_coordinates: bool,
737 
738     /// For sampler bindings, whether the shader performs sampling operations that are not
739     /// permitted with a sampler YCbCr conversion. This includes sampling with `Gather` SPIR-V
740     /// instructions or with an offset.
741     pub sampler_no_ycbcr_conversion: bool,
742 
743     /// For sampler bindings, the sampled image descriptors that are used in combination with this
744     /// sampler.
745     pub sampler_with_images: HashSet<DescriptorIdentifier>,
746 
747     /// For storage image bindings, whether the shader performs atomic operations.
748     pub storage_image_atomic: bool,
749 }
750 
751 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
752 pub struct DescriptorIdentifier {
753     pub set: u32,
754     pub binding: u32,
755     pub index: u32,
756 }
757 
758 impl DescriptorBindingRequirements {
759     /// Merges `other` into `self`, so that `self` satisfies the requirements of both.
760     /// An error is returned if the requirements conflict.
761     #[inline]
merge(&mut self, other: &Self) -> Result<(), DescriptorBindingRequirementsIncompatible>762     pub fn merge(&mut self, other: &Self) -> Result<(), DescriptorBindingRequirementsIncompatible> {
763         let Self {
764             descriptor_types,
765             descriptor_count,
766             image_format,
767             image_multisampled,
768             image_scalar_type,
769             image_view_type,
770             stages,
771             descriptors,
772         } = self;
773 
774         /* Checks */
775 
776         if !descriptor_types
777             .iter()
778             .any(|ty| other.descriptor_types.contains(ty))
779         {
780             return Err(DescriptorBindingRequirementsIncompatible::DescriptorType);
781         }
782 
783         if let (Some(first), Some(second)) = (*image_format, other.image_format) {
784             if first != second {
785                 return Err(DescriptorBindingRequirementsIncompatible::ImageFormat);
786             }
787         }
788 
789         if let (Some(first), Some(second)) = (*image_scalar_type, other.image_scalar_type) {
790             if first != second {
791                 return Err(DescriptorBindingRequirementsIncompatible::ImageScalarType);
792             }
793         }
794 
795         if let (Some(first), Some(second)) = (*image_view_type, other.image_view_type) {
796             if first != second {
797                 return Err(DescriptorBindingRequirementsIncompatible::ImageViewType);
798             }
799         }
800 
801         if *image_multisampled != other.image_multisampled {
802             return Err(DescriptorBindingRequirementsIncompatible::ImageMultisampled);
803         }
804 
805         /* Merge */
806 
807         descriptor_types.retain(|ty| other.descriptor_types.contains(ty));
808 
809         *descriptor_count = (*descriptor_count).max(other.descriptor_count);
810         *image_format = image_format.or(other.image_format);
811         *image_scalar_type = image_scalar_type.or(other.image_scalar_type);
812         *image_view_type = image_view_type.or(other.image_view_type);
813         *stages |= other.stages;
814 
815         for (&index, other) in &other.descriptors {
816             match descriptors.entry(index) {
817                 Entry::Vacant(entry) => {
818                     entry.insert(other.clone());
819                 }
820                 Entry::Occupied(entry) => {
821                     entry.into_mut().merge(other);
822                 }
823             }
824         }
825 
826         Ok(())
827     }
828 }
829 
830 impl DescriptorRequirements {
831     /// Merges `other` into `self`, so that `self` satisfies the requirements of both.
832     #[inline]
merge(&mut self, other: &Self)833     pub fn merge(&mut self, other: &Self) {
834         let Self {
835             memory_read,
836             memory_write,
837             sampler_compare,
838             sampler_no_unnormalized_coordinates,
839             sampler_no_ycbcr_conversion,
840             sampler_with_images,
841             storage_image_atomic,
842         } = self;
843 
844         *memory_read |= other.memory_read;
845         *memory_write |= other.memory_write;
846         *sampler_compare |= other.sampler_compare;
847         *sampler_no_unnormalized_coordinates |= other.sampler_no_unnormalized_coordinates;
848         *sampler_no_ycbcr_conversion |= other.sampler_no_ycbcr_conversion;
849         sampler_with_images.extend(&other.sampler_with_images);
850         *storage_image_atomic |= other.storage_image_atomic;
851     }
852 }
853 
854 /// An error that can be returned when trying to create the intersection of two
855 /// `DescriptorBindingRequirements` values.
856 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
857 pub enum DescriptorBindingRequirementsIncompatible {
858     /// The allowed descriptor types of the descriptors do not overlap.
859     DescriptorType,
860     /// The descriptors require different formats.
861     ImageFormat,
862     /// The descriptors require different scalar types.
863     ImageScalarType,
864     /// The multisampling requirements of the descriptors differ.
865     ImageMultisampled,
866     /// The descriptors require different image view types.
867     ImageViewType,
868 }
869 
870 impl Error for DescriptorBindingRequirementsIncompatible {}
871 
872 impl Display for DescriptorBindingRequirementsIncompatible {
fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError>873     fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> {
874         match self {
875             DescriptorBindingRequirementsIncompatible::DescriptorType => write!(
876                 f,
877                 "the allowed descriptor types of the two descriptors do not overlap",
878             ),
879             DescriptorBindingRequirementsIncompatible::ImageFormat => {
880                 write!(f, "the descriptors require different formats",)
881             }
882             DescriptorBindingRequirementsIncompatible::ImageMultisampled => write!(
883                 f,
884                 "the multisampling requirements of the descriptors differ",
885             ),
886             DescriptorBindingRequirementsIncompatible::ImageScalarType => {
887                 write!(f, "the descriptors require different scalar types",)
888             }
889             DescriptorBindingRequirementsIncompatible::ImageViewType => {
890                 write!(f, "the descriptors require different image view types",)
891             }
892         }
893     }
894 }
895 
896 /// The requirements imposed by a shader on a specialization constant.
897 #[derive(Clone, Copy, Debug)]
898 pub struct SpecializationConstantRequirements {
899     pub size: DeviceSize,
900 }
901 
902 /// Trait for types that contain specialization data for shaders.
903 ///
904 /// Shader modules can contain what is called *specialization constants*. They are the same as
905 /// constants except that their values can be defined when you create a compute pipeline or a
906 /// graphics pipeline. Doing so is done by passing a type that implements the
907 /// `SpecializationConstants` trait and that stores the values in question. The `descriptors()`
908 /// method of this trait indicates how to grab them.
909 ///
910 /// Boolean specialization constants must be stored as 32bits integers, where `0` means `false` and
911 /// any non-zero value means `true`. Integer and floating-point specialization constants are
912 /// stored as their Rust equivalent.
913 ///
914 /// This trait is implemented on `()` for shaders that don't have any specialization constant.
915 ///
916 /// # Examples
917 ///
918 /// ```rust
919 /// use vulkano::shader::SpecializationConstants;
920 /// use vulkano::shader::SpecializationMapEntry;
921 ///
922 /// #[repr(C)]      // `#[repr(C)]` guarantees that the struct has a specific layout
923 /// struct MySpecConstants {
924 ///     my_integer_constant: i32,
925 ///     a_boolean: u32,
926 ///     floating_point: f32,
927 /// }
928 ///
929 /// unsafe impl SpecializationConstants for MySpecConstants {
930 ///     fn descriptors() -> &'static [SpecializationMapEntry] {
931 ///         static DESCRIPTORS: [SpecializationMapEntry; 3] = [
932 ///             SpecializationMapEntry {
933 ///                 constant_id: 0,
934 ///                 offset: 0,
935 ///                 size: 4,
936 ///             },
937 ///             SpecializationMapEntry {
938 ///                 constant_id: 1,
939 ///                 offset: 4,
940 ///                 size: 4,
941 ///             },
942 ///             SpecializationMapEntry {
943 ///                 constant_id: 2,
944 ///                 offset: 8,
945 ///                 size: 4,
946 ///             },
947 ///         ];
948 ///
949 ///         &DESCRIPTORS
950 ///     }
951 /// }
952 /// ```
953 ///
954 /// # Safety
955 ///
956 /// - The `SpecializationMapEntry` returned must contain valid offsets and sizes.
957 /// - The size of each `SpecializationMapEntry` must match the size of the corresponding constant
958 ///   (`4` for booleans).
959 pub unsafe trait SpecializationConstants {
960     /// Returns descriptors of the struct's layout.
descriptors() -> &'static [SpecializationMapEntry]961     fn descriptors() -> &'static [SpecializationMapEntry];
962 }
963 
964 unsafe impl SpecializationConstants for () {
965     #[inline]
descriptors() -> &'static [SpecializationMapEntry]966     fn descriptors() -> &'static [SpecializationMapEntry] {
967         &[]
968     }
969 }
970 
971 /// Describes an individual constant to set in the shader. Also a field in the struct.
972 // Implementation note: has the same memory representation as a `VkSpecializationMapEntry`.
973 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
974 #[repr(C)]
975 pub struct SpecializationMapEntry {
976     /// Identifier of the constant in the shader that corresponds to this field.
977     ///
978     /// For SPIR-V, this must be the value of the `SpecId` decoration applied to the specialization
979     /// constant.
980     /// For GLSL, this must be the value of `N` in the `layout(constant_id = N)` attribute applied
981     /// to a constant.
982     pub constant_id: u32,
983 
984     /// Offset within the struct where the data can be found.
985     pub offset: u32,
986 
987     /// Size of the data in bytes. Must match the size of the constant (`4` for booleans).
988     pub size: usize,
989 }
990 
991 impl From<SpecializationMapEntry> for ash::vk::SpecializationMapEntry {
992     #[inline]
from(val: SpecializationMapEntry) -> Self993     fn from(val: SpecializationMapEntry) -> Self {
994         Self {
995             constant_id: val.constant_id,
996             offset: val.offset,
997             size: val.size,
998         }
999     }
1000 }
1001 
1002 /// Type that contains the definition of an interface between two shader stages, or between
1003 /// the outside and a shader stage.
1004 #[derive(Clone, Debug)]
1005 pub struct ShaderInterface {
1006     elements: Vec<ShaderInterfaceEntry>,
1007 }
1008 
1009 impl ShaderInterface {
1010     /// Constructs a new `ShaderInterface`.
1011     ///
1012     /// # Safety
1013     ///
1014     /// - Must only provide one entry per location.
1015     /// - The format of each element must not be larger than 128 bits.
1016     // TODO: 4x64 bit formats are possible, but they require special handling.
1017     // TODO: could this be made safe?
1018     #[inline]
new_unchecked(elements: Vec<ShaderInterfaceEntry>) -> ShaderInterface1019     pub unsafe fn new_unchecked(elements: Vec<ShaderInterfaceEntry>) -> ShaderInterface {
1020         ShaderInterface { elements }
1021     }
1022 
1023     /// Creates a description of an empty shader interface.
1024     #[inline]
empty() -> ShaderInterface1025     pub const fn empty() -> ShaderInterface {
1026         ShaderInterface {
1027             elements: Vec::new(),
1028         }
1029     }
1030 
1031     /// Returns a slice containing the elements of the interface.
1032     #[inline]
elements(&self) -> &[ShaderInterfaceEntry]1033     pub fn elements(&self) -> &[ShaderInterfaceEntry] {
1034         self.elements.as_ref()
1035     }
1036 
1037     /// Checks whether the interface is potentially compatible with another one.
1038     ///
1039     /// Returns `Ok` if the two interfaces are compatible.
1040     #[inline]
matches(&self, other: &ShaderInterface) -> Result<(), ShaderInterfaceMismatchError>1041     pub fn matches(&self, other: &ShaderInterface) -> Result<(), ShaderInterfaceMismatchError> {
1042         if self.elements().len() != other.elements().len() {
1043             return Err(ShaderInterfaceMismatchError::ElementsCountMismatch {
1044                 self_elements: self.elements().len() as u32,
1045                 other_elements: other.elements().len() as u32,
1046             });
1047         }
1048 
1049         for a in self.elements() {
1050             let location_range = a.location..a.location + a.ty.num_locations();
1051             for loc in location_range {
1052                 let b = match other
1053                     .elements()
1054                     .iter()
1055                     .find(|e| loc >= e.location && loc < e.location + e.ty.num_locations())
1056                 {
1057                     None => {
1058                         return Err(ShaderInterfaceMismatchError::MissingElement { location: loc })
1059                     }
1060                     Some(b) => b,
1061                 };
1062 
1063                 if a.ty != b.ty {
1064                     return Err(ShaderInterfaceMismatchError::TypeMismatch {
1065                         location: loc,
1066                         self_ty: a.ty,
1067                         other_ty: b.ty,
1068                     });
1069                 }
1070 
1071                 // TODO: enforce this?
1072                 /*match (a.name, b.name) {
1073                     (Some(ref an), Some(ref bn)) => if an != bn { return false },
1074                     _ => ()
1075                 };*/
1076             }
1077         }
1078 
1079         // Note: since we check that the number of elements is the same, we don't need to iterate
1080         // over b's elements.
1081 
1082         Ok(())
1083     }
1084 }
1085 
1086 /// Entry of a shader interface definition.
1087 #[derive(Debug, Clone)]
1088 pub struct ShaderInterfaceEntry {
1089     /// The location slot that the variable starts at.
1090     pub location: u32,
1091 
1092     /// The component slot that the variable starts at. Must be in the range 0..=3.
1093     pub component: u32,
1094 
1095     /// Name of the element, or `None` if the name is unknown.
1096     pub name: Option<Cow<'static, str>>,
1097 
1098     /// The type of the variable.
1099     pub ty: ShaderInterfaceEntryType,
1100 }
1101 
1102 /// The type of a variable in a shader interface.
1103 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1104 pub struct ShaderInterfaceEntryType {
1105     /// The base numeric type.
1106     pub base_type: ShaderScalarType,
1107 
1108     /// The number of vector components. Must be in the range 1..=4.
1109     pub num_components: u32,
1110 
1111     /// The number of array elements or matrix columns.
1112     pub num_elements: u32,
1113 
1114     /// Whether the base type is 64 bits wide. If true, each item of the base type takes up two
1115     /// component slots instead of one.
1116     pub is_64bit: bool,
1117 }
1118 
1119 impl ShaderInterfaceEntryType {
num_locations(&self) -> u321120     pub(crate) fn num_locations(&self) -> u32 {
1121         assert!(!self.is_64bit); // TODO: implement
1122         self.num_elements
1123     }
1124 }
1125 
1126 /// The numeric base type of a shader variable.
1127 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
1128 pub enum ShaderScalarType {
1129     Float,
1130     Sint,
1131     Uint,
1132 }
1133 
1134 // https://registry.khronos.org/vulkan/specs/1.3-extensions/html/chap43.html#formats-numericformat
1135 impl From<NumericType> for ShaderScalarType {
1136     #[inline]
from(val: NumericType) -> Self1137     fn from(val: NumericType) -> Self {
1138         match val {
1139             NumericType::SFLOAT => Self::Float,
1140             NumericType::UFLOAT => Self::Float,
1141             NumericType::SINT => Self::Sint,
1142             NumericType::UINT => Self::Uint,
1143             NumericType::SNORM => Self::Float,
1144             NumericType::UNORM => Self::Float,
1145             NumericType::SSCALED => Self::Float,
1146             NumericType::USCALED => Self::Float,
1147             NumericType::SRGB => Self::Float,
1148         }
1149     }
1150 }
1151 
1152 /// Error that can happen when the interface mismatches between two shader stages.
1153 #[derive(Clone, Debug, PartialEq, Eq)]
1154 pub enum ShaderInterfaceMismatchError {
1155     /// The number of elements is not the same between the two shader interfaces.
1156     ElementsCountMismatch {
1157         /// Number of elements in the first interface.
1158         self_elements: u32,
1159         /// Number of elements in the second interface.
1160         other_elements: u32,
1161     },
1162 
1163     /// An element is missing from one of the interfaces.
1164     MissingElement {
1165         /// Location of the missing element.
1166         location: u32,
1167     },
1168 
1169     /// The type of an element does not match.
1170     TypeMismatch {
1171         /// Location of the element that mismatches.
1172         location: u32,
1173         /// Type in the first interface.
1174         self_ty: ShaderInterfaceEntryType,
1175         /// Type in the second interface.
1176         other_ty: ShaderInterfaceEntryType,
1177     },
1178 }
1179 
1180 impl Error for ShaderInterfaceMismatchError {}
1181 
1182 impl Display for ShaderInterfaceMismatchError {
fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError>1183     fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> {
1184         write!(
1185             f,
1186             "{}",
1187             match self {
1188                 ShaderInterfaceMismatchError::ElementsCountMismatch { .. } => {
1189                     "the number of elements mismatches"
1190                 }
1191                 ShaderInterfaceMismatchError::MissingElement { .. } => "an element is missing",
1192                 ShaderInterfaceMismatchError::TypeMismatch { .. } => {
1193                     "the type of an element does not match"
1194                 }
1195             }
1196         )
1197     }
1198 }
1199 
1200 vulkan_bitflags_enum! {
1201     #[non_exhaustive]
1202 
1203     /// A set of [`ShaderStage`] values.
1204     ShaderStages impl {
1205         /// Creates a `ShaderStages` struct with all graphics stages set to `true`.
1206         #[inline]
1207         pub const fn all_graphics() -> ShaderStages {
1208             ShaderStages::VERTEX
1209                 .union(ShaderStages::TESSELLATION_CONTROL)
1210                 .union(ShaderStages::TESSELLATION_EVALUATION)
1211                 .union(ShaderStages::GEOMETRY)
1212                 .union(ShaderStages::FRAGMENT)
1213         }
1214     },
1215 
1216     /// A shader stage within a pipeline.
1217     ShaderStage,
1218 
1219     = ShaderStageFlags(u32);
1220 
1221     // TODO: document
1222     VERTEX, Vertex = VERTEX,
1223 
1224     // TODO: document
1225     TESSELLATION_CONTROL, TessellationControl = TESSELLATION_CONTROL,
1226 
1227     // TODO: document
1228     TESSELLATION_EVALUATION, TessellationEvaluation = TESSELLATION_EVALUATION,
1229 
1230     // TODO: document
1231     GEOMETRY, Geometry = GEOMETRY,
1232 
1233     // TODO: document
1234     FRAGMENT, Fragment = FRAGMENT,
1235 
1236     // TODO: document
1237     COMPUTE, Compute = COMPUTE,
1238 
1239     // TODO: document
1240     RAYGEN, Raygen = RAYGEN_KHR {
1241         device_extensions: [khr_ray_tracing_pipeline, nv_ray_tracing],
1242     },
1243 
1244     // TODO: document
1245     ANY_HIT, AnyHit = ANY_HIT_KHR {
1246         device_extensions: [khr_ray_tracing_pipeline, nv_ray_tracing],
1247     },
1248 
1249     // TODO: document
1250     CLOSEST_HIT, ClosestHit = CLOSEST_HIT_KHR {
1251         device_extensions: [khr_ray_tracing_pipeline, nv_ray_tracing],
1252     },
1253 
1254     // TODO: document
1255     MISS, Miss = MISS_KHR {
1256         device_extensions: [khr_ray_tracing_pipeline, nv_ray_tracing],
1257     },
1258 
1259     // TODO: document
1260     INTERSECTION, Intersection = INTERSECTION_KHR {
1261         device_extensions: [khr_ray_tracing_pipeline, nv_ray_tracing],
1262     },
1263 
1264     // TODO: document
1265     CALLABLE, Callable = CALLABLE_KHR {
1266         device_extensions: [khr_ray_tracing_pipeline, nv_ray_tracing],
1267     },
1268 
1269     // TODO: document
1270     TASK, Task = TASK_EXT {
1271         device_extensions: [ext_mesh_shader, nv_mesh_shader],
1272     },
1273 
1274     // TODO: document
1275     MESH, Mesh = MESH_EXT {
1276         device_extensions: [ext_mesh_shader, nv_mesh_shader],
1277     },
1278 
1279     // TODO: document
1280     SUBPASS_SHADING, SubpassShading = SUBPASS_SHADING_HUAWEI {
1281         device_extensions: [huawei_subpass_shading],
1282     },
1283 }
1284 
1285 impl From<ShaderExecution> for ShaderStage {
1286     #[inline]
from(val: ShaderExecution) -> Self1287     fn from(val: ShaderExecution) -> Self {
1288         match val {
1289             ShaderExecution::Vertex => Self::Vertex,
1290             ShaderExecution::TessellationControl => Self::TessellationControl,
1291             ShaderExecution::TessellationEvaluation => Self::TessellationEvaluation,
1292             ShaderExecution::Geometry(_) => Self::Geometry,
1293             ShaderExecution::Fragment(_) => Self::Fragment,
1294             ShaderExecution::Compute => Self::Compute,
1295             ShaderExecution::RayGeneration => Self::Raygen,
1296             ShaderExecution::AnyHit => Self::AnyHit,
1297             ShaderExecution::ClosestHit => Self::ClosestHit,
1298             ShaderExecution::Miss => Self::Miss,
1299             ShaderExecution::Intersection => Self::Intersection,
1300             ShaderExecution::Callable => Self::Callable,
1301             ShaderExecution::Task => Self::Task,
1302             ShaderExecution::Mesh => Self::Mesh,
1303             ShaderExecution::SubpassShading => Self::SubpassShading,
1304         }
1305     }
1306 }
1307 
1308 impl From<ShaderStages> for PipelineStages {
1309     #[inline]
from(stages: ShaderStages) -> PipelineStages1310     fn from(stages: ShaderStages) -> PipelineStages {
1311         let mut result = PipelineStages::empty();
1312 
1313         if stages.intersects(ShaderStages::VERTEX) {
1314             result |= PipelineStages::VERTEX_SHADER
1315         }
1316 
1317         if stages.intersects(ShaderStages::TESSELLATION_CONTROL) {
1318             result |= PipelineStages::TESSELLATION_CONTROL_SHADER
1319         }
1320 
1321         if stages.intersects(ShaderStages::TESSELLATION_EVALUATION) {
1322             result |= PipelineStages::TESSELLATION_EVALUATION_SHADER
1323         }
1324 
1325         if stages.intersects(ShaderStages::GEOMETRY) {
1326             result |= PipelineStages::GEOMETRY_SHADER
1327         }
1328 
1329         if stages.intersects(ShaderStages::FRAGMENT) {
1330             result |= PipelineStages::FRAGMENT_SHADER
1331         }
1332 
1333         if stages.intersects(ShaderStages::COMPUTE) {
1334             result |= PipelineStages::COMPUTE_SHADER
1335         }
1336 
1337         if stages.intersects(
1338             ShaderStages::RAYGEN
1339                 | ShaderStages::ANY_HIT
1340                 | ShaderStages::CLOSEST_HIT
1341                 | ShaderStages::MISS
1342                 | ShaderStages::INTERSECTION
1343                 | ShaderStages::CALLABLE,
1344         ) {
1345             result |= PipelineStages::RAY_TRACING_SHADER
1346         }
1347 
1348         if stages.intersects(ShaderStages::TASK) {
1349             result |= PipelineStages::TASK_SHADER;
1350         }
1351 
1352         if stages.intersects(ShaderStages::MESH) {
1353             result |= PipelineStages::MESH_SHADER;
1354         }
1355 
1356         if stages.intersects(ShaderStages::SUBPASS_SHADING) {
1357             result |= PipelineStages::SUBPASS_SHADING;
1358         }
1359 
1360         result
1361     }
1362 }
1363 
check_spirv_version(device: &Device, mut version: Version) -> Result<(), ShaderSupportError>1364 fn check_spirv_version(device: &Device, mut version: Version) -> Result<(), ShaderSupportError> {
1365     version.patch = 0; // Ignore the patch version
1366 
1367     match version {
1368         Version::V1_0 => {}
1369         Version::V1_1 | Version::V1_2 | Version::V1_3 => {
1370             if !(device.api_version() >= Version::V1_1) {
1371                 return Err(ShaderSupportError::RequirementsNotMet(&[
1372                     "Vulkan API version 1.1",
1373                 ]));
1374             }
1375         }
1376         Version::V1_4 => {
1377             if !(device.api_version() >= Version::V1_2 || device.enabled_extensions().khr_spirv_1_4)
1378             {
1379                 return Err(ShaderSupportError::RequirementsNotMet(&[
1380                     "Vulkan API version 1.2",
1381                     "extension `khr_spirv_1_4`",
1382                 ]));
1383             }
1384         }
1385         Version::V1_5 => {
1386             if !(device.api_version() >= Version::V1_2) {
1387                 return Err(ShaderSupportError::RequirementsNotMet(&[
1388                     "Vulkan API version 1.2",
1389                 ]));
1390             }
1391         }
1392         Version::V1_6 => {
1393             if !(device.api_version() >= Version::V1_3) {
1394                 return Err(ShaderSupportError::RequirementsNotMet(&[
1395                     "Vulkan API version 1.3",
1396                 ]));
1397             }
1398         }
1399         _ => return Err(ShaderSupportError::NotSupportedByVulkan),
1400     }
1401     Ok(())
1402 }
1403