1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#[repr(transparent)]
pub struct ThreadBlockShared<T: 'static> {
    #[cfg_attr(not(feature = "device"), allow(dead_code))]
    shared: *mut T,
}

impl<T: 'static> ThreadBlockShared<T> {
    #[cfg(any(feature = "host", feature = "device"))]
    #[must_use]
    #[expect(clippy::inline_always)]
    #[cfg_attr(feature = "host", expect(clippy::missing_const_for_fn))]
    #[inline(always)]
    pub fn new_uninit() -> Self {
        #[cfg(feature = "host")]
        {
            Self {
                shared: core::ptr::NonNull::dangling().as_ptr(),
            }
        }

        #[cfg(feature = "device")]
        {
            let shared: *mut T;

            unsafe {
                core::arch::asm!(
                    ".shared .align {align} .b8 {reg}_rust_cuda_static_shared[{size}];",
                    "cvta.shared.u64 {reg}, {reg}_rust_cuda_static_shared;",
                    reg = out(reg64) shared,
                    align = const(core::mem::align_of::<T>()),
                    size = const(core::mem::size_of::<T>()),
                );
            }

            Self { shared }
        }
    }

    #[cfg(feature = "device")]
    #[must_use]
    pub const fn as_mut_ptr(&self) -> *mut T {
        self.shared
    }
}

impl<T: 'static, const N: usize> ThreadBlockShared<[T; N]> {
    #[cfg(feature = "device")]
    /// # Safety
    ///
    /// The provided `index` must not be out of bounds.
    #[inline]
    #[must_use]
    pub unsafe fn index_mut_unchecked<I: core::slice::SliceIndex<[T]>>(
        &self,
        index: I,
    ) -> *mut <I as core::slice::SliceIndex<[T]>>::Output {
        core::ptr::slice_from_raw_parts_mut(self.shared.cast::<T>(), N).get_unchecked_mut(index)
    }
}