use core::alloc::Layout;
use const_type_layout::TypeGraphLayout;
#[expect(clippy::module_name_repetitions)]
#[repr(transparent)]
pub struct ThreadBlockSharedSlice<T: 'static + TypeGraphLayout> {
shared: *mut [T],
}
impl<T: 'static + TypeGraphLayout> ThreadBlockSharedSlice<T> {
#[cfg(feature = "host")]
#[must_use]
pub fn new_uninit_with_len(len: usize) -> Self {
Self {
shared: Self::dangling_slice_with_len(len),
}
}
#[cfg(feature = "host")]
#[must_use]
pub fn with_len(mut self, len: usize) -> Self {
self.shared = Self::dangling_slice_with_len(len);
self
}
#[cfg(feature = "host")]
#[must_use]
pub fn with_len_mut(&mut self, len: usize) -> &mut Self {
self.shared = Self::dangling_slice_with_len(len);
self
}
#[cfg(feature = "host")]
fn dangling_slice_with_len(len: usize) -> *mut [T] {
core::ptr::slice_from_raw_parts_mut(core::ptr::NonNull::dangling().as_ptr(), len)
}
#[must_use]
pub fn len(&self) -> usize {
core::ptr::metadata(self.shared)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn layout(&self) -> Layout {
unsafe { Layout::for_value_raw(self.shared) }
}
#[cfg(feature = "device")]
#[must_use]
pub const fn as_mut_ptr(&self) -> *mut T {
self.shared.cast()
}
#[cfg(feature = "device")]
#[must_use]
pub const fn as_mut_slice_ptr(&self) -> *mut [T] {
self.shared
}
#[cfg(feature = "device")]
#[inline]
#[must_use]
pub unsafe fn index_mut_unchecked<I: core::slice::SliceIndex<[T]>>(
&self,
index: I,
) -> *mut <I as core::slice::SliceIndex<[T]>>::Output {
self.shared.get_unchecked_mut(index)
}
}
#[cfg(feature = "device")]
impl<T: 'static + TypeGraphLayout> ThreadBlockSharedSlice<T> {
pub(crate) unsafe fn with_uninit_for_len<F: FnOnce(&mut Self) -> Q, Q>(
len: usize,
inner: F,
) -> Q {
let base: *mut u8;
unsafe {
core::arch::asm!(
"mov.u64 {base}, %rust_cuda_dynamic_shared;",
base = out(reg64) base,
);
}
let aligned_base = base.byte_add(base.align_offset(core::mem::align_of::<T>()));
let data: *mut T = aligned_base.cast();
let new_base = data.add(len).cast::<u8>();
unsafe {
core::arch::asm!(
"mov.u64 %rust_cuda_dynamic_shared, {new_base};",
new_base = in(reg64) new_base,
);
}
let shared = core::ptr::slice_from_raw_parts_mut(data, len);
inner(&mut Self { shared })
}
}
#[cfg(feature = "device")]
pub unsafe fn init() {
#[expect(clippy::multiple_unsafe_ops_per_block)]
unsafe {
core::arch::asm!(".reg .u64 %rust_cuda_dynamic_shared;");
core::arch::asm!(
"cvta.shared.u64 %rust_cuda_dynamic_shared, rust_cuda_dynamic_shared_base;",
);
}
}
#[cfg(feature = "device")]
core::arch::global_asm!(".extern .shared .align 8 .b8 rust_cuda_dynamic_shared_base[];");
#[cfg(feature = "host")]
pub struct SharedMemorySize {
last_align: usize,
total_size: usize,
}
#[cfg(feature = "host")]
impl SharedMemorySize {
#[must_use]
pub const fn new() -> Self {
Self {
last_align: 8,
total_size: 0,
}
}
pub fn add(&mut self, layout: core::alloc::Layout) {
if layout.align() > self.last_align {
let pessimistic_padding = layout.align() - self.last_align;
self.total_size += pessimistic_padding;
}
self.last_align = layout.align();
self.total_size += layout.size();
}
pub const fn total(self) -> usize {
self.total_size
}
}