1use std::ffi::c_int;
24
25use ndarray::{ArrayView2, ArrayView3, ArrayViewMut2, ArrayViewMut3};
26
27#[derive(Copy, Clone, PartialEq, Debug)]
28pub enum CompressionMode {
30 BitsPerPixel {
32 bpp: f64,
34 },
35 PeakSignalToNoiseRatio {
37 psnr: f64,
39 },
40 PointwiseError {
42 pwe: f64,
44 },
45 QuantisationStep {
47 q: f64,
49 },
50}
51
52#[derive(Debug, thiserror::Error)]
53pub enum Error {
55 #[error("one or more parameters is invalid")]
57 InvalidParameter,
58 #[error("compressed data is missing the header")]
60 DecompressMissingHeader,
61 #[error("cannot decompress to an array with a different shape")]
63 DecompressShapeMismatch,
64 #[error("other error")]
66 Other,
67}
68
69impl CompressionMode {
70 const fn as_mode(self) -> c_int {
71 match self {
72 Self::BitsPerPixel { .. } => 1,
73 Self::PeakSignalToNoiseRatio { .. } => 2,
74 Self::PointwiseError { .. } => 3,
75 Self::QuantisationStep { .. } => -4,
76 }
77 }
78
79 const fn as_quality(self) -> f64 {
80 match self {
81 Self::BitsPerPixel { bpp: quality }
82 | Self::PeakSignalToNoiseRatio { psnr: quality }
83 | Self::PointwiseError { pwe: quality }
84 | Self::QuantisationStep { q: quality } => quality,
85 }
86 }
87}
88
89#[allow(clippy::missing_panics_doc)]
97pub fn compress_2d<T: Element>(
98 src: ArrayView2<T>,
99 mode: CompressionMode,
100) -> Result<Vec<u8>, Error> {
101 let src = src.as_standard_layout();
102
103 let mut dst = std::ptr::null_mut();
104 let mut dst_len = 0;
105
106 #[allow(unsafe_code)] let res = unsafe {
108 sperr_sys::sperr_comp_2d(
109 src.as_ptr().cast(),
110 T::IS_FLOAT.into(),
111 src.dim().1,
112 src.dim().0,
113 mode.as_mode(),
114 mode.as_quality(),
115 true.into(),
116 std::ptr::addr_of_mut!(dst),
117 std::ptr::addr_of_mut!(dst_len),
118 )
119 };
120
121 match res {
122 0 => (), #[allow(clippy::unreachable)]
124 1 => unreachable!("sperr_comp_2d: dst is not pointing to a NULL pointer"),
125 2 => return Err(Error::InvalidParameter),
126 -1 => return Err(Error::Other),
127 #[allow(clippy::panic)]
128 _ => panic!("sperr_comp_2d: unknown error kind {res}"),
129 }
130
131 #[allow(unsafe_code)] let compressed =
133 Vec::from(unsafe { std::slice::from_raw_parts(dst.cast_const().cast::<u8>(), dst_len) });
134
135 #[allow(unsafe_code)] unsafe {
137 sperr_sys::free_dst(dst);
138 }
139
140 Ok(compressed)
141}
142
143#[allow(clippy::missing_panics_doc)]
155pub fn decompress_into_2d<T: Element>(
156 compressed: &[u8],
157 mut decompressed: ArrayViewMut2<T>,
158) -> Result<(), Error> {
159 let Some((header, compressed)) = compressed.split_at_checked(10) else {
160 return Err(Error::DecompressMissingHeader);
161 };
162
163 let mut dim_x = 0;
164 let mut dim_y = 0;
165 let mut dim_z = 0;
166 let mut is_float = 0;
167
168 #[allow(unsafe_code)] unsafe {
170 sperr_sys::sperr_parse_header(
171 header.as_ptr().cast(),
172 std::ptr::addr_of_mut!(dim_x),
173 std::ptr::addr_of_mut!(dim_y),
174 std::ptr::addr_of_mut!(dim_z),
175 std::ptr::addr_of_mut!(is_float),
176 );
177 }
178
179 if (dim_z, dim_y, dim_x) != (1, decompressed.dim().0, decompressed.dim().1) {
180 return Err(Error::DecompressShapeMismatch);
181 }
182
183 let mut dst = std::ptr::null_mut();
184
185 #[allow(unsafe_code)] let res = unsafe {
187 sperr_sys::sperr_decomp_2d(
188 compressed.as_ptr().cast(),
189 compressed.len(),
190 T::IS_FLOAT.into(),
191 decompressed.dim().1,
192 decompressed.dim().0,
193 std::ptr::addr_of_mut!(dst),
194 )
195 };
196
197 match res {
198 0 => (), #[allow(clippy::unreachable)]
200 1 => unreachable!("sperr_decomp_2d: dst is not pointing to a NULL pointer"),
201 -1 => return Err(Error::Other),
202 #[allow(clippy::panic)]
203 _ => panic!("sperr_decomp_2d: unknown error kind {res}"),
204 }
205
206 #[allow(unsafe_code)] let dec =
208 unsafe { ArrayView2::from_shape_ptr(decompressed.dim(), dst.cast_const().cast::<T>()) };
209 decompressed.assign(&dec);
210
211 #[allow(unsafe_code)] unsafe {
213 sperr_sys::free_dst(dst);
214 }
215
216 Ok(())
217}
218
219#[allow(clippy::missing_panics_doc)]
228pub fn compress_3d<T: Element>(
229 src: ArrayView3<T>,
230 mode: CompressionMode,
231 chunks: (usize, usize, usize),
232) -> Result<Vec<u8>, Error> {
233 let src = src.as_standard_layout();
234
235 let mut dst = std::ptr::null_mut();
236 let mut dst_len = 0;
237
238 #[allow(unsafe_code)] let res = unsafe {
240 sperr_sys::sperr_comp_3d(
241 src.as_ptr().cast(),
242 T::IS_FLOAT.into(),
243 src.dim().2,
244 src.dim().1,
245 src.dim().0,
246 chunks.2,
247 chunks.1,
248 chunks.0,
249 mode.as_mode(),
250 mode.as_quality(),
251 0,
252 std::ptr::addr_of_mut!(dst),
253 std::ptr::addr_of_mut!(dst_len),
254 )
255 };
256
257 match res {
258 0 => (), #[allow(clippy::unreachable)]
260 1 => unreachable!("sperr_comp_3d: dst is not pointing to a NULL pointer"),
261 2 => return Err(Error::InvalidParameter),
262 -1 => return Err(Error::Other),
263 #[allow(clippy::panic)]
264 _ => panic!("sperr_comp_3d: unknown error kind {res}"),
265 }
266
267 #[allow(unsafe_code)] let compressed =
269 Vec::from(unsafe { std::slice::from_raw_parts(dst.cast_const().cast::<u8>(), dst_len) });
270
271 #[allow(unsafe_code)] unsafe {
273 sperr_sys::free_dst(dst);
274 }
275
276 Ok(compressed)
277}
278
279#[allow(clippy::missing_panics_doc)]
289pub fn decompress_into_3d<T: Element>(
290 compressed: &[u8],
291 mut decompressed: ArrayViewMut3<T>,
292) -> Result<(), Error> {
293 let mut dim_x = 0;
294 let mut dim_y = 0;
295 let mut dim_z = 0;
296 let mut is_float = 0;
297
298 #[allow(unsafe_code)] unsafe {
300 sperr_sys::sperr_parse_header(
301 compressed.as_ptr().cast(),
302 std::ptr::addr_of_mut!(dim_x),
303 std::ptr::addr_of_mut!(dim_y),
304 std::ptr::addr_of_mut!(dim_z),
305 std::ptr::addr_of_mut!(is_float),
306 );
307 }
308
309 if (dim_z, dim_y, dim_x)
310 != (
311 decompressed.dim().0,
312 decompressed.dim().1,
313 decompressed.dim().2,
314 )
315 {
316 return Err(Error::DecompressShapeMismatch);
317 }
318
319 let mut dst = std::ptr::null_mut();
320
321 #[allow(unsafe_code)] let res = unsafe {
323 sperr_sys::sperr_decomp_3d(
324 compressed.as_ptr().cast(),
325 compressed.len(),
326 T::IS_FLOAT.into(),
327 0,
328 std::ptr::addr_of_mut!(dim_x),
329 std::ptr::addr_of_mut!(dim_y),
330 std::ptr::addr_of_mut!(dim_z),
331 std::ptr::addr_of_mut!(dst),
332 )
333 };
334
335 match res {
336 0 => (), #[allow(clippy::unreachable)]
338 1 => unreachable!("sperr_decomp_3d: dst is not pointing to a NULL pointer"),
339 -1 => return Err(Error::Other),
340 #[allow(clippy::panic)]
341 _ => panic!("sperr_decomp_3d: unknown error kind {res}"),
342 }
343
344 #[allow(unsafe_code)] let dec =
346 unsafe { ArrayView3::from_shape_ptr(decompressed.dim(), dst.cast_const().cast::<T>()) };
347 decompressed.assign(&dec);
348
349 #[allow(unsafe_code)] unsafe {
351 sperr_sys::free_dst(dst);
352 }
353
354 Ok(())
355}
356
357pub trait Element: sealed::Element {}
359
360impl Element for f32 {}
361impl sealed::Element for f32 {
362 const IS_FLOAT: bool = true;
363}
364
365impl Element for f64 {}
366impl sealed::Element for f64 {
367 const IS_FLOAT: bool = false;
368}
369
370mod sealed {
371 pub trait Element: Copy {
372 const IS_FLOAT: bool;
373 }
374}
375
376#[cfg(test)]
377#[allow(clippy::expect_used)]
378mod tests {
379 use ndarray::{linspace, logspace, Array1, Array3};
380
381 use super::*;
382
383 fn compress_decompress(mode: CompressionMode) {
384 let data = linspace(1.0, 10.0, 128 * 128 * 128).collect::<Array1<f64>>()
385 + logspace(2.0, 0.0, 5.0, 128 * 128 * 128)
386 .rev()
387 .collect::<Array1<f64>>();
388 let data: Array3<f64> = data
389 .into_shape_clone((128, 128, 128))
390 .expect("create test data array");
391
392 let compressed =
393 compress_3d(data.view(), mode, (64, 64, 64)).expect("compression should not fail");
394
395 let mut decompressed = Array3::<f64>::zeros(data.dim());
396 decompress_into_3d(compressed.as_slice(), decompressed.view_mut())
397 .expect("decompression should not fail");
398 }
399
400 #[test]
401 fn compress_decompress_bpp() {
402 compress_decompress(CompressionMode::BitsPerPixel { bpp: 2.0 });
403 }
404
405 #[test]
406 fn compress_decompress_psnr() {
407 compress_decompress(CompressionMode::PeakSignalToNoiseRatio { psnr: 30.0 });
408 }
409
410 #[test]
411 fn compress_decompress_pwe() {
412 compress_decompress(CompressionMode::PointwiseError { pwe: 0.1 });
413 }
414
415 #[test]
416 fn compress_decompress_q() {
417 compress_decompress(CompressionMode::QuantisationStep { q: 3.0 });
418 }
419}