vulkan_layer/
global_simple_intercept.rs

1// Copyright 2023 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::{
16    ffi::{c_void, CStr},
17    fmt::Debug,
18    mem::MaybeUninit,
19};
20
21use ash::vk;
22
23use smallvec::SmallVec;
24use thiserror::Error;
25
26pub mod generated;
27pub use generated::*;
28
29use crate::vk_utils::ptr_as_uninit_mut;
30
31#[derive(Error, Debug)]
32pub enum TryFromExtensionError {
33    #[error("unknown extension `{0}`")]
34    UnknownExtension(String),
35}
36
37/// A union type of extensions and core API version.
38///
39/// In `vk.xml`, Vulkan commands and types are grouped under different API version and extensions.
40/// The tag of those group XML elements is `feature` or `extension`. One command will have only one
41/// single correspondent feature. This type is mostly used to tell if a command should be returned
42/// by `vkGet*ProcAddr` given the supported/enabled Vulkan API version and extensions.
43#[derive(PartialEq, Eq, PartialOrd, Ord, Clone)]
44pub enum Feature {
45    /// Vulkan core API interface.
46    Core(ApiVersion),
47    /// Vulkan extension interface.
48    Extension(Extension),
49}
50
51impl From<ApiVersion> for Feature {
52    fn from(value: ApiVersion) -> Self {
53        Self::Core(value)
54    }
55}
56
57impl From<Extension> for Feature {
58    fn from(value: Extension) -> Self {
59        Self::Extension(value)
60    }
61}
62
63/// Vulkan API version number.
64///
65/// Can be used to store the result decoded from
66/// [`VK_MAKE_API_VERSION`](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_MAKE_API_VERSION.html).
67#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
68pub struct ApiVersion {
69    /// The major version number. At most 7 bits.
70    pub major: u8,
71    /// The minor version number. At most 10 bits
72    pub minor: u16,
73}
74
75impl ApiVersion {
76    /// Vulkan version 1.0. The initial release of the Vulkan API.
77    ///
78    /// <https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_VERSION_1_0.html>
79    pub const V1_0: Self = Self { major: 1, minor: 0 };
80    /// Vulkan version 1.1.
81    ///
82    /// <https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_VERSION_1_1.html>
83    pub const V1_1: Self = Self { major: 1, minor: 1 };
84}
85
86impl From<u32> for ApiVersion {
87    fn from(value: u32) -> Self {
88        Self {
89            major: vk::api_version_major(value)
90                .try_into()
91                .expect("The major version must be no more than 7 bits."),
92            minor: vk::api_version_minor(value)
93                .try_into()
94                .expect("The minor version must be no more than 10 bits."),
95        }
96    }
97}
98
99impl From<ApiVersion> for u32 {
100    fn from(value: ApiVersion) -> Self {
101        vk::make_api_version(0, value.major.into(), value.minor.into(), 0)
102    }
103}
104
105pub(crate) struct VulkanCommand {
106    pub name: &'static str,
107    pub features: SmallVec<[Feature; 2]>,
108    pub hooked: bool,
109    pub proc: vk::PFN_vkVoidFunction,
110}
111
112fn get_instance_proc_addr_loader(
113    get_instance_proc_addr: vk::PFN_vkGetInstanceProcAddr,
114    instance: &ash::Instance,
115) -> impl Fn(&CStr) -> *const c_void + '_ {
116    move |name| {
117        // Safe because the VkInstance is valid, and a valid C string pointer is passed to the
118        // `p_name` parameter.
119        let fp = unsafe { get_instance_proc_addr(instance.handle(), name.as_ptr()) };
120        if let Some(fp) = fp {
121            return fp as *const _;
122        }
123        std::ptr::null()
124    }
125}
126
127fn get_device_proc_addr_loader(
128    get_device_proc_addr: vk::PFN_vkGetDeviceProcAddr,
129    device: &ash::Device,
130) -> impl Fn(&CStr) -> *const c_void + '_ {
131    move |name| {
132        // Safe because the VkDevice is valid, and a valid C string pointer is passed to the
133        // `p_name` parameter.
134        let fp = unsafe { get_device_proc_addr(device.handle(), name.as_ptr()) };
135        match fp {
136            Some(fp) => fp as *const _,
137            None => std::ptr::null(),
138        }
139    }
140}
141
142/// Converts a raw pointer to a reference of a slice of maybe-uninit. In contrast to
143/// `from_raw_parts_mut`, this does not require that the value has to be initialized and the input
144/// pointer can be null and may not be aligned if the input size is 0. If either `p_out_array` or
145/// `p_size` is null, [`None`] is returned. Otherwise, [`Some`] is returned.
146///
147/// # Safety
148/// Behavior is undefined if any of the following conditions are violated:
149///
150/// * `p_data` must be valid for writes for `size * mem::size_of::<T>()` many bytes, and it must be
151///   properly aligned. This means in particular, the entire memory range of this slice must be
152///   contained within a single allocated object! Slices can never span across multiple allocated
153///   objects. `p_data` must be be aligned if `size` is not 0.
154/// * The memory referenced by the returned slice must not be accessed through any other pointer
155///   (not derived from the return value) for the duration of lifetime 'a. Both read and write
156///   accesses are forbidden.
157/// * The total size `size * mem::size_of::<T>()` of the slice must be no larger than
158///   [`std::isize::MAX`], and adding that size to data must not “wrap around” the address space.
159///   See the safety documentation of
160///   [`pointer::offset`](https://doc.rust-lang.org/std/primitive.pointer.html#method.offset).
161/// * `p_size` must be either null or points to a valid data to read. See details at [`pointer::as_ref`](https://doc.rust-lang.org/std/primitive.pointer.html#safety).
162///
163/// # Panics
164/// * Panics if `size` can't be converted to usize.
165#[deny(unsafe_op_in_unsafe_fn)]
166unsafe fn maybe_uninit_slice_from_raw_parts_mut<'a, T>(
167    p_out_array: *mut T,
168    p_size: *const (impl TryInto<usize, Error = impl Debug> + Copy),
169) -> Option<&'a mut [MaybeUninit<T>]> {
170    let size: usize = unsafe { p_size.as_ref() }
171        .copied()?
172        .try_into()
173        .expect("size mut be within the range of usize");
174    if p_out_array.is_null() {
175        return None;
176    }
177    Some(unsafe { uninit_slice_from_raw_parts_mut(p_out_array, size) })
178}
179
180/// Forms a slice from a pointer and a length.
181///
182/// In contrast to [`std::slice::from_raw_parts`], this does not require that `data` must be
183/// non-null and unaligned for zero-length slice. If `data` is null, [`None`] is returned.
184///
185/// # Safety
186///
187/// If `len` is 0, there is no safety requirement.
188///
189/// If `len` is not 0, the following conditions shouldn't be violated:
190/// * `data` must be valid for reads for `len * mem::size_of::<T>()` many bytes, and it must be
191///   properly aligned. This means in particular: the entire memory range of this slice must be
192///   contained within a single allocated object! Slices can never span across multiple allocated
193///   objects.
194/// * `data` must point to `len` consecutive properly initialized values of type `T`.
195/// * The memory referenced by the returned slice must not be mutated for the duration of lifetime
196///   `'a`, except inside an `UnsafeCell`.
197/// * The total size `len * mem::size_of::<T>()` of the slice must be no larger than [`isize::MAX`],
198///   and adding that size to data must not “wrap around” the address space. See the safety
199///   documentation of
200///   [`pointer::offset`](https://doc.rust-lang.org/std/primitive.pointer.html#method.offset).
201///
202/// # Panics
203///
204/// Panics if `len` can't be converted to `uszie`.
205#[deny(unsafe_op_in_unsafe_fn)]
206unsafe fn maybe_slice_from_raw_parts<'a, T>(
207    data: *const T,
208    len: impl TryInto<usize, Error = impl Debug>,
209) -> Option<&'a [T]> {
210    if data.is_null() {
211        return None;
212    }
213    let len: usize = len
214        .try_into()
215        .expect("len mut be within the range of usize");
216    if len == 0 {
217        return Some(&[]);
218    }
219    Some(unsafe { std::slice::from_raw_parts(data, len) })
220}
221
222/// Converts a raw pointer to a reference of a slice of maybe-uninit. In contrast to
223/// `from_raw_parts_mut`, this does not require that the value has to be initialized and the input
224/// pointer can be null and may not be aligned if the input size is 0.
225///
226/// # Safety
227/// Behavior is undefined if any of the following conditions are violated:
228///
229/// * `p_data` must be valid for writes for `size * mem::size_of::<T>()` many bytes, and it must be
230///   properly aligned. This means in particular, the entire memory range of this slice must be
231///   contained within a single allocated object! Slices can never span across multiple allocated
232///   objects. `p_data` must be be aligned if `size` is not 0.
233/// * The memory referenced by the returned slice must not be accessed through any other pointer
234///   (not derived from the return value) for the duration of lifetime 'a. Both read and write
235///   accesses are forbidden.
236/// * The total size `size * mem::size_of::<T>()` of the slice must be no larger than
237///   [`std::isize::MAX`], and adding that size to data must not “wrap around” the address space.
238///   See the safety documentation of
239///   [`pointer::offset`](https://doc.rust-lang.org/std/primitive.pointer.html#method.offset).
240///
241/// # Panics
242/// * Panics if `size` is not 0 and `p_data` is null.
243/// * Panics if `size` can't be converted to usize.
244#[deny(unsafe_op_in_unsafe_fn)]
245unsafe fn uninit_slice_from_raw_parts_mut<'a, T>(
246    p_data: *mut T,
247    size: impl TryInto<usize, Error = impl Debug>,
248) -> &'a mut [MaybeUninit<T>] {
249    let size: usize = size
250        .try_into()
251        .expect("size mut be within the range of usize");
252    if size == 0 {
253        return &mut [];
254    }
255    let first_element =
256        unsafe { ptr_as_uninit_mut(p_data) }.expect("the input data pointer should not be null");
257    unsafe { std::slice::from_raw_parts_mut(first_element, size) }
258}
259
260#[deny(unsafe_op_in_unsafe_fn)]
261unsafe fn bool_iterator_from_raw_parts(
262    ptr: *const vk::Bool32,
263    size: impl TryInto<usize, Error = impl Debug>,
264) -> impl Iterator<Item = bool> {
265    let size: usize = size
266        .try_into()
267        .expect("size mut be within the range of usize");
268    let slice = if size == 0 {
269        &[]
270    } else {
271        unsafe { std::slice::from_raw_parts(ptr, size) }
272    };
273    slice.iter().map(|v| *v == vk::TRUE)
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn bool_iterator_from_raw_parts_result_should_match() {
282        let expected_value = vec![true, true, false, true, false];
283        let input = expected_value
284            .iter()
285            .map(|v| if *v { vk::TRUE } else { vk::FALSE })
286            .collect::<Vec<_>>();
287        let result = unsafe { bool_iterator_from_raw_parts(input.as_ptr(), expected_value.len()) }
288            .collect::<Vec<_>>();
289        assert_eq!(result, expected_value);
290    }
291
292    #[test]
293    fn bool_iterator_from_raw_parts_empty_with_null_ptr() {
294        let result =
295            unsafe { bool_iterator_from_raw_parts(std::ptr::null(), 0) }.collect::<Vec<_>>();
296        assert!(result.is_empty());
297    }
298
299    #[test]
300    fn maybe_uninit_slice_from_raw_parts_mut_result_should_match() {
301        const LEN: usize = 10;
302        let expected_value: [i32; LEN] = [81, 95, 43, 65, 34, 47, 65, 62, 47, 82];
303        let mut input = [MaybeUninit::<i32>::uninit(); LEN];
304        let input_ptr = input.as_mut_ptr() as *mut i32;
305        let output = unsafe { maybe_uninit_slice_from_raw_parts_mut(input_ptr, &LEN) }
306            .expect("for valid input should return in Some");
307        assert_eq!(output.len(), LEN);
308        for (i, output_element) in output.iter_mut().enumerate() {
309            output_element.write(expected_value[i]);
310        }
311
312        for (i, input_element) in input.iter().enumerate() {
313            assert_eq!(
314                *unsafe { input_element.assume_init_ref() },
315                expected_value[i]
316            );
317        }
318    }
319
320    #[test]
321    fn maybe_uninit_slice_from_raw_parts_mut_null_data_ptr() {
322        let input_ptr: *mut i32 = std::ptr::null_mut();
323        let output = unsafe { maybe_uninit_slice_from_raw_parts_mut(input_ptr, &10) };
324        assert!(output.is_none());
325    }
326
327    #[test]
328    fn maybe_uninit_slice_from_raw_parts_zero_length_unaligned_data_ptr() {
329        // Some address in the u8_array must be unaligned with u32.
330        let mut u8_array = [0u8; 2];
331        let input_ptrs = [
332            &mut u8_array[0] as *mut _ as *mut i32,
333            &mut u8_array[1] as *mut _ as *mut i32,
334        ];
335        for input_ptr in input_ptrs {
336            let output = unsafe { maybe_uninit_slice_from_raw_parts_mut(input_ptr, &0) }
337                .expect("for valid input pointer, Some should be returned");
338            assert!(output.is_empty());
339        }
340    }
341
342    #[test]
343    fn maybe_uninit_slice_from_raw_parts_mut_null_size_ptr() {
344        let mut data = MaybeUninit::<u32>::uninit();
345        let output = unsafe {
346            maybe_uninit_slice_from_raw_parts_mut(data.as_mut_ptr(), std::ptr::null::<usize>())
347        };
348        assert!(output.is_none());
349    }
350
351    #[test]
352    #[should_panic]
353    fn maybe_uninit_slice_from_raw_parts_mut_invalid_size_value() {
354        let mut data = MaybeUninit::<u32>::uninit();
355        unsafe { maybe_uninit_slice_from_raw_parts_mut(data.as_mut_ptr(), &-1) };
356    }
357
358    #[test]
359    fn uninit_slice_from_raw_parts_mut_zero_size_and_invalid_data_address() {
360        // Some address in the u8_array must be unaligned with u32.
361        let mut u8_array = [0u8; 2];
362        let input_ptrs = [
363            &mut u8_array[0] as *mut _ as *mut i32,
364            &mut u8_array[1] as *mut _ as *mut i32,
365        ];
366        for input_ptr in input_ptrs {
367            let output = unsafe { uninit_slice_from_raw_parts_mut(input_ptr, 0) };
368            assert!(output.is_empty());
369        }
370
371        let output = unsafe { uninit_slice_from_raw_parts_mut(std::ptr::null_mut::<u8>(), 0) };
372        assert!(output.is_empty());
373    }
374
375    #[test]
376    #[should_panic]
377    fn uninit_slice_from_raw_parts_mut_valid_size_and_null_data_address() {
378        unsafe { uninit_slice_from_raw_parts_mut(std::ptr::null_mut::<u8>(), 10) };
379    }
380
381    #[test]
382    #[should_panic]
383    fn uninit_slice_from_raw_parts_mut_invalid_size_value() {
384        let mut data = MaybeUninit::<u32>::uninit();
385        unsafe { uninit_slice_from_raw_parts_mut(data.as_mut_ptr(), -1) };
386    }
387
388    #[test]
389    fn uninit_slice_from_raw_parts_mut_result_should_match() {
390        const LEN: usize = 10;
391        let expected_value: [i32; LEN] = [14, 45, 60, 97, 35, 21, 13, 42, 11, 12];
392        let mut input = [MaybeUninit::<i32>::uninit(); LEN];
393        let input_ptr = input.as_mut_ptr() as *mut i32;
394        let output = unsafe { uninit_slice_from_raw_parts_mut(input_ptr, LEN) };
395        assert_eq!(output.len(), LEN);
396        for (i, output_element) in output.iter_mut().enumerate() {
397            output_element.write(expected_value[i]);
398        }
399
400        for (i, input_element) in input.iter().enumerate() {
401            assert_eq!(
402                *unsafe { input_element.assume_init_ref() },
403                expected_value[i]
404            );
405        }
406    }
407
408    #[test]
409    fn maybe_slice_from_raw_parts_null_data() {
410        let res = unsafe { maybe_slice_from_raw_parts(std::ptr::null::<u8>(), 0) };
411        assert!(res.is_none());
412        let res = unsafe { maybe_slice_from_raw_parts(std::ptr::null::<u8>(), 10) };
413        assert!(res.is_none());
414    }
415
416    #[test]
417    fn maybe_slice_from_raw_parts_zero_length_invalid_data_ptr() {
418        let u8_array = [0u8; 2];
419        // Some address must not be aligned.
420        let input_ptrs = u8_array
421            .iter()
422            .map(|element| element as *const _ as *const u32)
423            .collect::<Vec<_>>();
424        for input_ptr in input_ptrs {
425            let res = unsafe { maybe_slice_from_raw_parts(input_ptr, 0) }
426                .expect("should always return Some for non-null pointers");
427            assert!(res.is_empty());
428        }
429    }
430
431    #[test]
432    fn maybe_slice_from_raw_parts_reflects_original_array() {
433        let input: [u32; 10] = [47, 63, 14, 13, 8, 45, 52, 97, 21, 10];
434        let result = unsafe { maybe_slice_from_raw_parts(input.as_ptr(), input.len()) }
435            .expect("should always return Some for non-null pointers");
436        assert_eq!(result, &input);
437    }
438
439    #[test]
440    #[should_panic]
441    fn maybe_slice_from_raw_parts_bad_len() {
442        unsafe { maybe_slice_from_raw_parts(std::ptr::NonNull::<u8>::dangling().as_ptr(), -1) };
443    }
444
445    #[test]
446    fn extension_try_from_should_return_error_on_unknown_extension() {
447        let unknown_extension = "VK_UNKNOWN_unknown";
448        let err = Extension::try_from(unknown_extension).unwrap_err();
449        assert!(err.to_string().contains(unknown_extension));
450    }
451}