use core::mem::MaybeUninit;
use const_type_layout::{TypeGraphLayout, TypeLayout};
#[cfg(feature = "host")]
use rustacuda::error::CudaResult;
use crate::{
lend::{CudaAsRust, RustToCuda, RustToCudaAsync, RustToCudaProxy},
safety::PortableBitSemantics,
utils::{adapter::RustToCudaWithPortableBitCopySemantics, ffi::DeviceAccessible},
};
#[cfg(feature = "host")]
use crate::{
alloc::{CombinedCudaAlloc, CudaAlloc},
utils::r#async::{Async, CompletionFnMut, NoCompletion},
};
#[doc(hidden)]
#[expect(clippy::module_name_repetitions)]
#[derive(TypeLayout)]
#[repr(C)]
pub struct OptionCudaRepresentation<T: CudaAsRust> {
maybe: MaybeUninit<DeviceAccessible<T>>,
present: bool,
}
unsafe impl<T: RustToCuda> RustToCuda for Option<T> {
type CudaAllocation = Option<<T as RustToCuda>::CudaAllocation>;
type CudaRepresentation = OptionCudaRepresentation<<T as RustToCuda>::CudaRepresentation>;
#[cfg(feature = "host")]
unsafe fn borrow<A: CudaAlloc>(
&self,
alloc: A,
) -> CudaResult<(
DeviceAccessible<Self::CudaRepresentation>,
CombinedCudaAlloc<Self::CudaAllocation, A>,
)> {
let (cuda_repr, alloc) = match self {
None => (
OptionCudaRepresentation {
maybe: MaybeUninit::uninit(),
present: false,
},
CombinedCudaAlloc::new(None, alloc),
),
Some(value) => {
let (cuda_repr, alloc) = value.borrow(alloc)?;
let (alloc_front, alloc_tail) = alloc.split();
(
OptionCudaRepresentation {
maybe: MaybeUninit::new(cuda_repr),
present: true,
},
CombinedCudaAlloc::new(Some(alloc_front), alloc_tail),
)
},
};
Ok((DeviceAccessible::from(cuda_repr), alloc))
}
#[cfg(feature = "host")]
unsafe fn restore<A: CudaAlloc>(
&mut self,
alloc: CombinedCudaAlloc<Self::CudaAllocation, A>,
) -> CudaResult<A> {
let (alloc_front, alloc_tail) = alloc.split();
match (self, alloc_front) {
(Some(value), Some(alloc_front)) => {
value.restore(CombinedCudaAlloc::new(alloc_front, alloc_tail))
},
_ => Ok(alloc_tail),
}
}
}
unsafe impl<T: RustToCudaAsync> RustToCudaAsync for Option<T> {
type CudaAllocationAsync = Option<<T as RustToCudaAsync>::CudaAllocationAsync>;
#[cfg(feature = "host")]
unsafe fn borrow_async<'stream, A: CudaAlloc>(
&self,
alloc: A,
stream: crate::host::Stream<'stream>,
) -> CudaResult<(
Async<'_, 'stream, DeviceAccessible<Self::CudaRepresentation>>,
CombinedCudaAlloc<Self::CudaAllocationAsync, A>,
)> {
let (cuda_repr, alloc) = match self {
None => (
Async::ready(
DeviceAccessible::from(OptionCudaRepresentation {
maybe: MaybeUninit::uninit(),
present: false,
}),
stream,
),
CombinedCudaAlloc::new(None, alloc),
),
Some(value) => {
let (cuda_repr, alloc) = value.borrow_async(alloc, stream)?;
let (cuda_repr, completion) = unsafe { cuda_repr.unwrap_unchecked()? };
let (alloc_front, alloc_tail) = alloc.split();
let alloc = CombinedCudaAlloc::new(Some(alloc_front), alloc_tail);
let option_cuda_repr = DeviceAccessible::from(OptionCudaRepresentation {
maybe: MaybeUninit::new(cuda_repr),
present: true,
});
let r#async = if matches!(completion, Some(NoCompletion)) {
Async::pending(option_cuda_repr, stream, NoCompletion)?
} else {
Async::ready(option_cuda_repr, stream)
};
(r#async, alloc)
},
};
Ok((cuda_repr, alloc))
}
#[cfg(feature = "host")]
unsafe fn restore_async<'a, 'stream, A: CudaAlloc, O>(
mut this: owning_ref::BoxRefMut<'a, O, Self>,
alloc: CombinedCudaAlloc<Self::CudaAllocationAsync, A>,
stream: crate::host::Stream<'stream>,
) -> CudaResult<(
Async<'a, 'stream, owning_ref::BoxRefMut<'a, O, Self>, CompletionFnMut<'a, Self>>,
A,
)> {
let (alloc_front, alloc_tail) = alloc.split();
if let (Some(_), Some(alloc_front)) = (&mut *this, alloc_front) {
let this_backup = unsafe { std::mem::ManuallyDrop::new(std::ptr::read(&this)) };
let (r#async, alloc_tail) = RustToCudaAsync::restore_async(
this.map_mut(|value| unsafe { value.as_mut().unwrap_unchecked() }),
CombinedCudaAlloc::new(alloc_front, alloc_tail),
stream,
)?;
let (value, on_completion) = unsafe { r#async.unwrap_unchecked()? };
std::mem::forget(value);
let this = std::mem::ManuallyDrop::into_inner(this_backup);
if let Some(on_completion) = on_completion {
let r#async = Async::<_, CompletionFnMut<'a, Self>>::pending(
this,
stream,
Box::new(|this: &mut Self| {
if let Some(value) = this {
on_completion(value)?;
}
Ok(())
}),
)?;
Ok((r#async, alloc_tail))
} else {
let r#async = Async::ready(this, stream);
Ok((r#async, alloc_tail))
}
} else {
let r#async = Async::ready(this, stream);
Ok((r#async, alloc_tail))
}
}
}
unsafe impl<T: CudaAsRust> CudaAsRust for OptionCudaRepresentation<T> {
type RustRepresentation = Option<<T as CudaAsRust>::RustRepresentation>;
#[cfg(feature = "device")]
unsafe fn as_rust(this: &DeviceAccessible<Self>) -> Self::RustRepresentation {
if this.present {
Some(CudaAsRust::as_rust(this.maybe.assume_init_ref()))
} else {
None
}
}
}
impl<T: Copy + PortableBitSemantics + TypeGraphLayout> RustToCudaProxy<Option<T>>
for Option<RustToCudaWithPortableBitCopySemantics<T>>
{
fn from_ref(val: &Option<T>) -> &Self {
unsafe { &*core::ptr::from_ref(val).cast() }
}
fn from_mut(val: &mut Option<T>) -> &mut Self {
unsafe { &mut *core::ptr::from_mut(val).cast() }
}
fn into(self) -> Option<T> {
self.map(RustToCudaWithPortableBitCopySemantics::into_inner)
}
}