1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
use std::env::VarError;

use proc_macro::TokenStream;
use quote::quote;

use crate::kernel::CHECK_SPECIALISATION;

#[expect(clippy::module_name_repetitions)]
pub fn specialise_kernel_function(attr: TokenStream, func: TokenStream) -> TokenStream {
    let mut func: syn::ItemFn = syn::parse(func).unwrap_or_else(|err| {
        abort_call_site!(
            "#[specialise_kernel_function(...)] must be wrapped around a function: {:?}",
            err
        )
    });

    let kernel: syn::Ident = match syn::parse(attr) {
        Ok(kernel) => kernel,
        Err(err) => abort_call_site!(
            "#[specialise_kernel_function(KERNEL)] expects KERNEL identifier: {:?}",
            err
        ),
    };

    let crate_name = proc_macro::tracked_env::var("CARGO_CRATE_NAME")
        .unwrap_or_else(|err| abort_call_site!("Failed to read crate name: {:?}", err));

    let specialisation_var = format!(
        "RUST_CUDA_DERIVE_SPECIALISE_{}_{}",
        crate_name.to_uppercase(),
        kernel.to_string().to_uppercase()
    );

    func.sig.ident = match proc_macro::tracked_env::var(&specialisation_var).as_deref() {
        Ok("") => quote::format_ident!("{}_kernel", func.sig.ident),
        Ok(CHECK_SPECIALISATION) => {
            let func_ident = quote::format_ident!("{}_{CHECK_SPECIALISATION}", func.sig.ident);

            return (quote! {
                #[cfg(target_os = "cuda")]
                #[no_mangle]
                pub unsafe extern "ptx-kernel" fn #func_ident() {}
            })
            .into();
        },
        Ok(specialisation) => {
            quote::format_ident!(
                "{}_kernel_{:016x}",
                func.sig.ident,
                seahash::hash(specialisation.as_bytes())
            )
        },
        Err(err @ VarError::NotUnicode(_)) => abort_call_site!(
            "Failed to read specialisation from {:?}: {:?}",
            &specialisation_var,
            err
        ),
        Err(VarError::NotPresent) => return quote!().into(),
    };

    (quote! { #func }).into()
}