use proc_macro2::TokenStream;
use quote::quote;
use syn::spanned::Spanned;
use crate::kernel::{
wrapper::{FuncIdent, FunctionInputs, ImplGenerics},
KERNEL_TYPE_LAYOUT_HASH_SEED_IDENT, KERNEL_TYPE_LAYOUT_IDENT, KERNEL_TYPE_USE_END_CANARY,
KERNEL_TYPE_USE_START_CANARY,
};
pub(in super::super) fn quote_cuda_wrapper(
crate_path: &syn::Path,
inputs @ FunctionInputs { func_inputs }: &FunctionInputs,
func @ FuncIdent {
func_ident,
func_ident_hash,
..
}: &FuncIdent,
impl_generics @ ImplGenerics {
impl_generics: generics,
..
}: &ImplGenerics,
func_attrs: &[syn::Attribute],
func_params: &[syn::Ident],
) -> TokenStream {
let (ffi_inputs, ffi_types) =
specialise_ffi_input_types(crate_path, inputs, func, impl_generics);
let ffi_types_len = ffi_types.len();
let ffi_param_ptx_jit_wrap = func_inputs.iter().enumerate().rev().fold(
quote! {
#func_ident(#(#func_params),*)
},
|inner, (i, syn::PatType { pat, ty, .. })| {
let specialised_ty = quote::quote_spanned! { ty.span()=>
#crate_path::device::specialise_kernel_param_type!(#ty for #generics in #func_ident)
};
quote::quote_spanned! { ty.span()=>
unsafe {
<
#specialised_ty as #crate_path::kernel::CudaKernelParameter
>::with_ffi_as_device::<_, #i>(
#pat, |#pat: <
#specialised_ty as #crate_path::kernel::CudaKernelParameter
>::DeviceType::<'_>| { #inner }
)
}
}
},
);
let private_func_params = func_params
.iter()
.map(|param| {
let mut private = syn::Ident::clone(param);
private.set_span(proc_macro::Span::def_site().into());
private
})
.collect::<Vec<_>>();
let ffi_signature_ident = syn::Ident::new(KERNEL_TYPE_LAYOUT_IDENT, func_ident.span());
let ffi_signature_hash_seed_ident =
syn::Ident::new(KERNEL_TYPE_LAYOUT_HASH_SEED_IDENT, func_ident.span());
let ffi_signature_seed =
std::hash::BuildHasher::hash_one(&std::hash::RandomState::new(), 0xd236_cae6_da79_5f77_u64);
quote! {
#[cfg(target_os = "cuda")]
#[#crate_path::device::specialise_kernel_function(#func_ident)]
#[no_mangle]
#[allow(unused_unsafe)]
#(#func_attrs)*
pub unsafe extern "ptx-kernel" fn #func_ident_hash(#(#ffi_inputs),*) {
extern "C" { #(
#[allow(dead_code)]
#[deny(improper_ctypes)]
static #private_func_params: #ffi_types;
)* }
unsafe {
#crate_path::utils::shared::init();
}
unsafe { ::core::arch::asm!(#KERNEL_TYPE_USE_START_CANARY); }
#[no_mangle]
static #ffi_signature_hash_seed_ident: [u8; 8] = #ffi_signature_seed.to_le_bytes();
unsafe { ::core::ptr::read_volatile(&#ffi_signature_hash_seed_ident) };
#[no_mangle]
static #ffi_signature_ident: [[u8; 8]; #ffi_types_len] = [#(
#crate_path::deps::const_type_layout::hash_type_graph::<#ffi_types>(#ffi_signature_seed).to_le_bytes()
),*];
unsafe { ::core::ptr::read_volatile(&#ffi_signature_ident) };
unsafe { ::core::arch::asm!(#KERNEL_TYPE_USE_END_CANARY); }
#ffi_param_ptx_jit_wrap
}
}
}
fn specialise_ffi_input_types(
crate_path: &syn::Path,
FunctionInputs { func_inputs }: &FunctionInputs,
FuncIdent { func_ident, .. }: &FuncIdent,
ImplGenerics { impl_generics, .. }: &ImplGenerics,
) -> (Vec<syn::FnArg>, Vec<syn::Type>) {
func_inputs
.iter()
.map(|syn::PatType {
attrs,
pat,
colon_token,
ty,
}| {
let specialised_ty = quote::quote_spanned! { ty.span()=>
#crate_path::device::specialise_kernel_param_type!(#ty for #impl_generics in #func_ident)
};
let ffi_ty: syn::Type = syn::parse_quote_spanned! { ty.span()=>
<#specialised_ty as #crate_path::kernel::CudaKernelParameter>::FfiType<'static, 'static>
};
let ffi_param = syn::FnArg::Typed(syn::PatType {
attrs: attrs.clone(),
ty: Box::new(ffi_ty.clone()),
pat: pat.clone(),
colon_token: *colon_token,
});
(ffi_param, ffi_ty)
})
.unzip()
}