numcodecs_fourier_network/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.82.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-fourier-network
10//! [crates.io]: https://crates.io/crates/numcodecs-fourier-network
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-fourier-network
13//! [docs.rs]: https://docs.rs/numcodecs-fourier-network/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_fourier_network
17//!
18//! Fourier feature neural network codec implementation for the [`numcodecs`] API.
19
20#![expect(clippy::multiple_crate_versions)]
21
22use std::{borrow::Cow, num::NonZeroUsize, ops::AddAssign};
23
24use burn::{
25    backend::{ndarray::NdArrayDevice, Autodiff, NdArray},
26    module::{Module, Param},
27    nn::loss::{MseLoss, Reduction},
28    optim::{AdamConfig, GradientsParams, Optimizer},
29    prelude::Backend,
30    record::{
31        BinBytesRecorder, DoublePrecisionSettings, FullPrecisionSettings, PrecisionSettings,
32        Record, Recorder, RecorderError,
33    },
34    tensor::{
35        backend::AutodiffBackend, Distribution, Element as BurnElement, Float, Tensor, TensorData,
36    },
37};
38use itertools::Itertools;
39use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension, Ix1, Order, Zip};
40use num_traits::{ConstOne, ConstZero, Float as FloatTrait, FromPrimitive};
41use numcodecs::{
42    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
43    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
44};
45use schemars::{json_schema, JsonSchema, Schema, SchemaGenerator};
46use serde::{Deserialize, Deserializer, Serialize, Serializer};
47use thiserror::Error;
48
49#[cfg(test)]
50use ::serde_json as _;
51
52// FIXME: see https://github.com/tracel-ai/burn/issues/2876
53use ::{bincode as _, bincode_derive as _};
54
55// FIXME: 1.9.0 has MSRV 1.84, which our MSRV 1.82 doesn't support
56use ::bytemuck_derive as _;
57
58mod modules;
59
60use modules::{Model, ModelConfig, ModelExtra, ModelRecord};
61
62type FourierNetworkCodecVersion = StaticCodecVersion<0, 1, 0>;
63
64#[derive(Clone, Serialize, Deserialize, JsonSchema)]
65#[serde(deny_unknown_fields)]
66/// Fourier network codec which trains and overfits a fourier feature neural
67/// network on encoding and predicts during decoding.
68///
69/// The approach is based on the papers by Tancik et al. 2020
70/// (<https://dl.acm.org/doi/abs/10.5555/3495724.3496356>)
71/// and by Huang and Hoefler 2020 (<https://arxiv.org/abs/2210.12538>).
72pub struct FourierNetworkCodec {
73    /// The number of Fourier features that the data coordinates are projected to
74    pub fourier_features: NonZeroUsize,
75    /// The standard deviation of the Fourier features
76    pub fourier_scale: Positive<f64>,
77    /// The number of blocks in the network
78    pub num_blocks: NonZeroUsize,
79    /// The learning rate for the `Adam` optimizer
80    pub learning_rate: Positive<f64>,
81    /// The number of epochs for which the network is trained
82    pub num_epochs: usize,
83    /// The optional mini-batch size used during training
84    ///
85    /// Setting the mini-batch size to `None` disables the use of batching,
86    /// i.e. the network is trained using one large batch that includes the
87    /// full data.
88    #[serde(deserialize_with = "deserialize_required_option")]
89    #[schemars(required, extend("type" = ["integer", "null"]))]
90    pub mini_batch_size: Option<NonZeroUsize>,
91    /// The seed for the random number generator used during encoding
92    pub seed: u64,
93    /// The codec's encoding format version. Do not provide this parameter explicitly.
94    #[serde(default, rename = "_version")]
95    pub version: FourierNetworkCodecVersion,
96}
97
98// using this wrapper function makes an Option<T> required
99fn deserialize_required_option<'de, T: serde::Deserialize<'de>, D: serde::Deserializer<'de>>(
100    deserializer: D,
101) -> Result<Option<T>, D::Error> {
102    Option::<T>::deserialize(deserializer)
103}
104
105impl Codec for FourierNetworkCodec {
106    type Error = FourierNetworkCodecError;
107
108    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
109        match data {
110            AnyCowArray::F32(data) => Ok(AnyArray::U8(
111                encode::<f32, _, _, Autodiff<NdArray<f32>>>(
112                    &NdArrayDevice::Cpu,
113                    data,
114                    self.fourier_features,
115                    self.fourier_scale,
116                    self.num_blocks,
117                    self.learning_rate,
118                    self.num_epochs,
119                    self.mini_batch_size,
120                    self.seed,
121                )?
122                .into_dyn(),
123            )),
124            AnyCowArray::F64(data) => Ok(AnyArray::U8(
125                encode::<f64, _, _, Autodiff<NdArray<f64>>>(
126                    &NdArrayDevice::Cpu,
127                    data,
128                    self.fourier_features,
129                    self.fourier_scale,
130                    self.num_blocks,
131                    self.learning_rate,
132                    self.num_epochs,
133                    self.mini_batch_size,
134                    self.seed,
135                )?
136                .into_dyn(),
137            )),
138            encoded => Err(FourierNetworkCodecError::UnsupportedDtype(encoded.dtype())),
139        }
140    }
141
142    fn decode(&self, _encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
143        Err(FourierNetworkCodecError::MissingDecodingOutput)
144    }
145
146    fn decode_into(
147        &self,
148        encoded: AnyArrayView,
149        decoded: AnyArrayViewMut,
150    ) -> Result<(), Self::Error> {
151        let AnyArrayView::U8(encoded) = encoded else {
152            return Err(FourierNetworkCodecError::EncodedDataNotBytes {
153                dtype: encoded.dtype(),
154            });
155        };
156
157        let Ok(encoded): Result<ArrayBase<_, Ix1>, _> = encoded.view().into_dimensionality() else {
158            return Err(FourierNetworkCodecError::EncodedDataNotOneDimensional {
159                shape: encoded.shape().to_vec(),
160            });
161        };
162
163        match decoded {
164            AnyArrayViewMut::F32(decoded) => decode_into::<f32, _, _, NdArray<f32>>(
165                &NdArrayDevice::Cpu,
166                encoded,
167                decoded,
168                self.fourier_features,
169                self.num_blocks,
170            ),
171            AnyArrayViewMut::F64(decoded) => decode_into::<f64, _, _, NdArray<f64>>(
172                &NdArrayDevice::Cpu,
173                encoded,
174                decoded,
175                self.fourier_features,
176                self.num_blocks,
177            ),
178            decoded => Err(FourierNetworkCodecError::UnsupportedDtype(decoded.dtype())),
179        }
180    }
181}
182
183impl StaticCodec for FourierNetworkCodec {
184    const CODEC_ID: &'static str = "fourier-network.rs";
185
186    type Config<'de> = Self;
187
188    fn from_config(config: Self::Config<'_>) -> Self {
189        config
190    }
191
192    fn get_config(&self) -> StaticCodecConfig<Self> {
193        StaticCodecConfig::from(self)
194    }
195}
196
197#[expect(clippy::derive_partial_eq_without_eq)] // floats are not Eq
198#[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
199/// Positive floating point number
200pub struct Positive<T: FloatTrait>(T);
201
202impl Serialize for Positive<f64> {
203    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
204        serializer.serialize_f64(self.0)
205    }
206}
207
208impl<'de> Deserialize<'de> for Positive<f64> {
209    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
210        let x = f64::deserialize(deserializer)?;
211
212        if x > 0.0 {
213            Ok(Self(x))
214        } else {
215            Err(serde::de::Error::invalid_value(
216                serde::de::Unexpected::Float(x),
217                &"a positive value",
218            ))
219        }
220    }
221}
222
223impl JsonSchema for Positive<f64> {
224    fn schema_name() -> Cow<'static, str> {
225        Cow::Borrowed("PositiveF64")
226    }
227
228    fn schema_id() -> Cow<'static, str> {
229        Cow::Borrowed(concat!(module_path!(), "::", "Positive<f64>"))
230    }
231
232    fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
233        json_schema!({
234            "type": "number",
235            "exclusiveMinimum": 0.0
236        })
237    }
238}
239
240#[derive(Debug, Error)]
241/// Errors that may occur when applying the [`FourierNetworkCodec`].
242pub enum FourierNetworkCodecError {
243    /// [`FourierNetworkCodec`] does not support the dtype
244    #[error("FourierNetwork does not support the dtype {0}")]
245    UnsupportedDtype(AnyArrayDType),
246    /// [`FourierNetworkCodec`] does not support non-finite (infinite or NaN) floating
247    /// point data
248    #[error("FourierNetwork does not support non-finite (infinite or NaN) floating point data")]
249    NonFiniteData,
250    /// [`FourierNetworkCodec`] failed during a neural network computation
251    #[error("FourierNetwork failed during a neural network computation")]
252    NeuralNetworkError {
253        /// The source of the error
254        #[from]
255        source: NeuralNetworkError,
256    },
257    /// [`FourierNetworkCodec`] must be provided the output array during decoding
258    #[error("FourierNetwork must be provided the output array during decoding")]
259    MissingDecodingOutput,
260    /// [`FourierNetworkCodec`] can only decode one-dimensional byte arrays but
261    /// received an array of a different dtype
262    #[error(
263        "FourierNetwork can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
264    )]
265    EncodedDataNotBytes {
266        /// The unexpected dtype of the encoded array
267        dtype: AnyArrayDType,
268    },
269    /// [`FourierNetworkCodec`] can only decode one-dimensional byte arrays but
270    /// received an array of a different shape
271    #[error("FourierNetwork can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
272    EncodedDataNotOneDimensional {
273        /// The unexpected shape of the encoded array
274        shape: Vec<usize>,
275    },
276    /// [`FourierNetworkCodec`] cannot decode into the provided array
277    #[error("FourierNetwork cannot decode into the provided array")]
278    MismatchedDecodeIntoArray {
279        /// The source of the error
280        #[from]
281        source: AnyArrayAssignError,
282    },
283}
284
285#[derive(Debug, Error)]
286#[error(transparent)]
287/// Opaque error for when an error occurs in the neural network
288pub struct NeuralNetworkError(RecorderError);
289
290/// Floating point types.
291pub trait FloatExt:
292    AddAssign + BurnElement + ConstOne + ConstZero + FloatTrait + FromPrimitive
293{
294    /// The precision of this floating point type
295    type Precision: PrecisionSettings;
296
297    /// Convert a usize to a floating point number
298    fn from_usize(x: usize) -> Self;
299}
300
301impl FloatExt for f32 {
302    type Precision = FullPrecisionSettings;
303
304    #[expect(clippy::cast_precision_loss)]
305    fn from_usize(x: usize) -> Self {
306        x as Self
307    }
308}
309
310impl FloatExt for f64 {
311    type Precision = DoublePrecisionSettings;
312
313    #[expect(clippy::cast_precision_loss)]
314    fn from_usize(x: usize) -> Self {
315        x as Self
316    }
317}
318
319#[expect(clippy::similar_names)] // train_xs and train_ys
320#[expect(clippy::missing_panics_doc)] // only panics on implementation bugs
321#[expect(clippy::too_many_arguments)] // FIXME
322/// Encodes the `data` by training a fourier feature neural network.
323///
324/// The `fourier_features` are randomly sampled from a normal distribution with
325/// zero mean and `fourier_scale` standard deviation.
326///
327/// The neural network consists of `num_blocks` blocks.
328///
329/// The network is trained for `num_epochs` using the `learning_rate`
330/// and mini-batches of `mini_batch_size` if mini-batching is enabled.
331///
332/// All random numbers are generated using the provided `seed`.
333///
334/// # Errors
335///
336/// Errors with
337/// - [`FourierNetworkCodecError::NonFiniteData`] if any data element is
338///   non-finite (infinite or NaN)
339/// - [`FourierNetworkCodecError::NeuralNetworkError`] if an error occurs during
340///   the neural network computation
341pub fn encode<T: FloatExt, S: Data<Elem = T>, D: Dimension, B: AutodiffBackend<FloatElem = T>>(
342    device: &B::Device,
343    data: ArrayBase<S, D>,
344    fourier_features: NonZeroUsize,
345    fourier_scale: Positive<f64>,
346    num_blocks: NonZeroUsize,
347    learning_rate: Positive<f64>,
348    num_epochs: usize,
349    mini_batch_size: Option<NonZeroUsize>,
350    seed: u64,
351) -> Result<Array<u8, Ix1>, FourierNetworkCodecError> {
352    let Some(mean) = data.mean() else {
353        return Ok(Array::from_vec(Vec::new()));
354    };
355    let stdv = data.std(T::ZERO);
356    let stdv = if stdv == T::ZERO { T::ONE } else { stdv };
357
358    if !Zip::from(&data).all(|x| x.is_finite()) {
359        return Err(FourierNetworkCodecError::NonFiniteData);
360    }
361
362    B::seed(seed);
363
364    let b_t = Tensor::<B, 2, Float>::random(
365        [data.ndim(), fourier_features.get()],
366        Distribution::Normal(0.0, fourier_scale.0),
367        device,
368    );
369
370    let train_xs = flat_grid_like(&data, device);
371    let train_xs = fourier_mapping(train_xs, b_t.clone());
372
373    let train_ys_shape = [data.len(), 1];
374    let mut train_ys = data.into_owned();
375    train_ys.mapv_inplace(|x| (x - mean) / stdv);
376    #[expect(clippy::unwrap_used)] // reshape with one extra new axis cannot fail
377    let train_ys = train_ys
378        .into_shape_clone((train_ys_shape, Order::RowMajor))
379        .unwrap();
380    let train_ys = Tensor::from_data(
381        TensorData::new(train_ys.into_raw_vec_and_offset().0, train_ys_shape),
382        device,
383    );
384
385    let model = train(
386        device,
387        &train_xs,
388        &train_ys,
389        fourier_features,
390        num_blocks,
391        learning_rate,
392        num_epochs,
393        mini_batch_size,
394        stdv,
395    );
396
397    let extra = ModelExtra {
398        model: model.into_record(),
399        b_t: Param::from_tensor(b_t).set_require_grad(false),
400        mean: Param::from_tensor(Tensor::from_data(
401            TensorData::new(vec![mean], vec![1]),
402            device,
403        ))
404        .set_require_grad(false),
405        stdv: Param::from_tensor(Tensor::from_data(
406            TensorData::new(vec![stdv], vec![1]),
407            device,
408        ))
409        .set_require_grad(false),
410        version: StaticCodecVersion,
411    };
412
413    let recorder = BinBytesRecorder::<T::Precision>::new();
414    let encoded = recorder.record(extra, ()).map_err(NeuralNetworkError)?;
415
416    Ok(Array::from_vec(encoded))
417}
418
419#[expect(clippy::missing_panics_doc)] // only panics on implementation bugs
420/// Decodes the `encoded` data into the `decoded` output array by making a
421/// prediction using the fourier feature neural network.
422///
423/// The network must have been trained during [`encode`] using the same number
424/// of `feature_features` and `num_blocks`.
425///
426/// # Errors
427///
428/// Errors with
429/// - [`FourierNetworkCodecError::MismatchedDecodeIntoArray`] if the encoded
430///   array is empty but the decoded array is not
431/// - [`FourierNetworkCodecError::NeuralNetworkError`] if an error occurs during
432///   the neural network computation
433pub fn decode_into<T: FloatExt, S: Data<Elem = u8>, D: Dimension, B: Backend<FloatElem = T>>(
434    device: &B::Device,
435    encoded: ArrayBase<S, Ix1>,
436    mut decoded: ArrayViewMut<T, D>,
437    fourier_features: NonZeroUsize,
438    num_blocks: NonZeroUsize,
439) -> Result<(), FourierNetworkCodecError> {
440    if encoded.is_empty() {
441        if decoded.is_empty() {
442            return Ok(());
443        }
444
445        return Err(FourierNetworkCodecError::MismatchedDecodeIntoArray {
446            source: AnyArrayAssignError::ShapeMismatch {
447                src: encoded.shape().to_vec(),
448                dst: decoded.shape().to_vec(),
449            },
450        });
451    }
452
453    let encoded = encoded.into_owned().into_raw_vec_and_offset().0;
454
455    let recorder = BinBytesRecorder::<T::Precision>::new();
456    let record: ModelExtra<B> = recorder.load(encoded, device).map_err(NeuralNetworkError)?;
457
458    let model = ModelConfig::new(fourier_features, num_blocks)
459        .init(device)
460        .load_record(record.model);
461    let b_t = record.b_t.into_value();
462    let mean = record.mean.into_value().into_scalar();
463    let stdv = record.stdv.into_value().into_scalar();
464
465    let test_xs = flat_grid_like(&decoded, device);
466    let test_xs = fourier_mapping(test_xs, b_t);
467
468    let prediction = model.forward(test_xs).into_data();
469    #[expect(clippy::unwrap_used)] // same generic type, check must succeed
470    let prediction = prediction.as_slice().unwrap();
471
472    #[expect(clippy::unwrap_used)] // prediction shape is flattened
473    decoded.assign(&ArrayView::from_shape(decoded.shape(), prediction).unwrap());
474    decoded.mapv_inplace(|x| (x * stdv) + mean);
475
476    Ok(())
477}
478
479fn flat_grid_like<T: FloatExt, S: Data<Elem = T>, D: Dimension, B: Backend<FloatElem = T>>(
480    a: &ArrayBase<S, D>,
481    device: &B::Device,
482) -> Tensor<B, 2, Float> {
483    let grid = a
484        .shape()
485        .iter()
486        .copied()
487        .map(|s| {
488            #[expect(clippy::useless_conversion)] // (0..s).into_iter()
489            (0..s)
490                .into_iter()
491                .map(move |x| <T as FloatExt>::from_usize(x) / <T as FloatExt>::from_usize(s))
492        })
493        .multi_cartesian_product()
494        .flatten()
495        .collect::<Vec<_>>();
496
497    Tensor::from_data(TensorData::new(grid, [a.len(), a.ndim()]), device)
498}
499
500fn fourier_mapping<B: Backend>(
501    xs: Tensor<B, 2, Float>,
502    b_t: Tensor<B, 2, Float>,
503) -> Tensor<B, 2, Float> {
504    let xs_proj = xs.mul_scalar(core::f64::consts::TAU).matmul(b_t);
505
506    Tensor::cat(vec![xs_proj.clone().sin(), xs_proj.cos()], 1)
507}
508
509#[expect(clippy::similar_names)] // train_xs and train_ys
510#[expect(clippy::too_many_arguments)] // FIXME
511fn train<T: FloatExt, B: AutodiffBackend<FloatElem = T>>(
512    device: &B::Device,
513    train_xs: &Tensor<B, 2, Float>,
514    train_ys: &Tensor<B, 2, Float>,
515    fourier_features: NonZeroUsize,
516    num_blocks: NonZeroUsize,
517    learning_rate: Positive<f64>,
518    num_epochs: usize,
519    mini_batch_size: Option<NonZeroUsize>,
520    stdv: T,
521) -> Model<B> {
522    let num_samples = train_ys.shape().num_elements();
523    let num_batches = mini_batch_size.map(|b| num_samples.div_ceil(b.get()));
524
525    let mut model = ModelConfig::new(fourier_features, num_blocks).init(device);
526    let mut optim = AdamConfig::new().init();
527
528    let mut best_loss = T::infinity();
529    let mut best_epoch = 0;
530    let mut best_model_checkpoint = model.clone().into_record().into_item::<T::Precision>();
531
532    for epoch in 1..=num_epochs {
533        #[expect(clippy::option_if_let_else)]
534        let (train_xs_batches, train_ys_batches) = match num_batches {
535            Some(num_batches) => {
536                let shuffle = Tensor::<B, 1, Float>::random(
537                    [num_samples],
538                    Distribution::Uniform(0.0, 1.0),
539                    device,
540                );
541                let shuffle_indices = shuffle.argsort(0);
542
543                let train_xs_shuffled = train_xs.clone().select(0, shuffle_indices.clone());
544                let train_ys_shuffled = train_ys.clone().select(0, shuffle_indices);
545
546                (
547                    train_xs_shuffled.chunk(num_batches, 0),
548                    train_ys_shuffled.chunk(num_batches, 0),
549                )
550            }
551            None => (vec![train_xs.clone()], vec![train_ys.clone()]),
552        };
553
554        let mut loss_sum = T::ZERO;
555
556        let mut se_sum = T::ZERO;
557        let mut ae_sum = T::ZERO;
558        let mut l_inf = T::ZERO;
559
560        for (train_xs_batch, train_ys_batch) in train_xs_batches.into_iter().zip(train_ys_batches) {
561            let prediction = model.forward(train_xs_batch);
562            let loss =
563                MseLoss::new().forward(prediction.clone(), train_ys_batch.clone(), Reduction::Mean);
564
565            let grads = GradientsParams::from_grads(loss.backward(), &model);
566            model = optim.step(learning_rate.0, model, grads);
567
568            loss_sum += loss.into_scalar();
569
570            let err = prediction - train_ys_batch;
571
572            se_sum += (err.clone() * err.clone()).sum().into_scalar();
573            ae_sum += err.clone().abs().sum().into_scalar();
574            l_inf = l_inf.max(err.abs().max().into_scalar());
575        }
576
577        let loss_mean = loss_sum / <T as FloatExt>::from_usize(num_batches.unwrap_or(1));
578
579        if loss_mean < best_loss {
580            best_loss = loss_mean;
581            best_epoch = epoch;
582            best_model_checkpoint = model.clone().into_record().into_item::<T::Precision>();
583        }
584
585        let rmse = stdv * (se_sum / <T as FloatExt>::from_usize(num_samples)).sqrt();
586        let mae = stdv * ae_sum / <T as FloatExt>::from_usize(num_samples);
587        let l_inf = stdv * l_inf;
588
589        log::info!("[{epoch}/{num_epochs}]: loss={loss_mean:0.3} MAE={mae:0.3} RMSE={rmse:0.3} Linf={l_inf:0.3}");
590    }
591
592    if best_epoch != num_epochs {
593        model = model.load_record(ModelRecord::from_item(best_model_checkpoint, device));
594
595        log::info!("restored from epoch {best_epoch} with lowest loss={best_loss:0.3}");
596    }
597
598    model
599}
600
601#[cfg(test)]
602#[expect(clippy::unwrap_used)]
603mod tests {
604    use super::*;
605
606    #[test]
607    fn empty() {
608        std::mem::drop(simple_logger::init());
609
610        let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
611            &NdArrayDevice::Cpu,
612            Array::<f32, _>::zeros((0,)),
613            NonZeroUsize::MIN,
614            Positive(1.0),
615            NonZeroUsize::MIN,
616            Positive(1e-4),
617            10,
618            None,
619            42,
620        )
621        .unwrap();
622        assert!(encoded.is_empty());
623        let mut decoded = Array::<f32, _>::zeros((0,));
624        decode_into::<f32, _, _, NdArray<f32>>(
625            &NdArrayDevice::Cpu,
626            encoded,
627            decoded.view_mut(),
628            NonZeroUsize::MIN,
629            NonZeroUsize::MIN,
630        )
631        .unwrap();
632    }
633
634    #[test]
635    fn ones() {
636        std::mem::drop(simple_logger::init());
637
638        let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
639            &NdArrayDevice::Cpu,
640            Array::<f32, _>::zeros((1, 1, 1, 1)),
641            NonZeroUsize::MIN,
642            Positive(1.0),
643            NonZeroUsize::MIN,
644            Positive(1e-4),
645            10,
646            None,
647            42,
648        )
649        .unwrap();
650        let mut decoded = Array::<f32, _>::zeros((1, 1, 1, 1));
651        decode_into::<f32, _, _, NdArray<f32>>(
652            &NdArrayDevice::Cpu,
653            encoded,
654            decoded.view_mut(),
655            NonZeroUsize::MIN,
656            NonZeroUsize::MIN,
657        )
658        .unwrap();
659    }
660
661    #[test]
662    fn r#const() {
663        std::mem::drop(simple_logger::init());
664
665        let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
666            &NdArrayDevice::Cpu,
667            Array::<f32, _>::from_elem((2, 1, 3), 42.0),
668            NonZeroUsize::MIN,
669            Positive(1.0),
670            NonZeroUsize::MIN,
671            Positive(1e-4),
672            10,
673            None,
674            42,
675        )
676        .unwrap();
677        let mut decoded = Array::<f32, _>::zeros((2, 1, 3));
678        decode_into::<f32, _, _, NdArray<f32>>(
679            &NdArrayDevice::Cpu,
680            encoded,
681            decoded.view_mut(),
682            NonZeroUsize::MIN,
683            NonZeroUsize::MIN,
684        )
685        .unwrap();
686    }
687
688    #[test]
689    fn const_batched() {
690        std::mem::drop(simple_logger::init());
691
692        let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
693            &NdArrayDevice::Cpu,
694            Array::<f32, _>::from_elem((2, 1, 3), 42.0),
695            NonZeroUsize::MIN,
696            Positive(1.0),
697            NonZeroUsize::MIN,
698            Positive(1e-4),
699            10,
700            Some(NonZeroUsize::MIN.saturating_add(1)),
701            42,
702        )
703        .unwrap();
704        let mut decoded = Array::<f32, _>::zeros((2, 1, 3));
705        decode_into::<f32, _, _, NdArray<f32>>(
706            &NdArrayDevice::Cpu,
707            encoded,
708            decoded.view_mut(),
709            NonZeroUsize::MIN,
710            NonZeroUsize::MIN,
711        )
712        .unwrap();
713    }
714
715    #[test]
716    fn linspace() {
717        std::mem::drop(simple_logger::init());
718
719        let data = Array::linspace(0.0_f64, 100.0_f64, 100);
720
721        let fourier_features = NonZeroUsize::new(16).unwrap();
722        let fourier_scale = Positive(10.0);
723        let num_blocks = NonZeroUsize::new(2).unwrap();
724        let learning_rate = Positive(1e-4);
725        let num_epochs = 100;
726        let seed = 42;
727
728        for mini_batch_size in [
729            None,                                         // no mini-batching
730            Some(NonZeroUsize::MIN),                      // stochastic
731            Some(NonZeroUsize::MIN.saturating_add(6)),    // mini-batched, remainder
732            Some(NonZeroUsize::MIN.saturating_add(9)),    // mini-batched
733            Some(NonZeroUsize::MIN.saturating_add(1000)), // mini-batched, truncated
734        ] {
735            let mut decoded = Array::<f64, _>::zeros(data.shape());
736            let encoded = encode::<f64, _, _, Autodiff<NdArray<f64>>>(
737                &NdArrayDevice::Cpu,
738                data.view(),
739                fourier_features,
740                fourier_scale,
741                num_blocks,
742                learning_rate,
743                num_epochs,
744                mini_batch_size,
745                seed,
746            )
747            .unwrap();
748
749            decode_into::<f64, _, _, NdArray<f64>>(
750                &NdArrayDevice::Cpu,
751                encoded,
752                decoded.view_mut(),
753                fourier_features,
754                num_blocks,
755            )
756            .unwrap();
757        }
758    }
759}