1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
use proc_macro2::TokenStream;
use quote::quote;

use crate::kernel::wrapper::{DeclGenerics, FuncIdent, FunctionInputs, ImplGenerics};

pub(in super::super) fn quote_host_kernel_ty(
    crate_path: &syn::Path,
    DeclGenerics {
        generic_kernel_params,
        generic_start_token,
        generic_close_token,
        ..
    }: &DeclGenerics,
    ImplGenerics { ty_generics, .. }: &ImplGenerics,
    FunctionInputs { func_inputs }: &FunctionInputs,
    FuncIdent { func_ident, .. }: &FuncIdent,
    func_params: &[syn::Ident],
    func_attrs: &[syn::Attribute],
) -> TokenStream {
    let cuda_kernel_param_tys = func_inputs
        .iter()
        .map(|syn::PatType { ty, .. }| &**ty)
        .collect::<Vec<_>>();

    let launcher = syn::Ident::new("launcher", proc_macro2::Span::mixed_site());

    let full_generics = generic_kernel_params
        .iter()
        .map(|param| match param {
            syn::GenericParam::Type(syn::TypeParam { ident, .. })
            | syn::GenericParam::Const(syn::ConstParam { ident, .. }) => quote!(#ident),
            syn::GenericParam::Lifetime(syn::LifetimeParam { lifetime, .. }) => quote!(#lifetime),
        })
        .collect::<Vec<_>>();

    let mut private_func_ident = syn::Ident::clone(func_ident);
    private_func_ident.set_span(proc_macro::Span::def_site().into());

    let ty_turbofish = ty_generics.as_turbofish();

    quote! {
        #[cfg(not(target_os = "cuda"))]
        #[allow(non_camel_case_types)]
        pub type #func_ident #generic_start_token
            #generic_kernel_params
        #generic_close_token = impl Fn(
            &mut #crate_path::kernel::Launcher<#func_ident #generic_start_token
                #(#full_generics),*
            #generic_close_token>,
            #(#cuda_kernel_param_tys),*
        );

        #[cfg(not(target_os = "cuda"))]
        #(#func_attrs)*
        #[allow(clippy::too_many_arguments)]
        #[allow(clippy::used_underscore_binding)]
        fn #private_func_ident #generic_start_token
            #generic_kernel_params
        #generic_close_token (
            #launcher: &mut #crate_path::kernel::Launcher<#func_ident #generic_start_token
                #(#full_generics),*
            #generic_close_token>,
            #func_inputs
        ) {
            let _: #func_ident <#(#full_generics),*> = #private_func_ident #ty_turbofish;

            #(
                let _ = #func_params;
            )*
        }
    }
}