tthresh/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io]
2//! [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
3//!
4//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/tthresh-rs/ci.yml?branch=main
5//! [workflow]: https://github.com/juntyr/tthresh-rs/actions/workflows/ci.yml?query=branch%3Amain
6//!
7//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
8//! [repo]: https://github.com/juntyr/tthresh-rs
9//!
10//! [Latest Version]: https://img.shields.io/crates/v/tthresh
11//! [crates.io]: https://crates.io/crates/tthresh
12//!
13//! [Rust Doc Crate]: https://img.shields.io/docsrs/tthresh
14//! [docs.rs]: https://docs.rs/tthresh/
15//!
16//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
17//! [docs]: https://juntyr.github.io/tthresh-rs/tthresh
18//!
19//! High-level bindigs to the [tthresh] compressor.
20//!
21//! [tthresh]: https://github.com/rballester/tthresh
22
23#[derive(Clone, Copy, Debug, PartialEq)]
24/// Error bound
25pub enum ErrorBound {
26    /// Relative error
27    Eps(f64),
28    /// Root mean square error
29    RMSE(f64),
30    /// Peak signal-to-noise ratio
31    PSNR(f64),
32}
33
34/// Buffer for typed decompressed data
35#[derive(Debug, Clone, PartialEq, PartialOrd)]
36#[allow(missing_docs)]
37pub enum Buffer {
38    U8(Vec<u8>),
39    U16(Vec<u16>),
40    I32(Vec<i32>),
41    F32(Vec<f32>),
42    F64(Vec<f64>),
43}
44
45/// Compress the `data` buffer using the `target` error bound.
46///
47/// # Errors
48///
49/// Errors with
50/// - [`Error::InsufficientDimensionality`] if the `data`'s `shape` is not at least
51///   three-dimensional
52/// - [`Error::InvalidShape`] if the `shape` does not match the `data` length
53/// - [`Error::ExcessiveSize`] if the shape cannot be converted into [`u32`]s
54/// - [`Error::NegativeErrorBound`] if the `target` error bound is negative
55pub fn compress<T: Element>(
56    data: &[T],
57    shape: &[usize],
58    target: ErrorBound,
59    verbose: bool,
60    debug: bool,
61) -> Result<Vec<u8>, Error> {
62    if shape.len() < 3 {
63        return Err(Error::InsufficientDimensionality);
64    }
65
66    if shape.iter().copied().product::<usize>() != data.len() {
67        return Err(Error::InvalidShape);
68    }
69
70    let shape = shape
71        .iter()
72        .copied()
73        .map(u32::try_from)
74        .collect::<Result<Vec<_>, _>>()
75        .map_err(|_| Error::ExcessiveSize)?;
76
77    let target_value = match target {
78        ErrorBound::Eps(v) | ErrorBound::RMSE(v) | ErrorBound::PSNR(v) => v,
79    };
80
81    if target_value < 0.0 {
82        return Err(Error::NegativeErrorBound);
83    }
84
85    let mut output = std::ptr::null_mut();
86    let mut output_size = 0;
87
88    #[allow(unsafe_code)] // FFI
89    unsafe {
90        tthresh_sys::compress_buffer(
91            data.as_ptr().cast::<std::ffi::c_char>(),
92            T::IO_TYPE,
93            shape.as_ptr(),
94            shape.len(),
95            std::ptr::from_mut(&mut output),
96            std::ptr::from_mut(&mut output_size),
97            match target {
98                ErrorBound::Eps(_) => tthresh_sys::Target_eps,
99                ErrorBound::RMSE(_) => tthresh_sys::Target_rmse,
100                ErrorBound::PSNR(_) => tthresh_sys::Target_psnr,
101            },
102            target_value,
103            Some(alloc),
104            verbose,
105            debug,
106        );
107    }
108
109    #[allow(unsafe_code)]
110    // Safety: the output was allocated in Rust using `alloc` with the correct
111    //         size and alignment
112    let compressed = unsafe { Vec::from_raw_parts(output, output_size, output_size) };
113
114    Ok(compressed)
115}
116
117/// Deompress the `compressed` bytes into a [`Buffer`] and shape.
118///
119/// # Errors
120///
121/// Errors with
122/// - [`Error::ExcessiveSize`] if the output shape cannot be converted from [`u32`]s
123/// - [`Error::CorruptedCompressedBytes`] if the `compressed` bytes are corrupted
124pub fn decompress(
125    compressed: &[u8],
126    verbose: bool,
127    debug: bool,
128) -> Result<(Buffer, Vec<usize>), Error> {
129    let mut shape = std::ptr::null_mut();
130    let mut shape_size = 0;
131
132    let mut output = std::ptr::null_mut();
133    let mut output_type = 0;
134    let mut output_length = 0;
135
136    #[allow(unsafe_code)] // FFI
137    let ok = unsafe {
138        tthresh_sys::decompress_buffer(
139            compressed.as_ptr(),
140            compressed.len(),
141            std::ptr::from_mut(&mut output),
142            std::ptr::from_mut(&mut output_type),
143            std::ptr::from_mut(&mut output_length),
144            std::ptr::from_mut(&mut shape),
145            std::ptr::from_mut(&mut shape_size),
146            Some(alloc),
147            verbose,
148            debug,
149        )
150    };
151
152    if !ok {
153        return Err(Error::CorruptedCompressedBytes);
154    }
155
156    #[allow(unsafe_code)]
157    // Safety: the shape was allocated in Rust using `alloc` with the correct
158    //         size and alignment
159    let shape = unsafe { Vec::from_raw_parts(shape, shape_size, shape_size) };
160
161    #[allow(unsafe_code)]
162    // Safety: the output was allocated in Rust using `alloc` with the correct
163    //         size and alignment
164    let decompressed = match output_type {
165        tthresh_sys::IOType_uchar_ => {
166            Buffer::U8(unsafe { Vec::from_raw_parts(output.cast(), output_length, output_length) })
167        }
168        tthresh_sys::IOType_ushort_ => {
169            Buffer::U16(unsafe { Vec::from_raw_parts(output.cast(), output_length, output_length) })
170        }
171        tthresh_sys::IOType_int_ => {
172            Buffer::I32(unsafe { Vec::from_raw_parts(output.cast(), output_length, output_length) })
173        }
174        tthresh_sys::IOType_float_ => {
175            Buffer::F32(unsafe { Vec::from_raw_parts(output.cast(), output_length, output_length) })
176        }
177        tthresh_sys::IOType_double_ => {
178            Buffer::F64(unsafe { Vec::from_raw_parts(output.cast(), output_length, output_length) })
179        }
180        #[allow(clippy::unreachable)]
181        _ => unreachable!("tthresh decompression returned an unknown output type"),
182    };
183
184    let shape = shape
185        .into_iter()
186        .map(usize::try_from)
187        .collect::<Result<Vec<_>, _>>()
188        .map_err(|_| Error::ExcessiveSize)?;
189
190    Ok((decompressed, shape))
191}
192
193#[derive(Debug, thiserror::Error)]
194/// Errors that can occur during compression and decompression with tthresh
195pub enum Error {
196    /// data must be at least three-dimensional
197    #[error("data must be at least three-dimensional")]
198    InsufficientDimensionality,
199    /// shape does not match the provided buffer
200    #[error("shape does not match the provided buffer")]
201    InvalidShape,
202    /// data shape sizes must fit within [0; 2^32 - 1]
203    #[error("data shape sizes must fit within [0; 2^32 - 1]")]
204    ExcessiveSize,
205    /// error bound must be non-negative
206    #[error("error bound must be non-negative")]
207    NegativeErrorBound,
208    /// compressed bytes have been corrupted
209    #[error("compressed bytes have been corrupted")]
210    CorruptedCompressedBytes,
211}
212
213/// Marker trait for element types that can be compressed with tthresh
214pub trait Element: sealed::Element {}
215
216mod sealed {
217    pub trait Element: Copy {
218        const IO_TYPE: tthresh_sys::IOType;
219    }
220}
221
222impl Element for u8 {}
223impl sealed::Element for u8 {
224    const IO_TYPE: tthresh_sys::IOType = tthresh_sys::IOType_uchar_;
225}
226
227impl Element for u16 {}
228impl sealed::Element for u16 {
229    const IO_TYPE: tthresh_sys::IOType = tthresh_sys::IOType_ushort_;
230}
231
232impl Element for i32 {}
233impl sealed::Element for i32 {
234    const IO_TYPE: tthresh_sys::IOType = tthresh_sys::IOType_int_;
235}
236
237impl Element for f32 {}
238impl sealed::Element for f32 {
239    const IO_TYPE: tthresh_sys::IOType = tthresh_sys::IOType_float_;
240}
241
242impl Element for f64 {}
243impl sealed::Element for f64 {
244    const IO_TYPE: tthresh_sys::IOType = tthresh_sys::IOType_double_;
245}
246
247extern "C" fn alloc(size: usize, align: usize) -> *mut std::ffi::c_void {
248    #[allow(clippy::unwrap_used)]
249    let layout = std::alloc::Layout::from_size_align(size, align).unwrap();
250
251    // return a dangling pointer if the layout is zero-sized
252    if layout.size() == 0 {
253        #[allow(clippy::useless_transmute)]
254        // FIXME: use std::ptr::without_provenance_mut with MSRV 1.84
255        #[allow(unsafe_code)]
256        // Safety: usize -> *mut is always safe
257        return unsafe { std::mem::transmute(align) };
258    }
259
260    #[allow(unsafe_code)]
261    // Safety: layout is not zero-sized
262    unsafe { std::alloc::alloc_zeroed(layout) }.cast()
263}
264
265#[cfg(test)]
266#[allow(clippy::expect_used)]
267mod tests {
268    use super::*;
269
270    fn compress_decompress(target: ErrorBound) {
271        let data = std::fs::read("tthresh-sys/tthresh/data/3D_sphere_64_uchar.raw")
272            .expect("input file should not be missing");
273
274        let compressed = compress(data.as_slice(), &[64, 64, 64], target, true, true)
275            .expect("compression should not fail");
276
277        let (decompressed, shape) =
278            decompress(compressed.as_slice(), true, true).expect("decompression should not fail");
279        assert!(matches!(decompressed, Buffer::U8(_)));
280        assert_eq!(shape, &[64, 64, 64]);
281    }
282
283    #[test]
284    fn compress_decompress_eps() {
285        compress_decompress(ErrorBound::Eps(0.5));
286    }
287
288    #[test]
289    fn compress_decompress_rmse() {
290        compress_decompress(ErrorBound::RMSE(0.1));
291    }
292
293    #[test]
294    fn compress_decompress_psnr() {
295        compress_decompress(ErrorBound::PSNR(30.0));
296    }
297
298    #[test]
299    fn compress_decompress_u8() {
300        let compressed = compress(&[42_u8], &[1, 1, 1], ErrorBound::RMSE(0.0), true, true)
301            .expect("compression should not fail");
302
303        let (decompressed, shape) =
304            decompress(compressed.as_slice(), true, true).expect("decompression should not fail");
305        assert_eq!(decompressed, Buffer::U8(vec![42]));
306        assert_eq!(shape, &[1, 1, 1]);
307    }
308
309    #[test]
310    fn compress_decompress_u16() {
311        let compressed = compress(&[42_u16], &[1, 1, 1], ErrorBound::RMSE(0.0), true, true)
312            .expect("compression should not fail");
313
314        let (decompressed, shape) =
315            decompress(compressed.as_slice(), true, true).expect("decompression should not fail");
316        assert_eq!(decompressed, Buffer::U16(vec![42]));
317        assert_eq!(shape, &[1, 1, 1]);
318    }
319
320    #[test]
321    fn compress_decompress_i32() {
322        let compressed = compress(&[42_i32], &[1, 1, 1], ErrorBound::RMSE(0.0), true, true)
323            .expect("compression should not fail");
324
325        let (decompressed, shape) =
326            decompress(compressed.as_slice(), true, true).expect("decompression should not fail");
327        assert_eq!(decompressed, Buffer::I32(vec![42]));
328        assert_eq!(shape, &[1, 1, 1]);
329    }
330
331    #[test]
332    fn compress_decompress_f32() {
333        let compressed = compress(&[42.0_f32], &[1, 1, 1], ErrorBound::RMSE(0.0), true, true)
334            .expect("compression should not fail");
335
336        let (decompressed, shape) =
337            decompress(compressed.as_slice(), true, true).expect("decompression should not fail");
338        assert_eq!(decompressed, Buffer::F32(vec![42.0]));
339        assert_eq!(shape, &[1, 1, 1]);
340    }
341
342    #[test]
343    fn compress_decompress_f64() {
344        let compressed = compress(&[42.0_f64], &[1, 1, 1], ErrorBound::RMSE(0.0), true, true)
345            .expect("compression should not fail");
346
347        let (decompressed, shape) =
348            decompress(compressed.as_slice(), true, true).expect("decompression should not fail");
349        assert_eq!(decompressed, Buffer::F64(vec![42.0]));
350        assert_eq!(shape, &[1, 1, 1]);
351    }
352
353    #[test]
354    fn decompress_empty_garbage() {
355        let result = decompress(&[0], true, true);
356        assert!(matches!(result, Err(Error::CorruptedCompressedBytes)));
357    }
358
359    #[test]
360    fn decompress_full_garbage() {
361        let result = decompress(vec![1; 1024].as_slice(), true, true);
362        assert!(matches!(result, Err(Error::CorruptedCompressedBytes)));
363    }
364}