Skip to main content

numcodecs_fourier_network/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.87.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-fourier-network
10//! [crates.io]: https://crates.io/crates/numcodecs-fourier-network
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-fourier-network
13//! [docs.rs]: https://docs.rs/numcodecs-fourier-network/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_fourier_network
17//!
18//! Fourier feature neural network codec implementation for the [`numcodecs`] API.
19
20#![allow(clippy::multiple_crate_versions)]
21
22use std::{borrow::Cow, num::NonZeroUsize, ops::AddAssign};
23
24use burn::{
25    backend::{Autodiff, NdArray, ndarray::NdArrayDevice},
26    module::{Module, Param},
27    nn::loss::{MseLoss, Reduction},
28    optim::{AdamConfig, GradientsParams, Optimizer},
29    prelude::Backend,
30    record::{
31        BinBytesRecorder, DoublePrecisionSettings, FullPrecisionSettings, PrecisionSettings,
32        Record, Recorder, RecorderError,
33    },
34    tensor::{
35        Distribution, Element as BurnElement, Float, Tensor, TensorData, backend::AutodiffBackend,
36    },
37};
38use itertools::Itertools;
39use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension, Ix1, Order, Zip};
40use num_traits::{ConstOne, ConstZero, Float as FloatTrait, FromPrimitive};
41use numcodecs::{
42    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
43    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
44};
45use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
46use serde::{Deserialize, Deserializer, Serialize, Serializer};
47use thiserror::Error;
48
49// FIXME: bytemuck 1.24 fails to compile on 1.87
50use ::bytemuck as _;
51
52// FIXME: burn-common -> cubecl-common brings in wasm-bindgen
53//        wasm-bindgen v0.2.115 has an unresolved import in wasm32-wasi
54use ::wasm_bindgen as _;
55
56#[cfg(test)]
57use ::serde_json as _;
58
59mod modules;
60
61use modules::{Model, ModelConfig, ModelExtra, ModelRecord};
62
63type FourierNetworkCodecVersion = StaticCodecVersion<0, 1, 0>;
64
65#[derive(Clone, Serialize, Deserialize, JsonSchema)]
66#[serde(deny_unknown_fields)]
67/// Fourier network codec which trains and overfits a fourier feature neural
68/// network on encoding and predicts during decoding.
69///
70/// The approach is based on the papers by Tancik et al. 2020
71/// (<https://dl.acm.org/doi/abs/10.5555/3495724.3496356>)
72/// and by Huang and Hoefler 2020 (<https://arxiv.org/abs/2210.12538>).
73pub struct FourierNetworkCodec {
74    /// The number of Fourier features that the data coordinates are projected to
75    pub fourier_features: NonZeroUsize,
76    /// The standard deviation of the Fourier features
77    pub fourier_scale: Positive<f64>,
78    /// The number of blocks in the network
79    pub num_blocks: NonZeroUsize,
80    /// The learning rate for the `Adam` optimizer
81    pub learning_rate: Positive<f64>,
82    /// The number of epochs for which the network is trained
83    pub num_epochs: usize,
84    /// The optional mini-batch size used during training
85    ///
86    /// Setting the mini-batch size to `None` disables the use of batching,
87    /// i.e. the network is trained using one large batch that includes the
88    /// full data.
89    #[serde(deserialize_with = "deserialize_required_option")]
90    #[schemars(required, extend("type" = ["integer", "null"]))]
91    pub mini_batch_size: Option<NonZeroUsize>,
92    /// The seed for the random number generator used during encoding
93    pub seed: u64,
94    /// The codec's encoding format version. Do not provide this parameter explicitly.
95    #[serde(default, rename = "_version")]
96    pub version: FourierNetworkCodecVersion,
97}
98
99// using this wrapper function makes an Option<T> required
100fn deserialize_required_option<'de, T: serde::Deserialize<'de>, D: serde::Deserializer<'de>>(
101    deserializer: D,
102) -> Result<Option<T>, D::Error> {
103    Option::<T>::deserialize(deserializer)
104}
105
106impl Codec for FourierNetworkCodec {
107    type Error = FourierNetworkCodecError;
108
109    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
110        match data {
111            AnyCowArray::F32(data) => Ok(AnyArray::U8(
112                encode::<f32, _, _, Autodiff<NdArray<f32>>>(
113                    &NdArrayDevice::Cpu,
114                    data,
115                    self.fourier_features,
116                    self.fourier_scale,
117                    self.num_blocks,
118                    self.learning_rate,
119                    self.num_epochs,
120                    self.mini_batch_size,
121                    self.seed,
122                )?
123                .into_dyn(),
124            )),
125            AnyCowArray::F64(data) => Ok(AnyArray::U8(
126                encode::<f64, _, _, Autodiff<NdArray<f64>>>(
127                    &NdArrayDevice::Cpu,
128                    data,
129                    self.fourier_features,
130                    self.fourier_scale,
131                    self.num_blocks,
132                    self.learning_rate,
133                    self.num_epochs,
134                    self.mini_batch_size,
135                    self.seed,
136                )?
137                .into_dyn(),
138            )),
139            encoded => Err(FourierNetworkCodecError::UnsupportedDtype(encoded.dtype())),
140        }
141    }
142
143    fn decode(&self, _encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
144        Err(FourierNetworkCodecError::MissingDecodingOutput)
145    }
146
147    fn decode_into(
148        &self,
149        encoded: AnyArrayView,
150        decoded: AnyArrayViewMut,
151    ) -> Result<(), Self::Error> {
152        let AnyArrayView::U8(encoded) = encoded else {
153            return Err(FourierNetworkCodecError::EncodedDataNotBytes {
154                dtype: encoded.dtype(),
155            });
156        };
157
158        let Ok(encoded): Result<ArrayBase<_, Ix1>, _> = encoded.view().into_dimensionality() else {
159            return Err(FourierNetworkCodecError::EncodedDataNotOneDimensional {
160                shape: encoded.shape().to_vec(),
161            });
162        };
163
164        match decoded {
165            AnyArrayViewMut::F32(decoded) => decode_into::<f32, _, _, NdArray<f32>>(
166                &NdArrayDevice::Cpu,
167                encoded,
168                decoded,
169                self.fourier_features,
170                self.num_blocks,
171            ),
172            AnyArrayViewMut::F64(decoded) => decode_into::<f64, _, _, NdArray<f64>>(
173                &NdArrayDevice::Cpu,
174                encoded,
175                decoded,
176                self.fourier_features,
177                self.num_blocks,
178            ),
179            decoded => Err(FourierNetworkCodecError::UnsupportedDtype(decoded.dtype())),
180        }
181    }
182}
183
184impl StaticCodec for FourierNetworkCodec {
185    const CODEC_ID: &'static str = "fourier-network.rs";
186
187    type Config<'de> = Self;
188
189    fn from_config(config: Self::Config<'_>) -> Self {
190        config
191    }
192
193    fn get_config(&self) -> StaticCodecConfig<'_, Self> {
194        StaticCodecConfig::from(self)
195    }
196}
197
198#[expect(clippy::derive_partial_eq_without_eq)] // floats are not Eq
199#[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
200/// Positive floating point number
201pub struct Positive<T: FloatTrait>(T);
202
203impl Serialize for Positive<f64> {
204    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
205        serializer.serialize_f64(self.0)
206    }
207}
208
209impl<'de> Deserialize<'de> for Positive<f64> {
210    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
211        let x = f64::deserialize(deserializer)?;
212
213        if x > 0.0 {
214            Ok(Self(x))
215        } else {
216            Err(serde::de::Error::invalid_value(
217                serde::de::Unexpected::Float(x),
218                &"a positive value",
219            ))
220        }
221    }
222}
223
224impl JsonSchema for Positive<f64> {
225    fn schema_name() -> Cow<'static, str> {
226        Cow::Borrowed("PositiveF64")
227    }
228
229    fn schema_id() -> Cow<'static, str> {
230        Cow::Borrowed(concat!(module_path!(), "::", "Positive<f64>"))
231    }
232
233    fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
234        json_schema!({
235            "type": "number",
236            "exclusiveMinimum": 0.0
237        })
238    }
239}
240
241#[derive(Debug, Error)]
242/// Errors that may occur when applying the [`FourierNetworkCodec`].
243pub enum FourierNetworkCodecError {
244    /// [`FourierNetworkCodec`] does not support the dtype
245    #[error("FourierNetwork does not support the dtype {0}")]
246    UnsupportedDtype(AnyArrayDType),
247    /// [`FourierNetworkCodec`] does not support non-finite (infinite or NaN) floating
248    /// point data
249    #[error("FourierNetwork does not support non-finite (infinite or NaN) floating point data")]
250    NonFiniteData,
251    /// [`FourierNetworkCodec`] failed during a neural network computation
252    #[error("FourierNetwork failed during a neural network computation")]
253    NeuralNetworkError {
254        /// The source of the error
255        #[from]
256        source: NeuralNetworkError,
257    },
258    /// [`FourierNetworkCodec`] must be provided the output array during decoding
259    #[error("FourierNetwork must be provided the output array during decoding")]
260    MissingDecodingOutput,
261    /// [`FourierNetworkCodec`] can only decode one-dimensional byte arrays but
262    /// received an array of a different dtype
263    #[error(
264        "FourierNetwork can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
265    )]
266    EncodedDataNotBytes {
267        /// The unexpected dtype of the encoded array
268        dtype: AnyArrayDType,
269    },
270    /// [`FourierNetworkCodec`] can only decode one-dimensional byte arrays but
271    /// received an array of a different shape
272    #[error(
273        "FourierNetwork can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
274    )]
275    EncodedDataNotOneDimensional {
276        /// The unexpected shape of the encoded array
277        shape: Vec<usize>,
278    },
279    /// [`FourierNetworkCodec`] cannot decode into the provided array
280    #[error("FourierNetwork cannot decode into the provided array")]
281    MismatchedDecodeIntoArray {
282        /// The source of the error
283        #[from]
284        source: AnyArrayAssignError,
285    },
286}
287
288#[derive(Debug, Error)]
289#[error(transparent)]
290/// Opaque error for when an error occurs in the neural network
291pub struct NeuralNetworkError(RecorderError);
292
293/// Floating point types.
294pub trait FloatExt:
295    AddAssign + BurnElement + ConstOne + ConstZero + FloatTrait + FromPrimitive
296{
297    /// The precision of this floating point type
298    type Precision: PrecisionSettings;
299
300    /// Convert a usize to a floating point number
301    fn from_usize(x: usize) -> Self;
302}
303
304impl FloatExt for f32 {
305    type Precision = FullPrecisionSettings;
306
307    #[expect(clippy::cast_precision_loss)]
308    fn from_usize(x: usize) -> Self {
309        x as Self
310    }
311}
312
313impl FloatExt for f64 {
314    type Precision = DoublePrecisionSettings;
315
316    #[expect(clippy::cast_precision_loss)]
317    fn from_usize(x: usize) -> Self {
318        x as Self
319    }
320}
321
322#[expect(clippy::similar_names)] // train_xs and train_ys
323#[expect(clippy::missing_panics_doc)] // only panics on implementation bugs
324#[expect(clippy::too_many_arguments)] // FIXME
325/// Encodes the `data` by training a fourier feature neural network.
326///
327/// The `fourier_features` are randomly sampled from a normal distribution with
328/// zero mean and `fourier_scale` standard deviation.
329///
330/// The neural network consists of `num_blocks` blocks.
331///
332/// The network is trained for `num_epochs` using the `learning_rate`
333/// and mini-batches of `mini_batch_size` if mini-batching is enabled.
334///
335/// All random numbers are generated using the provided `seed`.
336///
337/// # Errors
338///
339/// Errors with
340/// - [`FourierNetworkCodecError::NonFiniteData`] if any data element is
341///   non-finite (infinite or NaN)
342/// - [`FourierNetworkCodecError::NeuralNetworkError`] if an error occurs during
343///   the neural network computation
344pub fn encode<T: FloatExt, S: Data<Elem = T>, D: Dimension, B: AutodiffBackend<FloatElem = T>>(
345    device: &B::Device,
346    data: ArrayBase<S, D>,
347    fourier_features: NonZeroUsize,
348    fourier_scale: Positive<f64>,
349    num_blocks: NonZeroUsize,
350    learning_rate: Positive<f64>,
351    num_epochs: usize,
352    mini_batch_size: Option<NonZeroUsize>,
353    seed: u64,
354) -> Result<Array<u8, Ix1>, FourierNetworkCodecError> {
355    let Some(mean) = data.mean() else {
356        return Ok(Array::from_vec(Vec::new()));
357    };
358    let stdv = data.std(T::ZERO);
359    let stdv = if stdv == T::ZERO { T::ONE } else { stdv };
360
361    if !Zip::from(&data).all(|x| x.is_finite()) {
362        return Err(FourierNetworkCodecError::NonFiniteData);
363    }
364
365    B::seed(seed);
366
367    let b_t = Tensor::<B, 2, Float>::random(
368        [data.ndim(), fourier_features.get()],
369        Distribution::Normal(0.0, fourier_scale.0),
370        device,
371    );
372
373    let train_xs = flat_grid_like(&data, device);
374    let train_xs = fourier_mapping(train_xs, b_t.clone());
375
376    let train_ys_shape = [data.len(), 1];
377    let mut train_ys = data.into_owned();
378    train_ys.mapv_inplace(|x| (x - mean) / stdv);
379    #[expect(clippy::unwrap_used)] // reshape with one extra new axis cannot fail
380    let train_ys = train_ys
381        .into_shape_clone((train_ys_shape, Order::RowMajor))
382        .unwrap();
383    let train_ys = Tensor::from_data(
384        TensorData::new(train_ys.into_raw_vec_and_offset().0, train_ys_shape),
385        device,
386    );
387
388    let model = train(
389        device,
390        &train_xs,
391        &train_ys,
392        fourier_features,
393        num_blocks,
394        learning_rate,
395        num_epochs,
396        mini_batch_size,
397        stdv,
398    );
399
400    let extra = ModelExtra {
401        model: model.into_record(),
402        b_t: Param::from_tensor(b_t).set_require_grad(false),
403        mean: Param::from_tensor(Tensor::from_data(
404            TensorData::new(vec![mean], vec![1]),
405            device,
406        ))
407        .set_require_grad(false),
408        stdv: Param::from_tensor(Tensor::from_data(
409            TensorData::new(vec![stdv], vec![1]),
410            device,
411        ))
412        .set_require_grad(false),
413        version: StaticCodecVersion,
414    };
415
416    let recorder = BinBytesRecorder::<T::Precision>::new();
417    let encoded = recorder.record(extra, ()).map_err(NeuralNetworkError)?;
418
419    Ok(Array::from_vec(encoded))
420}
421
422#[expect(clippy::missing_panics_doc)] // only panics on implementation bugs
423/// Decodes the `encoded` data into the `decoded` output array by making a
424/// prediction using the fourier feature neural network.
425///
426/// The network must have been trained during [`encode`] using the same number
427/// of `feature_features` and `num_blocks`.
428///
429/// # Errors
430///
431/// Errors with
432/// - [`FourierNetworkCodecError::MismatchedDecodeIntoArray`] if the encoded
433///   array is empty but the decoded array is not
434/// - [`FourierNetworkCodecError::NeuralNetworkError`] if an error occurs during
435///   the neural network computation
436pub fn decode_into<T: FloatExt, S: Data<Elem = u8>, D: Dimension, B: Backend<FloatElem = T>>(
437    device: &B::Device,
438    encoded: ArrayBase<S, Ix1>,
439    mut decoded: ArrayViewMut<T, D>,
440    fourier_features: NonZeroUsize,
441    num_blocks: NonZeroUsize,
442) -> Result<(), FourierNetworkCodecError> {
443    if encoded.is_empty() {
444        if decoded.is_empty() {
445            return Ok(());
446        }
447
448        return Err(FourierNetworkCodecError::MismatchedDecodeIntoArray {
449            source: AnyArrayAssignError::ShapeMismatch {
450                src: encoded.shape().to_vec(),
451                dst: decoded.shape().to_vec(),
452            },
453        });
454    }
455
456    let encoded = encoded.into_owned().into_raw_vec_and_offset().0;
457
458    let recorder = BinBytesRecorder::<T::Precision>::new();
459    let record: ModelExtra<B> = recorder.load(encoded, device).map_err(NeuralNetworkError)?;
460
461    let model = ModelConfig::new(fourier_features, num_blocks)
462        .init(device)
463        .load_record(record.model);
464    let b_t = record.b_t.into_value();
465    let mean = record.mean.into_value().into_scalar();
466    let stdv = record.stdv.into_value().into_scalar();
467
468    let test_xs = flat_grid_like(&decoded, device);
469    let test_xs = fourier_mapping(test_xs, b_t);
470
471    let prediction = model.forward(test_xs).into_data();
472    #[expect(clippy::unwrap_used)] // same generic type, check must succeed
473    let prediction = prediction.as_slice().unwrap();
474
475    #[expect(clippy::unwrap_used)] // prediction shape is flattened
476    decoded.assign(&ArrayView::from_shape(decoded.shape(), prediction).unwrap());
477    decoded.mapv_inplace(|x| (x * stdv) + mean);
478
479    Ok(())
480}
481
482fn flat_grid_like<T: FloatExt, S: Data<Elem = T>, D: Dimension, B: Backend<FloatElem = T>>(
483    a: &ArrayBase<S, D>,
484    device: &B::Device,
485) -> Tensor<B, 2, Float> {
486    let grid = a
487        .shape()
488        .iter()
489        .copied()
490        .map(|s| {
491            #[expect(clippy::useless_conversion)] // (0..s).into_iter()
492            (0..s)
493                .into_iter()
494                .map(move |x| <T as FloatExt>::from_usize(x) / <T as FloatExt>::from_usize(s))
495        })
496        .multi_cartesian_product()
497        .flatten()
498        .collect::<Vec<_>>();
499
500    Tensor::from_data(TensorData::new(grid, [a.len(), a.ndim()]), device)
501}
502
503fn fourier_mapping<B: Backend>(
504    xs: Tensor<B, 2, Float>,
505    b_t: Tensor<B, 2, Float>,
506) -> Tensor<B, 2, Float> {
507    let xs_proj = xs.mul_scalar(core::f64::consts::TAU).matmul(b_t);
508
509    Tensor::cat(vec![xs_proj.clone().sin(), xs_proj.cos()], 1)
510}
511
512#[expect(clippy::similar_names)] // train_xs and train_ys
513#[expect(clippy::too_many_arguments)] // FIXME
514fn train<T: FloatExt, B: AutodiffBackend<FloatElem = T>>(
515    device: &B::Device,
516    train_xs: &Tensor<B, 2, Float>,
517    train_ys: &Tensor<B, 2, Float>,
518    fourier_features: NonZeroUsize,
519    num_blocks: NonZeroUsize,
520    learning_rate: Positive<f64>,
521    num_epochs: usize,
522    mini_batch_size: Option<NonZeroUsize>,
523    stdv: T,
524) -> Model<B> {
525    let num_samples = train_ys.shape().num_elements();
526    let num_batches = mini_batch_size.map(|b| num_samples.div_ceil(b.get()));
527
528    let mut model = ModelConfig::new(fourier_features, num_blocks).init(device);
529    let mut optim = AdamConfig::new().init();
530
531    let mut best_loss = T::infinity();
532    let mut best_epoch = 0;
533    let mut best_model_checkpoint = model.clone().into_record().into_item::<T::Precision>();
534
535    for epoch in 1..=num_epochs {
536        #[expect(clippy::option_if_let_else)]
537        let (train_xs_batches, train_ys_batches) = match num_batches {
538            Some(num_batches) => {
539                let shuffle = Tensor::<B, 1, Float>::random(
540                    [num_samples],
541                    Distribution::Uniform(0.0, 1.0),
542                    device,
543                );
544                let shuffle_indices = shuffle.argsort(0);
545
546                let train_xs_shuffled = train_xs.clone().select(0, shuffle_indices.clone());
547                let train_ys_shuffled = train_ys.clone().select(0, shuffle_indices);
548
549                (
550                    train_xs_shuffled.chunk(num_batches, 0),
551                    train_ys_shuffled.chunk(num_batches, 0),
552                )
553            }
554            None => (vec![train_xs.clone()], vec![train_ys.clone()]),
555        };
556
557        let mut loss_sum = T::ZERO;
558
559        let mut se_sum = T::ZERO;
560        let mut ae_sum = T::ZERO;
561        let mut l_inf = T::ZERO;
562
563        for (train_xs_batch, train_ys_batch) in train_xs_batches.into_iter().zip(train_ys_batches) {
564            let prediction = model.forward(train_xs_batch);
565            let loss =
566                MseLoss::new().forward(prediction.clone(), train_ys_batch.clone(), Reduction::Mean);
567
568            let grads = GradientsParams::from_grads(loss.backward(), &model);
569            model = optim.step(learning_rate.0, model, grads);
570
571            loss_sum += loss.into_scalar();
572
573            let err = prediction - train_ys_batch;
574
575            se_sum += (err.clone() * err.clone()).sum().into_scalar();
576            ae_sum += err.clone().abs().sum().into_scalar();
577            l_inf = l_inf.max(err.abs().max().into_scalar());
578        }
579
580        let loss_mean = loss_sum / <T as FloatExt>::from_usize(num_batches.unwrap_or(1));
581
582        if loss_mean < best_loss {
583            best_loss = loss_mean;
584            best_epoch = epoch;
585            best_model_checkpoint = model.clone().into_record().into_item::<T::Precision>();
586        }
587
588        let rmse = stdv * (se_sum / <T as FloatExt>::from_usize(num_samples)).sqrt();
589        let mae = stdv * ae_sum / <T as FloatExt>::from_usize(num_samples);
590        let l_inf = stdv * l_inf;
591
592        log::info!(
593            "[{epoch}/{num_epochs}]: loss={loss_mean:0.3} MAE={mae:0.3} RMSE={rmse:0.3} Linf={l_inf:0.3}"
594        );
595    }
596
597    if best_epoch != num_epochs {
598        model = model.load_record(ModelRecord::from_item(best_model_checkpoint, device));
599
600        log::info!("restored from epoch {best_epoch} with lowest loss={best_loss:0.3}");
601    }
602
603    model
604}
605
606#[cfg(test)]
607#[expect(clippy::unwrap_used)]
608mod tests {
609    use super::*;
610
611    #[test]
612    fn empty() {
613        std::mem::drop(simple_logger::init());
614
615        let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
616            &NdArrayDevice::Cpu,
617            Array::<f32, _>::zeros((0,)),
618            NonZeroUsize::MIN,
619            Positive(1.0),
620            NonZeroUsize::MIN,
621            Positive(1e-4),
622            10,
623            None,
624            42,
625        )
626        .unwrap();
627        assert!(encoded.is_empty());
628        let mut decoded = Array::<f32, _>::zeros((0,));
629        decode_into::<f32, _, _, NdArray<f32>>(
630            &NdArrayDevice::Cpu,
631            encoded,
632            decoded.view_mut(),
633            NonZeroUsize::MIN,
634            NonZeroUsize::MIN,
635        )
636        .unwrap();
637    }
638
639    #[test]
640    fn ones() {
641        std::mem::drop(simple_logger::init());
642
643        let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
644            &NdArrayDevice::Cpu,
645            Array::<f32, _>::zeros((1, 1, 1, 1)),
646            NonZeroUsize::MIN,
647            Positive(1.0),
648            NonZeroUsize::MIN,
649            Positive(1e-4),
650            10,
651            None,
652            42,
653        )
654        .unwrap();
655        let mut decoded = Array::<f32, _>::zeros((1, 1, 1, 1));
656        decode_into::<f32, _, _, NdArray<f32>>(
657            &NdArrayDevice::Cpu,
658            encoded,
659            decoded.view_mut(),
660            NonZeroUsize::MIN,
661            NonZeroUsize::MIN,
662        )
663        .unwrap();
664    }
665
666    #[test]
667    fn r#const() {
668        std::mem::drop(simple_logger::init());
669
670        let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
671            &NdArrayDevice::Cpu,
672            Array::<f32, _>::from_elem((2, 1, 3), 42.0),
673            NonZeroUsize::MIN,
674            Positive(1.0),
675            NonZeroUsize::MIN,
676            Positive(1e-4),
677            10,
678            None,
679            42,
680        )
681        .unwrap();
682        let mut decoded = Array::<f32, _>::zeros((2, 1, 3));
683        decode_into::<f32, _, _, NdArray<f32>>(
684            &NdArrayDevice::Cpu,
685            encoded,
686            decoded.view_mut(),
687            NonZeroUsize::MIN,
688            NonZeroUsize::MIN,
689        )
690        .unwrap();
691    }
692
693    #[test]
694    fn const_batched() {
695        std::mem::drop(simple_logger::init());
696
697        let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
698            &NdArrayDevice::Cpu,
699            Array::<f32, _>::from_elem((2, 1, 3), 42.0),
700            NonZeroUsize::MIN,
701            Positive(1.0),
702            NonZeroUsize::MIN,
703            Positive(1e-4),
704            10,
705            Some(NonZeroUsize::MIN.saturating_add(1)),
706            42,
707        )
708        .unwrap();
709        let mut decoded = Array::<f32, _>::zeros((2, 1, 3));
710        decode_into::<f32, _, _, NdArray<f32>>(
711            &NdArrayDevice::Cpu,
712            encoded,
713            decoded.view_mut(),
714            NonZeroUsize::MIN,
715            NonZeroUsize::MIN,
716        )
717        .unwrap();
718    }
719
720    #[test]
721    fn linspace() {
722        std::mem::drop(simple_logger::init());
723
724        let data = Array::linspace(0.0_f64, 100.0_f64, 100);
725
726        let fourier_features = NonZeroUsize::new(16).unwrap();
727        let fourier_scale = Positive(10.0);
728        let num_blocks = NonZeroUsize::new(2).unwrap();
729        let learning_rate = Positive(1e-4);
730        let num_epochs = 100;
731        let seed = 42;
732
733        for mini_batch_size in [
734            None,                                         // no mini-batching
735            Some(NonZeroUsize::MIN),                      // stochastic
736            Some(NonZeroUsize::MIN.saturating_add(6)),    // mini-batched, remainder
737            Some(NonZeroUsize::MIN.saturating_add(9)),    // mini-batched
738            Some(NonZeroUsize::MIN.saturating_add(1000)), // mini-batched, truncated
739        ] {
740            let mut decoded = Array::<f64, _>::zeros(data.shape());
741            let encoded = encode::<f64, _, _, Autodiff<NdArray<f64>>>(
742                &NdArrayDevice::Cpu,
743                data.view(),
744                fourier_features,
745                fourier_scale,
746                num_blocks,
747                learning_rate,
748                num_epochs,
749                mini_batch_size,
750                seed,
751            )
752            .unwrap();
753
754            decode_into::<f64, _, _, NdArray<f64>>(
755                &NdArrayDevice::Cpu,
756                encoded,
757                decoded.view_mut(),
758                fourier_features,
759                num_blocks,
760            )
761            .unwrap();
762        }
763    }
764}