1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
use std::{
    collections::HashMap,
    ffi::{CStr, CString},
};

use super::{
    regex::{
        const_base_register_regex, const_load_instruction_regex, const_marker_regex, register_regex,
    },
    PtxElement, PtxJITCompiler, PtxLoadWidth,
};

impl PtxJITCompiler {
    #[must_use]
    pub fn new(ptx: &CStr) -> Self {
        let ptx = ptx.to_bytes();

        let mut const_markers: HashMap<&[u8], usize> = HashMap::new();

        // Find injected rust-cuda-const-markers which identify dummy register rxx
        for const_marker in const_marker_regex().captures_iter(ptx) {
            if let Some(tmpreg) = const_marker.name("tmpreg").map(|s| s.as_bytes()) {
                if let Some(param) = const_marker
                    .name("param")
                    .map(|s| s.as_bytes())
                    .and_then(|b| std::str::from_utf8(b).ok())
                    .and_then(|s| s.parse().ok())
                {
                    const_markers.insert(tmpreg, param);
                }
            }
        }
        // const_markers now contains a mapping rxx => param index

        let mut const_base_registers: HashMap<&[u8], usize> = HashMap::new();

        // Find base register ryy which was used in `ld.global.u32 rxx, [ryy];`
        for const_base_register in const_base_register_regex().captures_iter(ptx) {
            if let Some(tmpreg) = const_base_register.name("tmpreg").map(|s| s.as_bytes()) {
                if let Some(param) = const_markers.get(tmpreg) {
                    if let Some(basereg) = const_base_register.name("basereg").map(|s| s.as_bytes())
                    {
                        const_base_registers.insert(basereg, *param);
                    }
                }
            }
        }
        // const_base_registers now contains a mapping ryy => param index

        let mut from_index = 0_usize;
        let mut last_slice = Vec::new();

        let mut ptx_slices: Vec<PtxElement> = Vec::new();

        // Iterate over all load from base register with offset instructions
        for const_load_instruction in const_load_instruction_regex().captures_iter(ptx) {
            // Only consider instructions where the base register is ryy
            if let Some(basereg) = const_load_instruction.name("basereg").map(|s| s.as_bytes()) {
                if let Some(param) = const_base_registers.get(basereg) {
                    if let Some(loadwidth) = match const_load_instruction
                        .name("loadwidth")
                        .map(|s| s.as_bytes())
                    {
                        Some(b"8") => Some(PtxLoadWidth::B1),
                        Some(b"16") => Some(PtxLoadWidth::B2),
                        Some(b"32") => Some(PtxLoadWidth::B4),
                        Some(b"64") => Some(PtxLoadWidth::B8),
                        _ => None,
                    } {
                        if let Some(constreg) = const_load_instruction
                            .name("constreg")
                            .map(|s| s.as_bytes())
                        {
                            if let Some(loadoffset) = std::str::from_utf8(
                                const_load_instruction
                                    .name("loadoffset")
                                    .map_or(b"0", |s| s.as_bytes()),
                            )
                            .ok()
                            .and_then(|s| s.parse().ok())
                            {
                                if let Some((range, instruction)) = const_load_instruction
                                    .name("instruction")
                                    .map(|s| (s.range(), s.as_bytes()))
                                {
                                    // Store the PTX source code before the load instruction
                                    last_slice.extend_from_slice(&ptx[from_index..range.start]);

                                    ptx_slices.push(PtxElement::CopiedSource {
                                        ptx: std::mem::take(&mut last_slice).into_boxed_slice(),
                                    });

                                    from_index = range.end;

                                    // Store the load instruction with extracted parameters to
                                    //  generate a constant load if requested
                                    ptx_slices.push(PtxElement::ConstLoad {
                                        ptx: instruction.to_owned().into_boxed_slice(),
                                        parameter_index: *param,
                                        byte_offset: loadoffset,
                                        load_width: loadwidth,
                                        registers: register_regex()
                                            .captures_iter(constreg)
                                            .filter_map(|m| {
                                                m.name("register").map(|s| {
                                                    s.as_bytes().to_owned().into_boxed_slice()
                                                })
                                            })
                                            .collect::<Vec<_>>()
                                            .into_boxed_slice(),
                                    });
                                }
                            }
                        }
                    }
                }
            }
        }

        // Store the remainder of the PTX source code
        last_slice.extend_from_slice(&ptx[from_index..ptx.len()]);

        if !last_slice.is_empty() {
            ptx_slices.push(PtxElement::CopiedSource {
                ptx: last_slice.into_boxed_slice(),
            });
        }

        // Create the `PtxJITCompiler` which also caches the last PTX version
        Self {
            ptx_slices: ptx_slices.into_boxed_slice(),
            last_arguments: None,
            last_ptx: unsafe { CString::from_vec_unchecked(ptx.to_owned()) },
        }
    }
}