1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
use proc_macro2::TokenStream;
use quote::{format_ident, quote};

mod field_copy;
mod field_ty;
mod generics;
mod r#impl;

fn get_cuda_repr_ident(rust_repr_ident: &proc_macro2::Ident) -> proc_macro2::Ident {
    format_ident!("{}CudaRepresentation", rust_repr_ident)
}

#[expect(clippy::module_name_repetitions, clippy::too_many_lines)]
pub fn impl_rust_to_cuda(ast: &syn::DeriveInput) -> proc_macro::TokenStream {
    let (mut struct_fields_cuda, struct_semi_cuda) = if let syn::Data::Struct(s) = &ast.data {
        (s.fields.clone(), s.semi_token)
    } else {
        abort_call_site!("You can only derive the `RustToCuda` trait on structs for now.");
    };

    let struct_name = &ast.ident;
    let struct_name_cuda = get_cuda_repr_ident(struct_name);

    let (
        struct_attrs_cuda,
        struct_generics_cuda,
        struct_generics_cuda_async,
        struct_layout_attrs,
        r2c_async_impl,
        crate_path,
    ) = generics::expand_cuda_struct_generics_where_requested_in_attrs(ast);

    let mut combined_cuda_alloc_type: TokenStream = quote! {
        #crate_path::alloc::NoCudaAlloc
    };
    let mut combined_cuda_alloc_async_type: TokenStream = quote! {
        #crate_path::alloc::NoCudaAlloc
    };
    let mut r2c_field_declarations: Vec<TokenStream> = Vec::new();
    let mut r2c_field_async_declarations: Vec<TokenStream> = Vec::new();
    let mut r2c_field_async_completions: Vec<syn::Ident> = Vec::new();
    let mut r2c_field_initialisations: Vec<TokenStream> = Vec::new();
    let mut r2c_field_destructors: Vec<TokenStream> = Vec::new();
    let mut r2c_field_async_destructors: Vec<TokenStream> = Vec::new();
    let mut r2c_field_async_completion_calls: Vec<TokenStream> = Vec::new();

    let mut c2r_field_initialisations: Vec<TokenStream> = Vec::new();

    match struct_fields_cuda {
        syn::Fields::Named(syn::FieldsNamed {
            named: ref mut fields,
            ..
        })
        | syn::Fields::Unnamed(syn::FieldsUnnamed {
            unnamed: ref mut fields,
            ..
        }) => {
            let mut r2c_field_destructors_reverse: Vec<TokenStream> = Vec::new();
            let mut r2c_field_async_destructors_reverse: Vec<TokenStream> = Vec::new();

            for (field_index, field) in fields.iter_mut().enumerate() {
                let cuda_repr_field_ty =
                    field_ty::swap_field_type_and_filter_attrs(&crate_path, field);

                (combined_cuda_alloc_type, combined_cuda_alloc_async_type) =
                    field_copy::impl_field_copy_init_and_expand_alloc_type(
                        &crate_path,
                        field,
                        field_index,
                        &cuda_repr_field_ty,
                        combined_cuda_alloc_type,
                        combined_cuda_alloc_async_type,
                        &mut r2c_field_declarations,
                        &mut r2c_field_async_declarations,
                        &mut r2c_field_async_completions,
                        &mut r2c_field_initialisations,
                        &mut r2c_field_destructors_reverse,
                        &mut r2c_field_async_destructors_reverse,
                        &mut r2c_field_async_completion_calls,
                        &mut c2r_field_initialisations,
                    );
            }

            // The fields must be deallocated in the reverse order of their allocation
            r2c_field_destructors.extend(r2c_field_destructors_reverse.into_iter().rev());
            r2c_field_async_destructors
                .extend(r2c_field_async_destructors_reverse.into_iter().rev());
        },
        syn::Fields::Unit => (),
    }

    let cuda_struct_declaration = r#impl::cuda_struct_declaration(
        &crate_path,
        &struct_attrs_cuda,
        &struct_layout_attrs,
        &ast.vis,
        &struct_name_cuda,
        &struct_generics_cuda,
        &struct_fields_cuda,
        struct_semi_cuda,
    );

    let rust_to_cuda_trait_impl = r#impl::rust_to_cuda_trait(
        &crate_path,
        struct_name,
        &struct_name_cuda,
        &struct_generics_cuda,
        &struct_fields_cuda,
        &combined_cuda_alloc_type,
        &r2c_field_declarations,
        &r2c_field_initialisations,
        &r2c_field_destructors,
    );

    let rust_to_cuda_async_trait_impl = if r2c_async_impl {
        r#impl::rust_to_cuda_async_trait(
            &crate_path,
            struct_name,
            &struct_name_cuda,
            &struct_generics_cuda_async,
            &struct_fields_cuda,
            &combined_cuda_alloc_async_type,
            &r2c_field_async_declarations,
            &r2c_field_async_completions,
            &r2c_field_initialisations,
            &r2c_field_async_destructors,
            &r2c_field_async_completion_calls,
        )
    } else {
        TokenStream::new()
    };

    let cuda_as_rust_trait_impl = r#impl::cuda_as_rust_trait(
        &crate_path,
        struct_name,
        &struct_name_cuda,
        &struct_generics_cuda,
        &struct_fields_cuda,
        &c2r_field_initialisations,
    );

    (quote! {
        #cuda_struct_declaration

        #rust_to_cuda_trait_impl

        #rust_to_cuda_async_trait_impl

        #cuda_as_rust_trait_impl
    })
    .into()
}