1use std::{
16 ffi::{c_void, CStr},
17 fmt::Debug,
18 mem::MaybeUninit,
19};
20
21use ash::vk;
22
23use smallvec::SmallVec;
24use thiserror::Error;
25
26pub mod generated;
27pub use generated::*;
28
29use crate::vk_utils::ptr_as_uninit_mut;
30
31#[derive(Error, Debug)]
32pub enum TryFromExtensionError {
33 #[error("unknown extension `{0}`")]
34 UnknownExtension(String),
35}
36
37#[derive(PartialEq, Eq, PartialOrd, Ord, Clone)]
44pub enum Feature {
45 Core(ApiVersion),
47 Extension(Extension),
49}
50
51impl From<ApiVersion> for Feature {
52 fn from(value: ApiVersion) -> Self {
53 Self::Core(value)
54 }
55}
56
57impl From<Extension> for Feature {
58 fn from(value: Extension) -> Self {
59 Self::Extension(value)
60 }
61}
62
63#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
68pub struct ApiVersion {
69 pub major: u8,
71 pub minor: u16,
73}
74
75impl ApiVersion {
76 pub const V1_0: Self = Self { major: 1, minor: 0 };
80 pub const V1_1: Self = Self { major: 1, minor: 1 };
84}
85
86impl From<u32> for ApiVersion {
87 fn from(value: u32) -> Self {
88 Self {
89 major: vk::api_version_major(value)
90 .try_into()
91 .expect("The major version must be no more than 7 bits."),
92 minor: vk::api_version_minor(value)
93 .try_into()
94 .expect("The minor version must be no more than 10 bits."),
95 }
96 }
97}
98
99impl From<ApiVersion> for u32 {
100 fn from(value: ApiVersion) -> Self {
101 vk::make_api_version(0, value.major.into(), value.minor.into(), 0)
102 }
103}
104
105pub(crate) struct VulkanCommand {
106 pub name: &'static str,
107 pub features: SmallVec<[Feature; 2]>,
108 pub hooked: bool,
109 pub proc: vk::PFN_vkVoidFunction,
110}
111
112fn get_instance_proc_addr_loader(
113 get_instance_proc_addr: vk::PFN_vkGetInstanceProcAddr,
114 instance: &ash::Instance,
115) -> impl Fn(&CStr) -> *const c_void + '_ {
116 move |name| {
117 let fp = unsafe { get_instance_proc_addr(instance.handle(), name.as_ptr()) };
120 if let Some(fp) = fp {
121 return fp as *const _;
122 }
123 std::ptr::null()
124 }
125}
126
127fn get_device_proc_addr_loader(
128 get_device_proc_addr: vk::PFN_vkGetDeviceProcAddr,
129 device: &ash::Device,
130) -> impl Fn(&CStr) -> *const c_void + '_ {
131 move |name| {
132 let fp = unsafe { get_device_proc_addr(device.handle(), name.as_ptr()) };
135 match fp {
136 Some(fp) => fp as *const _,
137 None => std::ptr::null(),
138 }
139 }
140}
141
142#[deny(unsafe_op_in_unsafe_fn)]
166unsafe fn maybe_uninit_slice_from_raw_parts_mut<'a, T>(
167 p_out_array: *mut T,
168 p_size: *const (impl TryInto<usize, Error = impl Debug> + Copy),
169) -> Option<&'a mut [MaybeUninit<T>]> {
170 let size: usize = unsafe { p_size.as_ref() }
171 .copied()?
172 .try_into()
173 .expect("size mut be within the range of usize");
174 if p_out_array.is_null() {
175 return None;
176 }
177 Some(unsafe { uninit_slice_from_raw_parts_mut(p_out_array, size) })
178}
179
180#[deny(unsafe_op_in_unsafe_fn)]
206unsafe fn maybe_slice_from_raw_parts<'a, T>(
207 data: *const T,
208 len: impl TryInto<usize, Error = impl Debug>,
209) -> Option<&'a [T]> {
210 if data.is_null() {
211 return None;
212 }
213 let len: usize = len
214 .try_into()
215 .expect("len mut be within the range of usize");
216 if len == 0 {
217 return Some(&[]);
218 }
219 Some(unsafe { std::slice::from_raw_parts(data, len) })
220}
221
222#[deny(unsafe_op_in_unsafe_fn)]
245unsafe fn uninit_slice_from_raw_parts_mut<'a, T>(
246 p_data: *mut T,
247 size: impl TryInto<usize, Error = impl Debug>,
248) -> &'a mut [MaybeUninit<T>] {
249 let size: usize = size
250 .try_into()
251 .expect("size mut be within the range of usize");
252 if size == 0 {
253 return &mut [];
254 }
255 let first_element =
256 unsafe { ptr_as_uninit_mut(p_data) }.expect("the input data pointer should not be null");
257 unsafe { std::slice::from_raw_parts_mut(first_element, size) }
258}
259
260#[deny(unsafe_op_in_unsafe_fn)]
261unsafe fn bool_iterator_from_raw_parts(
262 ptr: *const vk::Bool32,
263 size: impl TryInto<usize, Error = impl Debug>,
264) -> impl Iterator<Item = bool> {
265 let size: usize = size
266 .try_into()
267 .expect("size mut be within the range of usize");
268 let slice = if size == 0 {
269 &[]
270 } else {
271 unsafe { std::slice::from_raw_parts(ptr, size) }
272 };
273 slice.iter().map(|v| *v == vk::TRUE)
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn bool_iterator_from_raw_parts_result_should_match() {
282 let expected_value = vec![true, true, false, true, false];
283 let input = expected_value
284 .iter()
285 .map(|v| if *v { vk::TRUE } else { vk::FALSE })
286 .collect::<Vec<_>>();
287 let result = unsafe { bool_iterator_from_raw_parts(input.as_ptr(), expected_value.len()) }
288 .collect::<Vec<_>>();
289 assert_eq!(result, expected_value);
290 }
291
292 #[test]
293 fn bool_iterator_from_raw_parts_empty_with_null_ptr() {
294 let result =
295 unsafe { bool_iterator_from_raw_parts(std::ptr::null(), 0) }.collect::<Vec<_>>();
296 assert!(result.is_empty());
297 }
298
299 #[test]
300 fn maybe_uninit_slice_from_raw_parts_mut_result_should_match() {
301 const LEN: usize = 10;
302 let expected_value: [i32; LEN] = [81, 95, 43, 65, 34, 47, 65, 62, 47, 82];
303 let mut input = [MaybeUninit::<i32>::uninit(); LEN];
304 let input_ptr = input.as_mut_ptr() as *mut i32;
305 let output = unsafe { maybe_uninit_slice_from_raw_parts_mut(input_ptr, &LEN) }
306 .expect("for valid input should return in Some");
307 assert_eq!(output.len(), LEN);
308 for (i, output_element) in output.iter_mut().enumerate() {
309 output_element.write(expected_value[i]);
310 }
311
312 for (i, input_element) in input.iter().enumerate() {
313 assert_eq!(
314 *unsafe { input_element.assume_init_ref() },
315 expected_value[i]
316 );
317 }
318 }
319
320 #[test]
321 fn maybe_uninit_slice_from_raw_parts_mut_null_data_ptr() {
322 let input_ptr: *mut i32 = std::ptr::null_mut();
323 let output = unsafe { maybe_uninit_slice_from_raw_parts_mut(input_ptr, &10) };
324 assert!(output.is_none());
325 }
326
327 #[test]
328 fn maybe_uninit_slice_from_raw_parts_zero_length_unaligned_data_ptr() {
329 let mut u8_array = [0u8; 2];
331 let input_ptrs = [
332 &mut u8_array[0] as *mut _ as *mut i32,
333 &mut u8_array[1] as *mut _ as *mut i32,
334 ];
335 for input_ptr in input_ptrs {
336 let output = unsafe { maybe_uninit_slice_from_raw_parts_mut(input_ptr, &0) }
337 .expect("for valid input pointer, Some should be returned");
338 assert!(output.is_empty());
339 }
340 }
341
342 #[test]
343 fn maybe_uninit_slice_from_raw_parts_mut_null_size_ptr() {
344 let mut data = MaybeUninit::<u32>::uninit();
345 let output = unsafe {
346 maybe_uninit_slice_from_raw_parts_mut(data.as_mut_ptr(), std::ptr::null::<usize>())
347 };
348 assert!(output.is_none());
349 }
350
351 #[test]
352 #[should_panic]
353 fn maybe_uninit_slice_from_raw_parts_mut_invalid_size_value() {
354 let mut data = MaybeUninit::<u32>::uninit();
355 unsafe { maybe_uninit_slice_from_raw_parts_mut(data.as_mut_ptr(), &-1) };
356 }
357
358 #[test]
359 fn uninit_slice_from_raw_parts_mut_zero_size_and_invalid_data_address() {
360 let mut u8_array = [0u8; 2];
362 let input_ptrs = [
363 &mut u8_array[0] as *mut _ as *mut i32,
364 &mut u8_array[1] as *mut _ as *mut i32,
365 ];
366 for input_ptr in input_ptrs {
367 let output = unsafe { uninit_slice_from_raw_parts_mut(input_ptr, 0) };
368 assert!(output.is_empty());
369 }
370
371 let output = unsafe { uninit_slice_from_raw_parts_mut(std::ptr::null_mut::<u8>(), 0) };
372 assert!(output.is_empty());
373 }
374
375 #[test]
376 #[should_panic]
377 fn uninit_slice_from_raw_parts_mut_valid_size_and_null_data_address() {
378 unsafe { uninit_slice_from_raw_parts_mut(std::ptr::null_mut::<u8>(), 10) };
379 }
380
381 #[test]
382 #[should_panic]
383 fn uninit_slice_from_raw_parts_mut_invalid_size_value() {
384 let mut data = MaybeUninit::<u32>::uninit();
385 unsafe { uninit_slice_from_raw_parts_mut(data.as_mut_ptr(), -1) };
386 }
387
388 #[test]
389 fn uninit_slice_from_raw_parts_mut_result_should_match() {
390 const LEN: usize = 10;
391 let expected_value: [i32; LEN] = [14, 45, 60, 97, 35, 21, 13, 42, 11, 12];
392 let mut input = [MaybeUninit::<i32>::uninit(); LEN];
393 let input_ptr = input.as_mut_ptr() as *mut i32;
394 let output = unsafe { uninit_slice_from_raw_parts_mut(input_ptr, LEN) };
395 assert_eq!(output.len(), LEN);
396 for (i, output_element) in output.iter_mut().enumerate() {
397 output_element.write(expected_value[i]);
398 }
399
400 for (i, input_element) in input.iter().enumerate() {
401 assert_eq!(
402 *unsafe { input_element.assume_init_ref() },
403 expected_value[i]
404 );
405 }
406 }
407
408 #[test]
409 fn maybe_slice_from_raw_parts_null_data() {
410 let res = unsafe { maybe_slice_from_raw_parts(std::ptr::null::<u8>(), 0) };
411 assert!(res.is_none());
412 let res = unsafe { maybe_slice_from_raw_parts(std::ptr::null::<u8>(), 10) };
413 assert!(res.is_none());
414 }
415
416 #[test]
417 fn maybe_slice_from_raw_parts_zero_length_invalid_data_ptr() {
418 let u8_array = [0u8; 2];
419 let input_ptrs = u8_array
421 .iter()
422 .map(|element| element as *const _ as *const u32)
423 .collect::<Vec<_>>();
424 for input_ptr in input_ptrs {
425 let res = unsafe { maybe_slice_from_raw_parts(input_ptr, 0) }
426 .expect("should always return Some for non-null pointers");
427 assert!(res.is_empty());
428 }
429 }
430
431 #[test]
432 fn maybe_slice_from_raw_parts_reflects_original_array() {
433 let input: [u32; 10] = [47, 63, 14, 13, 8, 45, 52, 97, 21, 10];
434 let result = unsafe { maybe_slice_from_raw_parts(input.as_ptr(), input.len()) }
435 .expect("should always return Some for non-null pointers");
436 assert_eq!(result, &input);
437 }
438
439 #[test]
440 #[should_panic]
441 fn maybe_slice_from_raw_parts_bad_len() {
442 unsafe { maybe_slice_from_raw_parts(std::ptr::NonNull::<u8>::dangling().as_ptr(), -1) };
443 }
444
445 #[test]
446 fn extension_try_from_should_return_error_on_unknown_extension() {
447 let unknown_extension = "VK_UNKNOWN_unknown";
448 let err = Extension::try_from(unknown_extension).unwrap_err();
449 assert!(err.to_string().contains(unknown_extension));
450 }
451}