#[cfg(feature = "host")]
use std::{borrow::BorrowMut, future::Future, future::IntoFuture, marker::PhantomData, task::Poll};
#[cfg(feature = "host")]
use rustacuda::{
error::CudaError, error::CudaResult, event::Event, event::EventFlags,
stream::StreamWaitEventFlags,
};
#[cfg(feature = "host")]
use crate::host::{CudaDropWrapper, Stream};
#[cfg(feature = "host")]
pub struct NoCompletion;
#[cfg(feature = "host")]
pub type CompletionFnMut<'a, T> = Box<dyn FnOnce(&mut T) -> CudaResult<()> + 'a>;
#[cfg(feature = "host")]
pub trait Completion<T: ?Sized + BorrowMut<Self::Completed>>: sealed::Sealed {
type Completed: ?Sized;
fn no_op() -> Self;
#[doc(hidden)]
fn synchronize_on_drop(&self) -> bool;
#[expect(clippy::missing_errors_doc)] fn complete(self, completed: &mut Self::Completed) -> CudaResult<()>;
}
#[cfg(feature = "host")]
mod sealed {
pub trait Sealed {}
}
#[cfg(feature = "host")]
impl<T: ?Sized> Completion<T> for NoCompletion {
type Completed = T;
#[inline]
fn no_op() -> Self {
Self
}
#[inline]
fn synchronize_on_drop(&self) -> bool {
false
}
#[inline]
fn complete(self, _completed: &mut Self::Completed) -> CudaResult<()> {
Ok(())
}
}
#[cfg(feature = "host")]
impl sealed::Sealed for NoCompletion {}
#[cfg(feature = "host")]
impl<'a, T: ?Sized + BorrowMut<B>, B: ?Sized> Completion<T> for CompletionFnMut<'a, B> {
type Completed = B;
#[inline]
fn no_op() -> Self {
Box::new(|_value| Ok(()))
}
#[inline]
fn synchronize_on_drop(&self) -> bool {
true
}
#[inline]
fn complete(self, completed: &mut Self::Completed) -> CudaResult<()> {
(self)(completed)
}
}
#[cfg(feature = "host")]
impl<'a, T: ?Sized> sealed::Sealed for CompletionFnMut<'a, T> {}
#[cfg(feature = "host")]
impl<T: ?Sized + BorrowMut<C::Completed>, C: Completion<T>> Completion<T> for Option<C> {
type Completed = C::Completed;
#[inline]
fn no_op() -> Self {
None
}
#[inline]
fn synchronize_on_drop(&self) -> bool {
self.as_ref().map_or(false, Completion::synchronize_on_drop)
}
#[inline]
fn complete(self, completed: &mut Self::Completed) -> CudaResult<()> {
self.map_or(Ok(()), |completion| completion.complete(completed))
}
}
#[cfg(feature = "host")]
impl<C> sealed::Sealed for Option<C> {}
#[cfg(feature = "host")]
pub struct Async<'a, 'stream, T: BorrowMut<C::Completed>, C: Completion<T> = NoCompletion> {
stream: Stream<'stream>,
value: T,
status: AsyncStatus<'a, T, C>,
_capture: PhantomData<&'a ()>,
}
#[cfg(feature = "host")]
enum AsyncStatus<'a, T: BorrowMut<C::Completed>, C: Completion<T>> {
Processing {
receiver: oneshot::Receiver<CudaResult<()>>,
completion: C,
event: Option<CudaDropWrapper<Event>>,
_capture: PhantomData<&'a T>,
},
Completed {
result: CudaResult<()>,
},
}
#[cfg(feature = "host")]
impl<'a, 'stream, T: BorrowMut<C::Completed>, C: Completion<T>> Async<'a, 'stream, T, C> {
#[must_use]
pub const fn ready(value: T, stream: Stream<'stream>) -> Self {
Self {
stream,
value,
status: AsyncStatus::Completed { result: Ok(()) },
_capture: PhantomData::<&'a ()>,
}
}
pub fn pending(value: T, stream: Stream<'stream>, completion: C) -> CudaResult<Self> {
let (sender, receiver) = oneshot::channel();
stream.add_callback(Box::new(|result| std::mem::drop(sender.send(result))))?;
Ok(Self {
stream,
value,
status: AsyncStatus::Processing {
receiver,
completion,
event: None,
_capture: PhantomData::<&'a T>,
},
_capture: PhantomData::<&'a ()>,
})
}
pub fn synchronize(self) -> CudaResult<T> {
let (_stream, mut value, status) = self.destructure_into_parts();
let (receiver, completion) = match status {
AsyncStatus::Completed { result } => return result.map(|()| value),
AsyncStatus::Processing {
receiver,
completion,
event: _,
_capture,
} => (receiver, completion),
};
match receiver.recv() {
Ok(Ok(())) => (),
Ok(Err(err)) => return Err(err),
Err(oneshot::RecvError) => return Err(CudaError::AlreadyAcquired),
}
completion.complete(value.borrow_mut())?;
Ok(value)
}
pub fn move_to_stream<'stream_new>(
self,
stream: Stream<'stream_new>,
) -> CudaResult<Async<'a, 'stream_new, T, C>> {
let (old_stream, mut value, status) = self.destructure_into_parts();
let completion = match status {
AsyncStatus::Completed { result } => {
result?;
C::no_op()
},
AsyncStatus::Processing {
receiver,
completion,
event: _,
_capture,
} => match receiver.try_recv() {
Ok(Ok(())) => {
completion.complete(value.borrow_mut())?;
C::no_op()
},
Ok(Err(err)) => return Err(err),
Err(oneshot::TryRecvError::Empty) => completion,
Err(oneshot::TryRecvError::Disconnected) => return Err(CudaError::AlreadyAcquired),
},
};
let event = CudaDropWrapper::from(Event::new(EventFlags::DISABLE_TIMING)?);
event.record(&old_stream)?;
stream.wait_event(&event, StreamWaitEventFlags::DEFAULT)?;
let (sender, receiver) = oneshot::channel();
stream.add_callback(Box::new(|result| std::mem::drop(sender.send(result))))?;
Ok(Async {
stream,
value,
status: AsyncStatus::Processing {
receiver,
completion,
event: Some(event),
_capture: PhantomData::<&'a T>,
},
_capture: PhantomData::<&'a ()>,
})
}
#[expect(clippy::missing_errors_doc)] pub unsafe fn unwrap_unchecked(self) -> CudaResult<(T, Option<C>)> {
let (_stream, value, status) = self.destructure_into_parts();
match status {
AsyncStatus::Completed { result: Ok(()) } => Ok((value, None)),
AsyncStatus::Completed { result: Err(err) } => Err(err),
AsyncStatus::Processing {
receiver: _,
completion,
event: _,
_capture,
} => Ok((value, Some(completion))),
}
}
pub const fn as_ref(&self) -> AsyncProj<'_, 'stream, &T> {
unsafe { AsyncProj::new(&self.value, None) }
}
pub fn as_mut(&mut self) -> AsyncProj<'_, 'stream, &mut T> {
unsafe {
AsyncProj::new(
&mut self.value,
Some(Box::new(|| {
let completion = match &mut self.status {
AsyncStatus::Completed { result } => {
(*result)?;
C::no_op()
},
AsyncStatus::Processing {
receiver: _,
completion,
event: _,
_capture,
} => std::mem::replace(completion, C::no_op()),
};
let event = CudaDropWrapper::from(Event::new(EventFlags::DISABLE_TIMING)?);
let (sender, receiver) = oneshot::channel();
self.stream
.add_callback(Box::new(|result| std::mem::drop(sender.send(result))))?;
event.record(&self.stream)?;
self.status = AsyncStatus::Processing {
receiver,
completion,
event: Some(event),
_capture: PhantomData::<&'a T>,
};
Ok(())
})),
)
}
}
#[must_use]
fn destructure_into_parts(self) -> (Stream<'stream>, T, AsyncStatus<'a, T, C>) {
let this = std::mem::ManuallyDrop::new(self);
let stream = this.stream;
let value = unsafe { std::ptr::read(&this.value) };
let status = unsafe { std::ptr::read(&this.status) };
(stream, value, status)
}
}
#[cfg(feature = "host")]
impl<
'a,
'stream,
T: crate::safety::PortableBitSemantics + const_type_layout::TypeGraphLayout,
C: Completion<crate::host::HostAndDeviceConstRef<'a, T>>,
> Async<'a, 'stream, crate::host::HostAndDeviceConstRef<'a, T>, C>
where
crate::host::HostAndDeviceConstRef<'a, T>: BorrowMut<C::Completed>,
{
pub const fn extract_ref(
&self,
) -> AsyncProj<'_, 'stream, crate::host::HostAndDeviceConstRef<'_, T>> {
unsafe { AsyncProj::new(self.value.as_ref(), None) }
}
}
#[cfg(feature = "host")]
impl<
'a,
'stream,
T: crate::safety::PortableBitSemantics + const_type_layout::TypeGraphLayout,
C: Completion<crate::host::HostAndDeviceMutRef<'a, T>>,
> Async<'a, 'stream, crate::host::HostAndDeviceMutRef<'a, T>, C>
where
crate::host::HostAndDeviceMutRef<'a, T>: BorrowMut<C::Completed>,
{
pub fn extract_ref(&self) -> AsyncProj<'_, 'stream, crate::host::HostAndDeviceConstRef<'_, T>> {
unsafe { AsyncProj::new(self.value.as_ref(), None) }
}
pub fn extract_mut(
&mut self,
) -> AsyncProj<'_, 'stream, crate::host::HostAndDeviceMutRef<'_, T>> {
unsafe {
AsyncProj::new(
self.value.as_mut(),
Some(Box::new(|| {
let completion = match &mut self.status {
AsyncStatus::Completed { result } => {
(*result)?;
C::no_op()
},
AsyncStatus::Processing {
receiver: _,
completion,
event: _,
_capture,
} => std::mem::replace(completion, C::no_op()),
};
let event = CudaDropWrapper::from(Event::new(EventFlags::DISABLE_TIMING)?);
let (sender, receiver) = oneshot::channel();
self.stream
.add_callback(Box::new(|result| std::mem::drop(sender.send(result))))?;
event.record(&self.stream)?;
self.status = AsyncStatus::Processing {
receiver,
completion,
event: Some(event),
_capture: PhantomData,
};
Ok(())
})),
)
}
}
}
#[cfg(feature = "host")]
impl<'a, 'stream, T: BorrowMut<C::Completed>, C: Completion<T>> Drop for Async<'a, 'stream, T, C> {
fn drop(&mut self) {
let AsyncStatus::Processing {
receiver,
completion,
event: _,
_capture,
} = std::mem::replace(&mut self.status, AsyncStatus::Completed { result: Ok(()) })
else {
return;
};
if completion.synchronize_on_drop() && receiver.recv() == Ok(Ok(())) {
let _ = completion.complete(self.value.borrow_mut());
}
}
}
#[cfg(feature = "host")]
struct AsyncFuture<'a, 'stream, T: BorrowMut<C::Completed>, C: Completion<T>> {
_stream: PhantomData<Stream<'stream>>,
value: Option<T>,
completion: Option<C>,
status: AsyncStatus<'a, T, NoCompletion>,
}
#[cfg(feature = "host")]
impl<'a, 'stream, T: BorrowMut<C::Completed>, C: Completion<T>> Future
for AsyncFuture<'a, 'stream, T, C>
{
type Output = CudaResult<T>;
fn poll(
self: core::pin::Pin<&mut Self>,
cx: &mut core::task::Context<'_>,
) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
match &mut this.status {
AsyncStatus::Processing {
receiver,
completion: _,
event: _,
_capture,
} => match std::pin::Pin::new(receiver).poll(cx) {
Poll::Ready(Ok(Ok(()))) => (),
Poll::Ready(Ok(Err(err))) => return Poll::Ready(Err(err)),
Poll::Ready(Err(oneshot::RecvError)) => {
return Poll::Ready(Err(CudaError::AlreadyAcquired))
},
Poll::Pending => return Poll::Pending,
},
AsyncStatus::Completed { result: Ok(()) } => (),
AsyncStatus::Completed { result: Err(err) } => return Poll::Ready(Err(*err)),
}
let Some(mut value) = this.value.take() else {
return Poll::Ready(Err(CudaError::AlreadyAcquired));
};
if let Some(completion) = this.completion.take() {
completion.complete(value.borrow_mut())?;
}
Poll::Ready(Ok(value))
}
}
#[cfg(feature = "host")]
impl<'a, 'stream, T: BorrowMut<C::Completed>, C: Completion<T>> IntoFuture
for Async<'a, 'stream, T, C>
{
type Output = CudaResult<T>;
type IntoFuture = impl Future<Output = Self::Output>;
fn into_future(self) -> Self::IntoFuture {
let (_stream, value, status) = self.destructure_into_parts();
let (completion, status): (Option<C>, AsyncStatus<'a, T, NoCompletion>) = match status {
AsyncStatus::Completed { result } => {
(None, AsyncStatus::Completed::<T, NoCompletion> { result })
},
AsyncStatus::Processing {
receiver,
completion,
event,
_capture,
} => (
Some(completion),
AsyncStatus::Processing::<T, NoCompletion> {
receiver,
completion: NoCompletion,
event,
_capture: PhantomData::<&'a T>,
},
),
};
AsyncFuture {
_stream: PhantomData::<Stream<'stream>>,
value: Some(value),
completion,
status,
}
}
}
#[cfg(feature = "host")]
impl<'a, 'stream, T: BorrowMut<C::Completed>, C: Completion<T>> Drop
for AsyncFuture<'a, 'stream, T, C>
{
fn drop(&mut self) {
let Some(mut value) = self.value.take() else {
return;
};
let AsyncStatus::Processing {
receiver,
completion: NoCompletion,
event: _,
_capture,
} = std::mem::replace(&mut self.status, AsyncStatus::Completed { result: Ok(()) })
else {
return;
};
let Some(completion) = self.completion.take() else {
return;
};
if completion.synchronize_on_drop() && receiver.recv() == Ok(Ok(())) {
let _ = completion.complete(value.borrow_mut());
}
}
}
#[cfg(feature = "host")]
#[expect(clippy::module_name_repetitions)]
pub struct AsyncProj<'a, 'stream, T: 'a> {
_capture: PhantomData<&'a ()>,
_stream: PhantomData<Stream<'stream>>,
value: T,
use_callback: Option<Box<dyn FnMut() -> CudaResult<()> + 'a>>,
}
#[cfg(feature = "host")]
impl<'a, 'stream, T: 'a> AsyncProj<'a, 'stream, T> {
#[must_use]
pub(crate) const unsafe fn new(
value: T,
use_callback: Option<Box<dyn FnMut() -> CudaResult<()> + 'a>>,
) -> Self {
Self {
_capture: PhantomData::<&'a ()>,
_stream: PhantomData::<Stream<'stream>>,
value,
use_callback,
}
}
pub(crate) unsafe fn unwrap_unchecked(self) -> T {
self.value
}
#[expect(clippy::type_complexity)]
pub(crate) unsafe fn unwrap_unchecked_with_use(
self,
) -> (T, Option<Box<dyn FnMut() -> CudaResult<()> + 'a>>) {
(self.value, self.use_callback)
}
}
#[cfg(feature = "host")]
impl<'a, 'stream, T: 'a> AsyncProj<'a, 'stream, T> {
#[must_use]
pub const fn proj_ref<'b>(&'b self) -> AsyncProj<'b, 'stream, &'b T>
where
'a: 'b,
{
AsyncProj {
_capture: PhantomData::<&'b ()>,
_stream: PhantomData::<Stream<'stream>>,
value: &self.value,
use_callback: None,
}
}
#[must_use]
pub fn proj_mut<'b>(&'b mut self) -> AsyncProj<'b, 'stream, &'b mut T>
where
'a: 'b,
{
AsyncProj {
_capture: PhantomData::<&'b ()>,
_stream: PhantomData::<Stream<'stream>>,
value: &mut self.value,
use_callback: self.use_callback.as_mut().map(|use_callback| {
let use_callback: Box<dyn FnMut() -> CudaResult<()>> = Box::new(use_callback);
use_callback
}),
}
}
pub(crate) fn record_mut_use(&mut self) -> CudaResult<()> {
self.use_callback
.as_mut()
.map_or(Ok(()), |use_callback| use_callback())
}
}
#[cfg(feature = "host")]
impl<'a, 'stream, T: 'a> AsyncProj<'a, 'stream, &'a T> {
#[must_use]
pub const fn as_ref<'b>(&'b self) -> AsyncProj<'b, 'stream, &'b T>
where
'a: 'b,
{
AsyncProj {
_capture: PhantomData::<&'b ()>,
_stream: PhantomData::<Stream<'stream>>,
value: self.value,
use_callback: None,
}
}
pub(crate) const unsafe fn unwrap_ref_unchecked(&self) -> &T {
self.value
}
}
#[cfg(feature = "host")]
impl<'a, 'stream, T: 'a> AsyncProj<'a, 'stream, &'a mut T> {
#[must_use]
pub fn as_ref<'b>(&'b self) -> AsyncProj<'b, 'stream, &'b T>
where
'a: 'b,
{
AsyncProj {
_capture: PhantomData::<&'b ()>,
_stream: PhantomData::<Stream<'stream>>,
value: self.value,
use_callback: None,
}
}
#[must_use]
pub fn as_mut<'b>(&'b mut self) -> AsyncProj<'b, 'stream, &'b mut T>
where
'a: 'b,
{
AsyncProj {
_capture: PhantomData::<&'b ()>,
_stream: PhantomData::<Stream<'stream>>,
value: self.value,
use_callback: self.use_callback.as_mut().map(|use_callback| {
let use_callback: Box<dyn FnMut() -> CudaResult<()>> = Box::new(use_callback);
use_callback
}),
}
}
#[expect(dead_code)] pub(crate) unsafe fn unwrap_ref_unchecked(&self) -> &T {
self.value
}
#[expect(dead_code)] pub(crate) unsafe fn unwrap_mut_unchecked(&mut self) -> &mut T {
self.value
}
}