numcodecs_fourier_network/
lib.rs

1//! [![CI Status]][workflow] [![MSRV]][repo] [![Latest Version]][crates.io] [![Rust Doc Crate]][docs.rs] [![Rust Doc Main]][docs]
2//!
3//! [CI Status]: https://img.shields.io/github/actions/workflow/status/juntyr/numcodecs-rs/ci.yml?branch=main
4//! [workflow]: https://github.com/juntyr/numcodecs-rs/actions/workflows/ci.yml?query=branch%3Amain
5//!
6//! [MSRV]: https://img.shields.io/badge/MSRV-1.85.0-blue
7//! [repo]: https://github.com/juntyr/numcodecs-rs
8//!
9//! [Latest Version]: https://img.shields.io/crates/v/numcodecs-fourier-network
10//! [crates.io]: https://crates.io/crates/numcodecs-fourier-network
11//!
12//! [Rust Doc Crate]: https://img.shields.io/docsrs/numcodecs-fourier-network
13//! [docs.rs]: https://docs.rs/numcodecs-fourier-network/
14//!
15//! [Rust Doc Main]: https://img.shields.io/badge/docs-main-blue
16//! [docs]: https://juntyr.github.io/numcodecs-rs/numcodecs_fourier_network
17//!
18//! Fourier feature neural network codec implementation for the [`numcodecs`] API.
19
20#![expect(clippy::multiple_crate_versions)]
21
22use std::{borrow::Cow, num::NonZeroUsize, ops::AddAssign};
23
24use burn::{
25    backend::{Autodiff, NdArray, ndarray::NdArrayDevice},
26    module::{Module, Param},
27    nn::loss::{MseLoss, Reduction},
28    optim::{AdamConfig, GradientsParams, Optimizer},
29    prelude::Backend,
30    record::{
31        BinBytesRecorder, DoublePrecisionSettings, FullPrecisionSettings, PrecisionSettings,
32        Record, Recorder, RecorderError,
33    },
34    tensor::{
35        Distribution, Element as BurnElement, Float, Tensor, TensorData, backend::AutodiffBackend,
36    },
37};
38use itertools::Itertools;
39use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension, Ix1, Order, Zip};
40use num_traits::{ConstOne, ConstZero, Float as FloatTrait, FromPrimitive};
41use numcodecs::{
42    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
43    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
44};
45use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
46use serde::{Deserialize, Deserializer, Serialize, Serializer};
47use thiserror::Error;
48
49#[cfg(test)]
50use ::serde_json as _;
51
52mod modules;
53
54use modules::{Model, ModelConfig, ModelExtra, ModelRecord};
55
56type FourierNetworkCodecVersion = StaticCodecVersion<0, 1, 0>;
57
58#[derive(Clone, Serialize, Deserialize, JsonSchema)]
59#[serde(deny_unknown_fields)]
60/// Fourier network codec which trains and overfits a fourier feature neural
61/// network on encoding and predicts during decoding.
62///
63/// The approach is based on the papers by Tancik et al. 2020
64/// (<https://dl.acm.org/doi/abs/10.5555/3495724.3496356>)
65/// and by Huang and Hoefler 2020 (<https://arxiv.org/abs/2210.12538>).
66pub struct FourierNetworkCodec {
67    /// The number of Fourier features that the data coordinates are projected to
68    pub fourier_features: NonZeroUsize,
69    /// The standard deviation of the Fourier features
70    pub fourier_scale: Positive<f64>,
71    /// The number of blocks in the network
72    pub num_blocks: NonZeroUsize,
73    /// The learning rate for the `Adam` optimizer
74    pub learning_rate: Positive<f64>,
75    /// The number of epochs for which the network is trained
76    pub num_epochs: usize,
77    /// The optional mini-batch size used during training
78    ///
79    /// Setting the mini-batch size to `None` disables the use of batching,
80    /// i.e. the network is trained using one large batch that includes the
81    /// full data.
82    #[serde(deserialize_with = "deserialize_required_option")]
83    #[schemars(required, extend("type" = ["integer", "null"]))]
84    pub mini_batch_size: Option<NonZeroUsize>,
85    /// The seed for the random number generator used during encoding
86    pub seed: u64,
87    /// The codec's encoding format version. Do not provide this parameter explicitly.
88    #[serde(default, rename = "_version")]
89    pub version: FourierNetworkCodecVersion,
90}
91
92// using this wrapper function makes an Option<T> required
93fn deserialize_required_option<'de, T: serde::Deserialize<'de>, D: serde::Deserializer<'de>>(
94    deserializer: D,
95) -> Result<Option<T>, D::Error> {
96    Option::<T>::deserialize(deserializer)
97}
98
99impl Codec for FourierNetworkCodec {
100    type Error = FourierNetworkCodecError;
101
102    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
103        match data {
104            AnyCowArray::F32(data) => Ok(AnyArray::U8(
105                encode::<f32, _, _, Autodiff<NdArray<f32>>>(
106                    &NdArrayDevice::Cpu,
107                    data,
108                    self.fourier_features,
109                    self.fourier_scale,
110                    self.num_blocks,
111                    self.learning_rate,
112                    self.num_epochs,
113                    self.mini_batch_size,
114                    self.seed,
115                )?
116                .into_dyn(),
117            )),
118            AnyCowArray::F64(data) => Ok(AnyArray::U8(
119                encode::<f64, _, _, Autodiff<NdArray<f64>>>(
120                    &NdArrayDevice::Cpu,
121                    data,
122                    self.fourier_features,
123                    self.fourier_scale,
124                    self.num_blocks,
125                    self.learning_rate,
126                    self.num_epochs,
127                    self.mini_batch_size,
128                    self.seed,
129                )?
130                .into_dyn(),
131            )),
132            encoded => Err(FourierNetworkCodecError::UnsupportedDtype(encoded.dtype())),
133        }
134    }
135
136    fn decode(&self, _encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
137        Err(FourierNetworkCodecError::MissingDecodingOutput)
138    }
139
140    fn decode_into(
141        &self,
142        encoded: AnyArrayView,
143        decoded: AnyArrayViewMut,
144    ) -> Result<(), Self::Error> {
145        let AnyArrayView::U8(encoded) = encoded else {
146            return Err(FourierNetworkCodecError::EncodedDataNotBytes {
147                dtype: encoded.dtype(),
148            });
149        };
150
151        let Ok(encoded): Result<ArrayBase<_, Ix1>, _> = encoded.view().into_dimensionality() else {
152            return Err(FourierNetworkCodecError::EncodedDataNotOneDimensional {
153                shape: encoded.shape().to_vec(),
154            });
155        };
156
157        match decoded {
158            AnyArrayViewMut::F32(decoded) => decode_into::<f32, _, _, NdArray<f32>>(
159                &NdArrayDevice::Cpu,
160                encoded,
161                decoded,
162                self.fourier_features,
163                self.num_blocks,
164            ),
165            AnyArrayViewMut::F64(decoded) => decode_into::<f64, _, _, NdArray<f64>>(
166                &NdArrayDevice::Cpu,
167                encoded,
168                decoded,
169                self.fourier_features,
170                self.num_blocks,
171            ),
172            decoded => Err(FourierNetworkCodecError::UnsupportedDtype(decoded.dtype())),
173        }
174    }
175}
176
177impl StaticCodec for FourierNetworkCodec {
178    const CODEC_ID: &'static str = "fourier-network.rs";
179
180    type Config<'de> = Self;
181
182    fn from_config(config: Self::Config<'_>) -> Self {
183        config
184    }
185
186    fn get_config(&self) -> StaticCodecConfig<Self> {
187        StaticCodecConfig::from(self)
188    }
189}
190
191#[expect(clippy::derive_partial_eq_without_eq)] // floats are not Eq
192#[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
193/// Positive floating point number
194pub struct Positive<T: FloatTrait>(T);
195
196impl Serialize for Positive<f64> {
197    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
198        serializer.serialize_f64(self.0)
199    }
200}
201
202impl<'de> Deserialize<'de> for Positive<f64> {
203    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
204        let x = f64::deserialize(deserializer)?;
205
206        if x > 0.0 {
207            Ok(Self(x))
208        } else {
209            Err(serde::de::Error::invalid_value(
210                serde::de::Unexpected::Float(x),
211                &"a positive value",
212            ))
213        }
214    }
215}
216
217impl JsonSchema for Positive<f64> {
218    fn schema_name() -> Cow<'static, str> {
219        Cow::Borrowed("PositiveF64")
220    }
221
222    fn schema_id() -> Cow<'static, str> {
223        Cow::Borrowed(concat!(module_path!(), "::", "Positive<f64>"))
224    }
225
226    fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
227        json_schema!({
228            "type": "number",
229            "exclusiveMinimum": 0.0
230        })
231    }
232}
233
234#[derive(Debug, Error)]
235/// Errors that may occur when applying the [`FourierNetworkCodec`].
236pub enum FourierNetworkCodecError {
237    /// [`FourierNetworkCodec`] does not support the dtype
238    #[error("FourierNetwork does not support the dtype {0}")]
239    UnsupportedDtype(AnyArrayDType),
240    /// [`FourierNetworkCodec`] does not support non-finite (infinite or NaN) floating
241    /// point data
242    #[error("FourierNetwork does not support non-finite (infinite or NaN) floating point data")]
243    NonFiniteData,
244    /// [`FourierNetworkCodec`] failed during a neural network computation
245    #[error("FourierNetwork failed during a neural network computation")]
246    NeuralNetworkError {
247        /// The source of the error
248        #[from]
249        source: NeuralNetworkError,
250    },
251    /// [`FourierNetworkCodec`] must be provided the output array during decoding
252    #[error("FourierNetwork must be provided the output array during decoding")]
253    MissingDecodingOutput,
254    /// [`FourierNetworkCodec`] can only decode one-dimensional byte arrays but
255    /// received an array of a different dtype
256    #[error(
257        "FourierNetwork can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
258    )]
259    EncodedDataNotBytes {
260        /// The unexpected dtype of the encoded array
261        dtype: AnyArrayDType,
262    },
263    /// [`FourierNetworkCodec`] can only decode one-dimensional byte arrays but
264    /// received an array of a different shape
265    #[error(
266        "FourierNetwork can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
267    )]
268    EncodedDataNotOneDimensional {
269        /// The unexpected shape of the encoded array
270        shape: Vec<usize>,
271    },
272    /// [`FourierNetworkCodec`] cannot decode into the provided array
273    #[error("FourierNetwork cannot decode into the provided array")]
274    MismatchedDecodeIntoArray {
275        /// The source of the error
276        #[from]
277        source: AnyArrayAssignError,
278    },
279}
280
281#[derive(Debug, Error)]
282#[error(transparent)]
283/// Opaque error for when an error occurs in the neural network
284pub struct NeuralNetworkError(RecorderError);
285
286/// Floating point types.
287pub trait FloatExt:
288    AddAssign + BurnElement + ConstOne + ConstZero + FloatTrait + FromPrimitive
289{
290    /// The precision of this floating point type
291    type Precision: PrecisionSettings;
292
293    /// Convert a usize to a floating point number
294    fn from_usize(x: usize) -> Self;
295}
296
297impl FloatExt for f32 {
298    type Precision = FullPrecisionSettings;
299
300    #[expect(clippy::cast_precision_loss)]
301    fn from_usize(x: usize) -> Self {
302        x as Self
303    }
304}
305
306impl FloatExt for f64 {
307    type Precision = DoublePrecisionSettings;
308
309    #[expect(clippy::cast_precision_loss)]
310    fn from_usize(x: usize) -> Self {
311        x as Self
312    }
313}
314
315#[expect(clippy::similar_names)] // train_xs and train_ys
316#[expect(clippy::missing_panics_doc)] // only panics on implementation bugs
317#[expect(clippy::too_many_arguments)] // FIXME
318/// Encodes the `data` by training a fourier feature neural network.
319///
320/// The `fourier_features` are randomly sampled from a normal distribution with
321/// zero mean and `fourier_scale` standard deviation.
322///
323/// The neural network consists of `num_blocks` blocks.
324///
325/// The network is trained for `num_epochs` using the `learning_rate`
326/// and mini-batches of `mini_batch_size` if mini-batching is enabled.
327///
328/// All random numbers are generated using the provided `seed`.
329///
330/// # Errors
331///
332/// Errors with
333/// - [`FourierNetworkCodecError::NonFiniteData`] if any data element is
334///   non-finite (infinite or NaN)
335/// - [`FourierNetworkCodecError::NeuralNetworkError`] if an error occurs during
336///   the neural network computation
337pub fn encode<T: FloatExt, S: Data<Elem = T>, D: Dimension, B: AutodiffBackend<FloatElem = T>>(
338    device: &B::Device,
339    data: ArrayBase<S, D>,
340    fourier_features: NonZeroUsize,
341    fourier_scale: Positive<f64>,
342    num_blocks: NonZeroUsize,
343    learning_rate: Positive<f64>,
344    num_epochs: usize,
345    mini_batch_size: Option<NonZeroUsize>,
346    seed: u64,
347) -> Result<Array<u8, Ix1>, FourierNetworkCodecError> {
348    let Some(mean) = data.mean() else {
349        return Ok(Array::from_vec(Vec::new()));
350    };
351    let stdv = data.std(T::ZERO);
352    let stdv = if stdv == T::ZERO { T::ONE } else { stdv };
353
354    if !Zip::from(&data).all(|x| x.is_finite()) {
355        return Err(FourierNetworkCodecError::NonFiniteData);
356    }
357
358    B::seed(seed);
359
360    let b_t = Tensor::<B, 2, Float>::random(
361        [data.ndim(), fourier_features.get()],
362        Distribution::Normal(0.0, fourier_scale.0),
363        device,
364    );
365
366    let train_xs = flat_grid_like(&data, device);
367    let train_xs = fourier_mapping(train_xs, b_t.clone());
368
369    let train_ys_shape = [data.len(), 1];
370    let mut train_ys = data.into_owned();
371    train_ys.mapv_inplace(|x| (x - mean) / stdv);
372    #[expect(clippy::unwrap_used)] // reshape with one extra new axis cannot fail
373    let train_ys = train_ys
374        .into_shape_clone((train_ys_shape, Order::RowMajor))
375        .unwrap();
376    let train_ys = Tensor::from_data(
377        TensorData::new(train_ys.into_raw_vec_and_offset().0, train_ys_shape),
378        device,
379    );
380
381    let model = train(
382        device,
383        &train_xs,
384        &train_ys,
385        fourier_features,
386        num_blocks,
387        learning_rate,
388        num_epochs,
389        mini_batch_size,
390        stdv,
391    );
392
393    let extra = ModelExtra {
394        model: model.into_record(),
395        b_t: Param::from_tensor(b_t).set_require_grad(false),
396        mean: Param::from_tensor(Tensor::from_data(
397            TensorData::new(vec![mean], vec![1]),
398            device,
399        ))
400        .set_require_grad(false),
401        stdv: Param::from_tensor(Tensor::from_data(
402            TensorData::new(vec![stdv], vec![1]),
403            device,
404        ))
405        .set_require_grad(false),
406        version: StaticCodecVersion,
407    };
408
409    let recorder = BinBytesRecorder::<T::Precision>::new();
410    let encoded = recorder.record(extra, ()).map_err(NeuralNetworkError)?;
411
412    Ok(Array::from_vec(encoded))
413}
414
415#[expect(clippy::missing_panics_doc)] // only panics on implementation bugs
416/// Decodes the `encoded` data into the `decoded` output array by making a
417/// prediction using the fourier feature neural network.
418///
419/// The network must have been trained during [`encode`] using the same number
420/// of `feature_features` and `num_blocks`.
421///
422/// # Errors
423///
424/// Errors with
425/// - [`FourierNetworkCodecError::MismatchedDecodeIntoArray`] if the encoded
426///   array is empty but the decoded array is not
427/// - [`FourierNetworkCodecError::NeuralNetworkError`] if an error occurs during
428///   the neural network computation
429pub fn decode_into<T: FloatExt, S: Data<Elem = u8>, D: Dimension, B: Backend<FloatElem = T>>(
430    device: &B::Device,
431    encoded: ArrayBase<S, Ix1>,
432    mut decoded: ArrayViewMut<T, D>,
433    fourier_features: NonZeroUsize,
434    num_blocks: NonZeroUsize,
435) -> Result<(), FourierNetworkCodecError> {
436    if encoded.is_empty() {
437        if decoded.is_empty() {
438            return Ok(());
439        }
440
441        return Err(FourierNetworkCodecError::MismatchedDecodeIntoArray {
442            source: AnyArrayAssignError::ShapeMismatch {
443                src: encoded.shape().to_vec(),
444                dst: decoded.shape().to_vec(),
445            },
446        });
447    }
448
449    let encoded = encoded.into_owned().into_raw_vec_and_offset().0;
450
451    let recorder = BinBytesRecorder::<T::Precision>::new();
452    let record: ModelExtra<B> = recorder.load(encoded, device).map_err(NeuralNetworkError)?;
453
454    let model = ModelConfig::new(fourier_features, num_blocks)
455        .init(device)
456        .load_record(record.model);
457    let b_t = record.b_t.into_value();
458    let mean = record.mean.into_value().into_scalar();
459    let stdv = record.stdv.into_value().into_scalar();
460
461    let test_xs = flat_grid_like(&decoded, device);
462    let test_xs = fourier_mapping(test_xs, b_t);
463
464    let prediction = model.forward(test_xs).into_data();
465    #[expect(clippy::unwrap_used)] // same generic type, check must succeed
466    let prediction = prediction.as_slice().unwrap();
467
468    #[expect(clippy::unwrap_used)] // prediction shape is flattened
469    decoded.assign(&ArrayView::from_shape(decoded.shape(), prediction).unwrap());
470    decoded.mapv_inplace(|x| (x * stdv) + mean);
471
472    Ok(())
473}
474
475fn flat_grid_like<T: FloatExt, S: Data<Elem = T>, D: Dimension, B: Backend<FloatElem = T>>(
476    a: &ArrayBase<S, D>,
477    device: &B::Device,
478) -> Tensor<B, 2, Float> {
479    let grid = a
480        .shape()
481        .iter()
482        .copied()
483        .map(|s| {
484            #[expect(clippy::useless_conversion)] // (0..s).into_iter()
485            (0..s)
486                .into_iter()
487                .map(move |x| <T as FloatExt>::from_usize(x) / <T as FloatExt>::from_usize(s))
488        })
489        .multi_cartesian_product()
490        .flatten()
491        .collect::<Vec<_>>();
492
493    Tensor::from_data(TensorData::new(grid, [a.len(), a.ndim()]), device)
494}
495
496fn fourier_mapping<B: Backend>(
497    xs: Tensor<B, 2, Float>,
498    b_t: Tensor<B, 2, Float>,
499) -> Tensor<B, 2, Float> {
500    let xs_proj = xs.mul_scalar(core::f64::consts::TAU).matmul(b_t);
501
502    Tensor::cat(vec![xs_proj.clone().sin(), xs_proj.cos()], 1)
503}
504
505#[expect(clippy::similar_names)] // train_xs and train_ys
506#[expect(clippy::too_many_arguments)] // FIXME
507fn train<T: FloatExt, B: AutodiffBackend<FloatElem = T>>(
508    device: &B::Device,
509    train_xs: &Tensor<B, 2, Float>,
510    train_ys: &Tensor<B, 2, Float>,
511    fourier_features: NonZeroUsize,
512    num_blocks: NonZeroUsize,
513    learning_rate: Positive<f64>,
514    num_epochs: usize,
515    mini_batch_size: Option<NonZeroUsize>,
516    stdv: T,
517) -> Model<B> {
518    let num_samples = train_ys.shape().num_elements();
519    let num_batches = mini_batch_size.map(|b| num_samples.div_ceil(b.get()));
520
521    let mut model = ModelConfig::new(fourier_features, num_blocks).init(device);
522    let mut optim = AdamConfig::new().init();
523
524    let mut best_loss = T::infinity();
525    let mut best_epoch = 0;
526    let mut best_model_checkpoint = model.clone().into_record().into_item::<T::Precision>();
527
528    for epoch in 1..=num_epochs {
529        #[expect(clippy::option_if_let_else)]
530        let (train_xs_batches, train_ys_batches) = match num_batches {
531            Some(num_batches) => {
532                let shuffle = Tensor::<B, 1, Float>::random(
533                    [num_samples],
534                    Distribution::Uniform(0.0, 1.0),
535                    device,
536                );
537                let shuffle_indices = shuffle.argsort(0);
538
539                let train_xs_shuffled = train_xs.clone().select(0, shuffle_indices.clone());
540                let train_ys_shuffled = train_ys.clone().select(0, shuffle_indices);
541
542                (
543                    train_xs_shuffled.chunk(num_batches, 0),
544                    train_ys_shuffled.chunk(num_batches, 0),
545                )
546            }
547            None => (vec![train_xs.clone()], vec![train_ys.clone()]),
548        };
549
550        let mut loss_sum = T::ZERO;
551
552        let mut se_sum = T::ZERO;
553        let mut ae_sum = T::ZERO;
554        let mut l_inf = T::ZERO;
555
556        for (train_xs_batch, train_ys_batch) in train_xs_batches.into_iter().zip(train_ys_batches) {
557            let prediction = model.forward(train_xs_batch);
558            let loss =
559                MseLoss::new().forward(prediction.clone(), train_ys_batch.clone(), Reduction::Mean);
560
561            let grads = GradientsParams::from_grads(loss.backward(), &model);
562            model = optim.step(learning_rate.0, model, grads);
563
564            loss_sum += loss.into_scalar();
565
566            let err = prediction - train_ys_batch;
567
568            se_sum += (err.clone() * err.clone()).sum().into_scalar();
569            ae_sum += err.clone().abs().sum().into_scalar();
570            l_inf = l_inf.max(err.abs().max().into_scalar());
571        }
572
573        let loss_mean = loss_sum / <T as FloatExt>::from_usize(num_batches.unwrap_or(1));
574
575        if loss_mean < best_loss {
576            best_loss = loss_mean;
577            best_epoch = epoch;
578            best_model_checkpoint = model.clone().into_record().into_item::<T::Precision>();
579        }
580
581        let rmse = stdv * (se_sum / <T as FloatExt>::from_usize(num_samples)).sqrt();
582        let mae = stdv * ae_sum / <T as FloatExt>::from_usize(num_samples);
583        let l_inf = stdv * l_inf;
584
585        log::info!(
586            "[{epoch}/{num_epochs}]: loss={loss_mean:0.3} MAE={mae:0.3} RMSE={rmse:0.3} Linf={l_inf:0.3}"
587        );
588    }
589
590    if best_epoch != num_epochs {
591        model = model.load_record(ModelRecord::from_item(best_model_checkpoint, device));
592
593        log::info!("restored from epoch {best_epoch} with lowest loss={best_loss:0.3}");
594    }
595
596    model
597}
598
599#[cfg(test)]
600#[expect(clippy::unwrap_used)]
601mod tests {
602    use super::*;
603
604    #[test]
605    fn empty() {
606        std::mem::drop(simple_logger::init());
607
608        let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
609            &NdArrayDevice::Cpu,
610            Array::<f32, _>::zeros((0,)),
611            NonZeroUsize::MIN,
612            Positive(1.0),
613            NonZeroUsize::MIN,
614            Positive(1e-4),
615            10,
616            None,
617            42,
618        )
619        .unwrap();
620        assert!(encoded.is_empty());
621        let mut decoded = Array::<f32, _>::zeros((0,));
622        decode_into::<f32, _, _, NdArray<f32>>(
623            &NdArrayDevice::Cpu,
624            encoded,
625            decoded.view_mut(),
626            NonZeroUsize::MIN,
627            NonZeroUsize::MIN,
628        )
629        .unwrap();
630    }
631
632    #[test]
633    fn ones() {
634        std::mem::drop(simple_logger::init());
635
636        let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
637            &NdArrayDevice::Cpu,
638            Array::<f32, _>::zeros((1, 1, 1, 1)),
639            NonZeroUsize::MIN,
640            Positive(1.0),
641            NonZeroUsize::MIN,
642            Positive(1e-4),
643            10,
644            None,
645            42,
646        )
647        .unwrap();
648        let mut decoded = Array::<f32, _>::zeros((1, 1, 1, 1));
649        decode_into::<f32, _, _, NdArray<f32>>(
650            &NdArrayDevice::Cpu,
651            encoded,
652            decoded.view_mut(),
653            NonZeroUsize::MIN,
654            NonZeroUsize::MIN,
655        )
656        .unwrap();
657    }
658
659    #[test]
660    fn r#const() {
661        std::mem::drop(simple_logger::init());
662
663        let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
664            &NdArrayDevice::Cpu,
665            Array::<f32, _>::from_elem((2, 1, 3), 42.0),
666            NonZeroUsize::MIN,
667            Positive(1.0),
668            NonZeroUsize::MIN,
669            Positive(1e-4),
670            10,
671            None,
672            42,
673        )
674        .unwrap();
675        let mut decoded = Array::<f32, _>::zeros((2, 1, 3));
676        decode_into::<f32, _, _, NdArray<f32>>(
677            &NdArrayDevice::Cpu,
678            encoded,
679            decoded.view_mut(),
680            NonZeroUsize::MIN,
681            NonZeroUsize::MIN,
682        )
683        .unwrap();
684    }
685
686    #[test]
687    fn const_batched() {
688        std::mem::drop(simple_logger::init());
689
690        let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
691            &NdArrayDevice::Cpu,
692            Array::<f32, _>::from_elem((2, 1, 3), 42.0),
693            NonZeroUsize::MIN,
694            Positive(1.0),
695            NonZeroUsize::MIN,
696            Positive(1e-4),
697            10,
698            Some(NonZeroUsize::MIN.saturating_add(1)),
699            42,
700        )
701        .unwrap();
702        let mut decoded = Array::<f32, _>::zeros((2, 1, 3));
703        decode_into::<f32, _, _, NdArray<f32>>(
704            &NdArrayDevice::Cpu,
705            encoded,
706            decoded.view_mut(),
707            NonZeroUsize::MIN,
708            NonZeroUsize::MIN,
709        )
710        .unwrap();
711    }
712
713    #[test]
714    fn linspace() {
715        std::mem::drop(simple_logger::init());
716
717        let data = Array::linspace(0.0_f64, 100.0_f64, 100);
718
719        let fourier_features = NonZeroUsize::new(16).unwrap();
720        let fourier_scale = Positive(10.0);
721        let num_blocks = NonZeroUsize::new(2).unwrap();
722        let learning_rate = Positive(1e-4);
723        let num_epochs = 100;
724        let seed = 42;
725
726        for mini_batch_size in [
727            None,                                         // no mini-batching
728            Some(NonZeroUsize::MIN),                      // stochastic
729            Some(NonZeroUsize::MIN.saturating_add(6)),    // mini-batched, remainder
730            Some(NonZeroUsize::MIN.saturating_add(9)),    // mini-batched
731            Some(NonZeroUsize::MIN.saturating_add(1000)), // mini-batched, truncated
732        ] {
733            let mut decoded = Array::<f64, _>::zeros(data.shape());
734            let encoded = encode::<f64, _, _, Autodiff<NdArray<f64>>>(
735                &NdArrayDevice::Cpu,
736                data.view(),
737                fourier_features,
738                fourier_scale,
739                num_blocks,
740                learning_rate,
741                num_epochs,
742                mini_batch_size,
743                seed,
744            )
745            .unwrap();
746
747            decode_into::<f64, _, _, NdArray<f64>>(
748                &NdArrayDevice::Cpu,
749                encoded,
750                decoded.view_mut(),
751                fourier_features,
752                num_blocks,
753            )
754            .unwrap();
755        }
756    }
757}