#[cfg(any(feature = "host", feature = "device"))]
use core::{
    mem::MaybeUninit,
    ops::{Deref, DerefMut},
};
use const_type_layout::TypeLayout;
use const_type_layout::TypeGraphLayout;
use crate::safety::{PortableBitSemantics, StackOnly};
#[cfg(any(feature = "host", feature = "device"))]
use crate::{
    alloc::NoCudaAlloc,
    lend::{RustToCuda, RustToCudaAsync},
};
#[cfg(feature = "host")]
use crate::{
    alloc::{CombinedCudaAlloc, CudaAlloc},
    utils::ffi::DeviceAccessible,
    utils::r#async::{Async, CompletionFnMut},
};
#[cfg(any(feature = "host", feature = "device"))]
use self::common::CudaExchangeBufferCudaRepresentation;
#[cfg(any(feature = "host", feature = "device"))]
mod common;
#[cfg(feature = "device")]
mod device;
#[cfg(feature = "host")]
mod host;
#[cfg(any(feature = "host", feature = "device"))]
#[expect(clippy::module_name_repetitions)]
pub struct CudaExchangeBuffer<
    T: StackOnly + PortableBitSemantics + TypeGraphLayout,
    const M2D: bool,
    const M2H: bool,
> {
    #[cfg(feature = "host")]
    inner: host::CudaExchangeBufferHost<T, M2D, M2H>,
    #[cfg(all(feature = "device", not(feature = "host")))]
    inner: device::CudaExchangeBufferDevice<T, M2D, M2H>,
}
#[cfg(any(feature = "host", feature = "device"))]
unsafe impl<
        T: StackOnly + PortableBitSemantics + TypeGraphLayout + Sync,
        const M2D: bool,
        const M2H: bool,
    > Sync for CudaExchangeBuffer<T, M2D, M2H>
{
}
#[cfg(feature = "host")]
impl<
        T: Clone + StackOnly + PortableBitSemantics + TypeGraphLayout,
        const M2D: bool,
        const M2H: bool,
    > CudaExchangeBuffer<T, M2D, M2H>
{
    pub fn new(elem: &T, capacity: usize) -> rustacuda::error::CudaResult<Self> {
        Ok(Self {
            inner: host::CudaExchangeBufferHost::new(elem, capacity)?,
        })
    }
}
#[cfg(feature = "host")]
impl<T: StackOnly + PortableBitSemantics + TypeGraphLayout, const M2D: bool, const M2H: bool>
    CudaExchangeBuffer<T, M2D, M2H>
{
    pub fn from_vec(vec: Vec<T>) -> rustacuda::error::CudaResult<Self> {
        Ok(Self {
            inner: host::CudaExchangeBufferHost::from_vec(vec)?,
        })
    }
}
#[cfg(any(feature = "host", feature = "device"))]
impl<T: StackOnly + PortableBitSemantics + TypeGraphLayout, const M2D: bool, const M2H: bool> Deref
    for CudaExchangeBuffer<T, M2D, M2H>
{
    type Target = [CudaExchangeItem<T, M2D, M2H>];
    fn deref(&self) -> &Self::Target {
        &self.inner
    }
}
#[cfg(any(feature = "host", feature = "device"))]
impl<T: StackOnly + PortableBitSemantics + TypeGraphLayout, const M2D: bool, const M2H: bool>
    DerefMut for CudaExchangeBuffer<T, M2D, M2H>
{
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.inner
    }
}
#[cfg(any(feature = "host", feature = "device"))]
unsafe impl<T: StackOnly + PortableBitSemantics + TypeGraphLayout, const M2D: bool, const M2H: bool>
    RustToCuda for CudaExchangeBuffer<T, M2D, M2H>
{
    type CudaAllocation = NoCudaAlloc;
    type CudaRepresentation = CudaExchangeBufferCudaRepresentation<T, M2D, M2H>;
    #[cfg(feature = "host")]
    unsafe fn borrow<A: CudaAlloc>(
        &self,
        alloc: A,
    ) -> rustacuda::error::CudaResult<(
        DeviceAccessible<Self::CudaRepresentation>,
        CombinedCudaAlloc<Self::CudaAllocation, A>,
    )> {
        self.inner.borrow(alloc)
    }
    #[cfg(feature = "host")]
    unsafe fn restore<A: CudaAlloc>(
        &mut self,
        alloc: CombinedCudaAlloc<Self::CudaAllocation, A>,
    ) -> rustacuda::error::CudaResult<A> {
        self.inner.restore(alloc)
    }
}
#[cfg(any(feature = "host", feature = "device"))]
unsafe impl<T: StackOnly + PortableBitSemantics + TypeGraphLayout, const M2D: bool, const M2H: bool>
    RustToCudaAsync for CudaExchangeBuffer<T, M2D, M2H>
{
    type CudaAllocationAsync = NoCudaAlloc;
    #[cfg(feature = "host")]
    unsafe fn borrow_async<'stream, A: CudaAlloc>(
        &self,
        alloc: A,
        stream: crate::host::Stream<'stream>,
    ) -> rustacuda::error::CudaResult<(
        Async<'_, 'stream, DeviceAccessible<Self::CudaRepresentation>>,
        CombinedCudaAlloc<Self::CudaAllocationAsync, A>,
    )> {
        self.inner.borrow_async(alloc, stream)
    }
    #[cfg(feature = "host")]
    unsafe fn restore_async<'a, 'stream, A: CudaAlloc, O>(
        this: owning_ref::BoxRefMut<'a, O, Self>,
        alloc: CombinedCudaAlloc<Self::CudaAllocationAsync, A>,
        stream: crate::host::Stream<'stream>,
    ) -> rustacuda::error::CudaResult<(
        Async<'a, 'stream, owning_ref::BoxRefMut<'a, O, Self>, CompletionFnMut<'a, Self>>,
        A,
    )> {
        let this_backup = unsafe { std::mem::ManuallyDrop::new(std::ptr::read(&this)) };
        let (r#async, alloc_tail) = host::CudaExchangeBufferHost::restore_async(
            this.map_mut(|this| &mut this.inner),
            alloc,
            stream,
        )?;
        let (inner, on_completion) = unsafe { r#async.unwrap_unchecked()? };
        std::mem::forget(inner);
        let this = std::mem::ManuallyDrop::into_inner(this_backup);
        if let Some(on_completion) = on_completion {
            let r#async = Async::<_, CompletionFnMut<'a, Self>>::pending(
                this,
                stream,
                Box::new(|this: &mut Self| on_completion(&mut this.inner)),
            )?;
            Ok((r#async, alloc_tail))
        } else {
            let r#async = Async::ready(this, stream);
            Ok((r#async, alloc_tail))
        }
    }
}
#[repr(transparent)]
#[derive(Clone, Copy, TypeLayout)]
pub struct CudaExchangeItem<
    T: StackOnly + PortableBitSemantics + TypeGraphLayout,
    const M2D: bool,
    const M2H: bool,
>(T);
impl<T: StackOnly + PortableBitSemantics + TypeGraphLayout, const M2D: bool>
    CudaExchangeItem<T, M2D, true>
{
    #[cfg(feature = "host")]
    pub const fn read(&self) -> &T {
        &self.0
    }
    #[cfg(feature = "device")]
    pub fn write(&mut self, value: T) {
        self.0 = value;
    }
}
impl<T: StackOnly + PortableBitSemantics + TypeGraphLayout, const M2H: bool>
    CudaExchangeItem<T, true, M2H>
{
    #[cfg(feature = "device")]
    pub const fn read(&self) -> &T {
        &self.0
    }
    #[cfg(feature = "host")]
    pub fn write(&mut self, value: T) {
        self.0 = value;
    }
}
impl<T: StackOnly + PortableBitSemantics + TypeGraphLayout> AsMut<T>
    for CudaExchangeItem<T, true, true>
{
    fn as_mut(&mut self) -> &mut T {
        &mut self.0
    }
}
impl<T: StackOnly + PortableBitSemantics + TypeGraphLayout> CudaExchangeItem<T, false, true> {
    #[cfg(feature = "host")]
    pub const fn as_scratch(&self) -> &T {
        &self.0
    }
    #[cfg(feature = "host")]
    pub fn as_scratch_mut(&mut self) -> &mut T {
        &mut self.0
    }
}
impl<T: StackOnly + PortableBitSemantics + TypeGraphLayout> CudaExchangeItem<T, true, false> {
    #[cfg(feature = "device")]
    pub const fn as_scratch(&self) -> &T {
        &self.0
    }
    #[cfg(feature = "device")]
    pub fn as_scratch_mut(&mut self) -> &mut T {
        &mut self.0
    }
}
impl<T: StackOnly + PortableBitSemantics + TypeGraphLayout> CudaExchangeItem<T, true, false> {
    #[cfg(feature = "host")]
    pub const fn as_uninit(&self) -> &MaybeUninit<T> {
        unsafe { &*core::ptr::from_ref(self).cast() }
    }
    #[cfg(feature = "host")]
    pub fn as_uninit_mut(&mut self) -> &mut MaybeUninit<T> {
        unsafe { &mut *core::ptr::from_mut(self).cast() }
    }
}
impl<T: StackOnly + PortableBitSemantics + TypeGraphLayout> CudaExchangeItem<T, false, true> {
    #[cfg(feature = "device")]
    pub const fn as_uninit(&self) -> &MaybeUninit<T> {
        unsafe { &*core::ptr::from_ref(self).cast() }
    }
    #[cfg(feature = "device")]
    pub fn as_uninit_mut(&mut self) -> &mut MaybeUninit<T> {
        unsafe { &mut *core::ptr::from_mut(self).cast() }
    }
}