use std::ops::{Deref, DerefMut};
use rustacuda::{
error::CudaResult,
memory::{AsyncCopyDestination, CopyDestination, DeviceBox, LockedBox},
};
use crate::{
alloc::{EmptyCudaAlloc, NoCudaAlloc},
host::{CudaDropWrapper, HostAndDeviceConstRef, HostAndDeviceMutRef, Stream},
lend::{RustToCuda, RustToCudaAsync},
safety::SafeMutableAliasing,
utils::{
adapter::DeviceCopyWithPortableBitSemantics,
ffi::DeviceAccessible,
r#async::{Async, AsyncProj, CompletionFnMut, NoCompletion},
},
};
pub struct ExchangeWrapperOnHost<T: RustToCuda<CudaAllocation: EmptyCudaAlloc>> {
value: Box<T>,
device_box: CudaDropWrapper<
DeviceBox<
DeviceCopyWithPortableBitSemantics<
DeviceAccessible<<T as RustToCuda>::CudaRepresentation>,
>,
>,
>,
locked_cuda_repr: CudaDropWrapper<
LockedBox<
DeviceCopyWithPortableBitSemantics<
DeviceAccessible<<T as RustToCuda>::CudaRepresentation>,
>,
>,
>,
}
pub struct ExchangeWrapperOnDevice<T: RustToCuda<CudaAllocation: EmptyCudaAlloc>> {
value: Box<T>,
device_box: CudaDropWrapper<
DeviceBox<
DeviceCopyWithPortableBitSemantics<
DeviceAccessible<<T as RustToCuda>::CudaRepresentation>,
>,
>,
>,
locked_cuda_repr: CudaDropWrapper<
LockedBox<
DeviceCopyWithPortableBitSemantics<
DeviceAccessible<<T as RustToCuda>::CudaRepresentation>,
>,
>,
>,
}
impl<T: RustToCuda<CudaAllocation: EmptyCudaAlloc>> ExchangeWrapperOnHost<T> {
pub fn new(value: T) -> CudaResult<Self> {
let device_box = CudaDropWrapper::from(unsafe { DeviceBox::uninitialized() }?);
let (cuda_repr, _null_alloc) = unsafe { value.borrow(NoCudaAlloc) }?;
let locked_cuda_repr = unsafe {
let mut uninit = CudaDropWrapper::from(LockedBox::<
DeviceCopyWithPortableBitSemantics<
DeviceAccessible<<T as RustToCuda>::CudaRepresentation>,
>,
>::uninitialized()?);
uninit
.as_mut_ptr()
.write(DeviceCopyWithPortableBitSemantics::from(cuda_repr));
uninit
};
Ok(Self {
value: Box::new(value),
device_box,
locked_cuda_repr,
})
}
pub fn move_to_device(mut self) -> CudaResult<ExchangeWrapperOnDevice<T>> {
let (cuda_repr, null_alloc) = unsafe { self.value.borrow(NoCudaAlloc) }?;
**self.locked_cuda_repr = DeviceCopyWithPortableBitSemantics::from(cuda_repr);
self.device_box.copy_from(&**self.locked_cuda_repr)?;
let _: NoCudaAlloc = null_alloc.into();
Ok(ExchangeWrapperOnDevice {
value: self.value,
device_box: self.device_box,
locked_cuda_repr: self.locked_cuda_repr,
})
}
}
impl<T: RustToCudaAsync<CudaAllocationAsync: EmptyCudaAlloc, CudaAllocation: EmptyCudaAlloc>>
ExchangeWrapperOnHost<T>
{
#[expect(clippy::needless_lifetimes)] pub fn move_to_device_async<'stream>(
mut self,
stream: Stream<'stream>,
) -> CudaResult<Async<'static, 'stream, ExchangeWrapperOnDevice<T>, NoCompletion>> {
let (cuda_repr, _null_alloc) = unsafe { self.value.borrow_async(NoCudaAlloc, stream) }?;
let (cuda_repr, _completion): (_, Option<NoCompletion>) =
unsafe { cuda_repr.unwrap_unchecked()? };
**self.locked_cuda_repr = DeviceCopyWithPortableBitSemantics::from(cuda_repr);
unsafe {
self.device_box
.async_copy_from(&*self.locked_cuda_repr, &stream)
}?;
Async::pending(
ExchangeWrapperOnDevice {
value: self.value,
device_box: self.device_box,
locked_cuda_repr: self.locked_cuda_repr,
},
stream,
NoCompletion,
)
}
}
impl<T: RustToCuda<CudaAllocation: EmptyCudaAlloc>> Deref for ExchangeWrapperOnHost<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.value
}
}
impl<T: RustToCuda<CudaAllocation: EmptyCudaAlloc>> DerefMut for ExchangeWrapperOnHost<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.value
}
}
impl<T: RustToCuda<CudaAllocation: EmptyCudaAlloc>> ExchangeWrapperOnDevice<T> {
pub fn move_to_host(mut self) -> CudaResult<ExchangeWrapperOnHost<T>> {
let null_alloc = NoCudaAlloc.into();
let _null_alloc: NoCudaAlloc = unsafe { self.value.restore(null_alloc) }?;
Ok(ExchangeWrapperOnHost {
value: self.value,
device_box: self.device_box,
locked_cuda_repr: self.locked_cuda_repr,
})
}
#[must_use]
pub fn as_ref(
&self,
) -> HostAndDeviceConstRef<DeviceAccessible<<T as RustToCuda>::CudaRepresentation>> {
unsafe {
HostAndDeviceConstRef::new_unchecked(
&self.device_box,
(**self.locked_cuda_repr).into_ref(),
)
}
}
}
impl<T: RustToCudaAsync<CudaAllocationAsync: EmptyCudaAlloc, CudaAllocation: EmptyCudaAlloc>>
ExchangeWrapperOnDevice<T>
{
#[expect(clippy::needless_lifetimes)] pub fn move_to_host_async<'stream>(
self,
stream: Stream<'stream>,
) -> CudaResult<
Async<
'static,
'stream,
ExchangeWrapperOnHost<T>,
CompletionFnMut<'static, ExchangeWrapperOnHost<T>>,
>,
> {
let null_alloc = NoCudaAlloc.into();
let value = owning_ref::BoxRefMut::new(self.value);
let (r#async, _null_alloc): (_, NoCudaAlloc) =
unsafe { RustToCudaAsync::restore_async(value, null_alloc, stream) }?;
let (value, on_complete) = unsafe { r#async.unwrap_unchecked()? };
let value = value.into_owner();
if let Some(on_complete) = on_complete {
Async::<_, CompletionFnMut<ExchangeWrapperOnHost<T>>>::pending(
ExchangeWrapperOnHost {
value,
device_box: self.device_box,
locked_cuda_repr: self.locked_cuda_repr,
},
stream,
Box::new(|on_host: &mut ExchangeWrapperOnHost<T>| on_complete(&mut on_host.value)),
)
} else {
Ok(Async::ready(
ExchangeWrapperOnHost {
value,
device_box: self.device_box,
locked_cuda_repr: self.locked_cuda_repr,
},
stream,
))
}
}
}
impl<
'a,
'stream,
T: RustToCudaAsync<CudaAllocationAsync: EmptyCudaAlloc, CudaAllocation: EmptyCudaAlloc>,
> Async<'a, 'stream, ExchangeWrapperOnDevice<T>, NoCompletion>
{
pub fn move_to_host_async(
self,
stream: Stream<'stream>,
) -> CudaResult<
Async<
'static,
'stream,
ExchangeWrapperOnHost<T>,
CompletionFnMut<'static, ExchangeWrapperOnHost<T>>,
>,
> {
let (this, completion): (_, Option<NoCompletion>) = unsafe { self.unwrap_unchecked()? };
let null_alloc = NoCudaAlloc.into();
let value = owning_ref::BoxRefMut::new(this.value);
let (r#async, _null_alloc): (_, NoCudaAlloc) =
unsafe { RustToCudaAsync::restore_async(value, null_alloc, stream) }?;
let (value, on_complete) = unsafe { r#async.unwrap_unchecked()? };
let value = value.into_owner();
let on_host = ExchangeWrapperOnHost {
value,
device_box: this.device_box,
locked_cuda_repr: this.locked_cuda_repr,
};
if let Some(on_complete) = on_complete {
Async::<_, CompletionFnMut<ExchangeWrapperOnHost<T>>>::pending(
on_host,
stream,
Box::new(|on_host: &mut ExchangeWrapperOnHost<T>| on_complete(&mut on_host.value)),
)
} else if matches!(completion, Some(NoCompletion)) {
Async::<_, CompletionFnMut<ExchangeWrapperOnHost<T>>>::pending(
on_host,
stream,
Box::new(|_on_host: &mut ExchangeWrapperOnHost<T>| Ok(())),
)
} else {
Ok(Async::ready(on_host, stream))
}
}
#[must_use]
pub fn as_ref_async(
&self,
) -> AsyncProj<
'_,
'stream,
HostAndDeviceConstRef<DeviceAccessible<<T as RustToCuda>::CudaRepresentation>>,
> {
let this = unsafe { self.as_ref().unwrap_unchecked() };
unsafe {
AsyncProj::new(
HostAndDeviceConstRef::new_unchecked(
&*(this.device_box),
(**(this.locked_cuda_repr)).into_ref(),
),
None,
)
}
}
#[must_use]
pub fn as_mut_async(
&mut self,
) -> AsyncProj<
'_,
'stream,
HostAndDeviceMutRef<'_, DeviceAccessible<<T as RustToCuda>::CudaRepresentation>>,
>
where
T: SafeMutableAliasing,
{
let (this, use_callback) = unsafe { self.as_mut().unwrap_unchecked_with_use() };
unsafe {
AsyncProj::new(
HostAndDeviceMutRef::new_unchecked(
&mut *(this.device_box),
(**(this.locked_cuda_repr)).into_mut(),
),
use_callback,
)
}
}
}