1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
use crate::deps::alloc::{fmt, format};

/// Abort the CUDA kernel using the `trap` system call.
///
/// [`abort`] poisons the CUDA context and no more work can be performed in it.
#[expect(clippy::inline_always)]
#[inline(always)]
pub fn abort() -> ! {
    unsafe { ::core::arch::nvptx::trap() }
}

/// Exit the CUDA kernel using the `exit` instruction.
///
/// # Safety
///
/// [`exit`] quits the kernel early and any mutable data accessible outside this
/// kernel launch (by the host or a subsequent kernel launch) may be in an
/// inconsistent state. Therefore, kernel failure must be communicated back to
/// host and handled in some other manner.
///
/// Safely return from the main kernel function instead.
#[expect(clippy::inline_always)]
#[inline(always)]
pub unsafe fn exit() -> ! {
    unsafe { ::core::arch::asm!("exit;", options(noreturn)) }
}

/// Prints to the CUDA kernel's standard output using the `vprintf` system call.
///
/// Replacement for the [`std::print!`] macro, which now forwards to the
/// [`print()`] function.
pub macro print($($arg:tt)*) {
    self::print(::core::format_args!($($arg)*))
}

/// Prints to the CUDA kernel's standard output using the `vprintf` system call.
///
/// Replacement for the [`std::println!`] macro, which now forwards to the
/// [`print()`] function.
pub macro println {
    () => {
        self::print(::core::format_args!("\n"))
    },
    ($($arg:tt)*) => {
        self::print(::core::format_args!("{}\n", ::core::format_args!($($arg)*)))
    },
}

/// The [`print()`] function takes an [`Arguments`](core::fmt::Arguments) struct
/// and formats and prints it to the CUDA kernel's standard output using the
/// `vprintf` system call.
///
/// The [`Arguments`](core::fmt::Arguments) instance can be created with the
/// [`format_args!`](core::format_args) macro.
#[inline(always)]
pub fn print(args: ::core::fmt::Arguments) {
    #[repr(C)]
    struct FormatArgs {
        msg_len: u32,
        msg_ptr: *const u8,
    }

    let msg; // place to store the dynamically expanded format string
    #[expect(clippy::option_if_let_else)]
    let msg = if let Some(msg) = args.as_str() {
        msg
    } else {
        msg = fmt::format(args);
        msg.as_str()
    };

    let args = FormatArgs {
        msg_len: u32::try_from(msg.len()).unwrap_or(u32::MAX),
        msg_ptr: msg.as_ptr(),
    };

    unsafe {
        ::core::arch::nvptx::vprintf(c"%*s".as_ptr().cast(), ::core::ptr::from_ref(&args).cast());
    }
}

/// Helper function to efficiently pretty-print a [`core::panic::PanicInfo`]
/// using the `vprintf` system call.
///
/// If `allow_dynamic_message` is set,
/// [`alloc::format!`](crate::deps::alloc::format!) is used to print the
/// [`core::panic::PanicMessage`] message when
/// [`core::panic::PanicMessage::as_str`] returns [`None`]. Note that this may
/// pull in a large amount of string formatting and dynamic allocation code.
/// If unset, a default placeholder panic message is printed instead.
#[inline(always)]
pub fn pretty_print_panic_info(info: &::core::panic::PanicInfo, allow_dynamic_message: bool) {
    #[repr(C)]
    struct FormatArgs {
        file_len: u32,
        file_ptr: *const u8,
        line: u32,
        column: u32,
        thread_idx_x: u32,
        thread_idx_y: u32,
        thread_idx_z: u32,
        msg_len: u32,
        msg_ptr: *const u8,
    }

    let msg; // place to store the dynamically expanded format string
    #[expect(clippy::option_if_let_else)]
    let msg = if let Some(msg) = info.message().as_str() {
        msg
    } else if allow_dynamic_message {
        msg = format!("{}", info.message());
        msg.as_str()
    } else {
        "<dynamic panic message>"
    };

    let location_line = info.location().map_or(0, ::core::panic::Location::line);
    let location_column = info.location().map_or(0, ::core::panic::Location::column);
    let location_file = info
        .location()
        .map_or("<unknown panic location>", ::core::panic::Location::file);

    let thread_idx = crate::device::thread::Thread::this().idx();

    let args = FormatArgs {
        file_len: u32::try_from(location_file.len()).unwrap_or(u32::MAX),
        file_ptr: location_file.as_ptr(),
        line: location_line,
        column: location_column,
        thread_idx_x: thread_idx.x,
        thread_idx_y: thread_idx.y,
        thread_idx_z: thread_idx.z,
        msg_len: u32::try_from(msg.len()).unwrap_or(u32::MAX),
        msg_ptr: msg.as_ptr(),
    };

    unsafe {
        ::core::arch::nvptx::vprintf(
            c"panicked at %*s:%u:%u on thread (x=%u, y=%u, z=%u):\n%*s\n"
                .as_ptr()
                .cast(),
            ::core::ptr::from_ref(&args).cast(),
        );
    }
}

/// Helper function to efficiently pretty-print an error message (inside an
/// allocation error handler) using the `vprintf` system call.
#[track_caller]
#[inline(always)]
pub fn pretty_print_alloc_error(layout: ::core::alloc::Layout) {
    #[repr(C)]
    struct FormatArgs {
        size: usize,
        align: usize,
        file_len: u32,
        file_ptr: *const u8,
        line: u32,
        column: u32,
        thread_idx_x: u32,
        thread_idx_y: u32,
        thread_idx_z: u32,
    }

    let location = ::core::panic::Location::caller();
    let thread_idx = crate::device::thread::Thread::this().idx();

    let args = FormatArgs {
        size: layout.size(),
        align: layout.align(),
        file_len: u32::try_from(location.file().len()).unwrap_or(u32::MAX),
        file_ptr: location.file().as_ptr(),
        line: location.line(),
        column: location.column(),
        thread_idx_x: thread_idx.x,
        thread_idx_y: thread_idx.y,
        thread_idx_z: thread_idx.z,
    };

    unsafe {
        ::core::arch::nvptx::vprintf(
            c"memory allocation of %llu bytes with alignment %llu failed at \
            %*s:%u:%u on thread (x=%u, y=%u, z=%u)\n"
                .as_ptr()
                .cast(),
            ::core::ptr::from_ref(&args).cast(),
        );
    }
}