1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
use std::{
collections::HashMap,
ffi::{CStr, CString},
};
use super::{
regex::{
const_base_register_regex, const_load_instruction_regex, const_marker_regex, register_regex,
},
PtxElement, PtxJITCompiler, PtxLoadWidth,
};
impl PtxJITCompiler {
#[must_use]
pub fn new(ptx: &CStr) -> Self {
let ptx = ptx.to_bytes();
let mut const_markers: HashMap<&[u8], usize> = HashMap::new();
// Find injected rust-cuda-const-markers which identify dummy register rxx
for const_marker in const_marker_regex().captures_iter(ptx) {
if let Some(tmpreg) = const_marker.name("tmpreg").map(|s| s.as_bytes()) {
if let Some(param) = const_marker
.name("param")
.map(|s| s.as_bytes())
.and_then(|b| std::str::from_utf8(b).ok())
.and_then(|s| s.parse().ok())
{
const_markers.insert(tmpreg, param);
}
}
}
// const_markers now contains a mapping rxx => param index
let mut const_base_registers: HashMap<&[u8], usize> = HashMap::new();
// Find base register ryy which was used in `ld.global.u32 rxx, [ryy];`
for const_base_register in const_base_register_regex().captures_iter(ptx) {
if let Some(tmpreg) = const_base_register.name("tmpreg").map(|s| s.as_bytes()) {
if let Some(param) = const_markers.get(tmpreg) {
if let Some(basereg) = const_base_register.name("basereg").map(|s| s.as_bytes())
{
const_base_registers.insert(basereg, *param);
}
}
}
}
// const_base_registers now contains a mapping ryy => param index
let mut from_index = 0_usize;
let mut last_slice = Vec::new();
let mut ptx_slices: Vec<PtxElement> = Vec::new();
// Iterate over all load from base register with offset instructions
for const_load_instruction in const_load_instruction_regex().captures_iter(ptx) {
// Only consider instructions where the base register is ryy
if let Some(basereg) = const_load_instruction.name("basereg").map(|s| s.as_bytes()) {
if let Some(param) = const_base_registers.get(basereg) {
if let Some(loadwidth) = match const_load_instruction
.name("loadwidth")
.map(|s| s.as_bytes())
{
Some(b"8") => Some(PtxLoadWidth::B1),
Some(b"16") => Some(PtxLoadWidth::B2),
Some(b"32") => Some(PtxLoadWidth::B4),
Some(b"64") => Some(PtxLoadWidth::B8),
_ => None,
} {
if let Some(constreg) = const_load_instruction
.name("constreg")
.map(|s| s.as_bytes())
{
if let Some(loadoffset) = std::str::from_utf8(
const_load_instruction
.name("loadoffset")
.map_or(b"0", |s| s.as_bytes()),
)
.ok()
.and_then(|s| s.parse().ok())
{
if let Some((range, instruction)) = const_load_instruction
.name("instruction")
.map(|s| (s.range(), s.as_bytes()))
{
// Store the PTX source code before the load instruction
last_slice.extend_from_slice(&ptx[from_index..range.start]);
ptx_slices.push(PtxElement::CopiedSource {
ptx: std::mem::take(&mut last_slice).into_boxed_slice(),
});
from_index = range.end;
// Store the load instruction with extracted parameters to
// generate a constant load if requested
ptx_slices.push(PtxElement::ConstLoad {
ptx: instruction.to_owned().into_boxed_slice(),
parameter_index: *param,
byte_offset: loadoffset,
load_width: loadwidth,
registers: register_regex()
.captures_iter(constreg)
.filter_map(|m| {
m.name("register").map(|s| {
s.as_bytes().to_owned().into_boxed_slice()
})
})
.collect::<Vec<_>>()
.into_boxed_slice(),
});
}
}
}
}
}
}
}
// Store the remainder of the PTX source code
last_slice.extend_from_slice(&ptx[from_index..ptx.len()]);
if !last_slice.is_empty() {
ptx_slices.push(PtxElement::CopiedSource {
ptx: last_slice.into_boxed_slice(),
});
}
// Create the `PtxJITCompiler` which also caches the last PTX version
Self {
ptx_slices: ptx_slices.into_boxed_slice(),
last_arguments: None,
last_ptx: unsafe { CString::from_vec_unchecked(ptx.to_owned()) },
}
}
}