1#![expect(clippy::multiple_crate_versions)]
21
22use std::{borrow::Cow, num::NonZeroUsize, ops::AddAssign};
23
24use burn::{
25    backend::{Autodiff, NdArray, ndarray::NdArrayDevice},
26    module::{Module, Param},
27    nn::loss::{MseLoss, Reduction},
28    optim::{AdamConfig, GradientsParams, Optimizer},
29    prelude::Backend,
30    record::{
31        BinBytesRecorder, DoublePrecisionSettings, FullPrecisionSettings, PrecisionSettings,
32        Record, Recorder, RecorderError,
33    },
34    tensor::{
35        Distribution, Element as BurnElement, Float, Tensor, TensorData, backend::AutodiffBackend,
36    },
37};
38use itertools::Itertools;
39use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension, Ix1, Order, Zip};
40use num_traits::{ConstOne, ConstZero, Float as FloatTrait, FromPrimitive};
41use numcodecs::{
42    AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
43    Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
44};
45use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
46use serde::{Deserialize, Deserializer, Serialize, Serializer};
47use thiserror::Error;
48
49use ::bytemuck as _;
51
52#[cfg(test)]
53use ::serde_json as _;
54
55mod modules;
56
57use modules::{Model, ModelConfig, ModelExtra, ModelRecord};
58
59type FourierNetworkCodecVersion = StaticCodecVersion<0, 1, 0>;
60
61#[derive(Clone, Serialize, Deserialize, JsonSchema)]
62#[serde(deny_unknown_fields)]
63pub struct FourierNetworkCodec {
70    pub fourier_features: NonZeroUsize,
72    pub fourier_scale: Positive<f64>,
74    pub num_blocks: NonZeroUsize,
76    pub learning_rate: Positive<f64>,
78    pub num_epochs: usize,
80    #[serde(deserialize_with = "deserialize_required_option")]
86    #[schemars(required, extend("type" = ["integer", "null"]))]
87    pub mini_batch_size: Option<NonZeroUsize>,
88    pub seed: u64,
90    #[serde(default, rename = "_version")]
92    pub version: FourierNetworkCodecVersion,
93}
94
95fn deserialize_required_option<'de, T: serde::Deserialize<'de>, D: serde::Deserializer<'de>>(
97    deserializer: D,
98) -> Result<Option<T>, D::Error> {
99    Option::<T>::deserialize(deserializer)
100}
101
102impl Codec for FourierNetworkCodec {
103    type Error = FourierNetworkCodecError;
104
105    fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
106        match data {
107            AnyCowArray::F32(data) => Ok(AnyArray::U8(
108                encode::<f32, _, _, Autodiff<NdArray<f32>>>(
109                    &NdArrayDevice::Cpu,
110                    data,
111                    self.fourier_features,
112                    self.fourier_scale,
113                    self.num_blocks,
114                    self.learning_rate,
115                    self.num_epochs,
116                    self.mini_batch_size,
117                    self.seed,
118                )?
119                .into_dyn(),
120            )),
121            AnyCowArray::F64(data) => Ok(AnyArray::U8(
122                encode::<f64, _, _, Autodiff<NdArray<f64>>>(
123                    &NdArrayDevice::Cpu,
124                    data,
125                    self.fourier_features,
126                    self.fourier_scale,
127                    self.num_blocks,
128                    self.learning_rate,
129                    self.num_epochs,
130                    self.mini_batch_size,
131                    self.seed,
132                )?
133                .into_dyn(),
134            )),
135            encoded => Err(FourierNetworkCodecError::UnsupportedDtype(encoded.dtype())),
136        }
137    }
138
139    fn decode(&self, _encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
140        Err(FourierNetworkCodecError::MissingDecodingOutput)
141    }
142
143    fn decode_into(
144        &self,
145        encoded: AnyArrayView,
146        decoded: AnyArrayViewMut,
147    ) -> Result<(), Self::Error> {
148        let AnyArrayView::U8(encoded) = encoded else {
149            return Err(FourierNetworkCodecError::EncodedDataNotBytes {
150                dtype: encoded.dtype(),
151            });
152        };
153
154        let Ok(encoded): Result<ArrayBase<_, Ix1>, _> = encoded.view().into_dimensionality() else {
155            return Err(FourierNetworkCodecError::EncodedDataNotOneDimensional {
156                shape: encoded.shape().to_vec(),
157            });
158        };
159
160        match decoded {
161            AnyArrayViewMut::F32(decoded) => decode_into::<f32, _, _, NdArray<f32>>(
162                &NdArrayDevice::Cpu,
163                encoded,
164                decoded,
165                self.fourier_features,
166                self.num_blocks,
167            ),
168            AnyArrayViewMut::F64(decoded) => decode_into::<f64, _, _, NdArray<f64>>(
169                &NdArrayDevice::Cpu,
170                encoded,
171                decoded,
172                self.fourier_features,
173                self.num_blocks,
174            ),
175            decoded => Err(FourierNetworkCodecError::UnsupportedDtype(decoded.dtype())),
176        }
177    }
178}
179
180impl StaticCodec for FourierNetworkCodec {
181    const CODEC_ID: &'static str = "fourier-network.rs";
182
183    type Config<'de> = Self;
184
185    fn from_config(config: Self::Config<'_>) -> Self {
186        config
187    }
188
189    fn get_config(&self) -> StaticCodecConfig<'_, Self> {
190        StaticCodecConfig::from(self)
191    }
192}
193
194#[expect(clippy::derive_partial_eq_without_eq)] #[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
196pub struct Positive<T: FloatTrait>(T);
198
199impl Serialize for Positive<f64> {
200    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
201        serializer.serialize_f64(self.0)
202    }
203}
204
205impl<'de> Deserialize<'de> for Positive<f64> {
206    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
207        let x = f64::deserialize(deserializer)?;
208
209        if x > 0.0 {
210            Ok(Self(x))
211        } else {
212            Err(serde::de::Error::invalid_value(
213                serde::de::Unexpected::Float(x),
214                &"a positive value",
215            ))
216        }
217    }
218}
219
220impl JsonSchema for Positive<f64> {
221    fn schema_name() -> Cow<'static, str> {
222        Cow::Borrowed("PositiveF64")
223    }
224
225    fn schema_id() -> Cow<'static, str> {
226        Cow::Borrowed(concat!(module_path!(), "::", "Positive<f64>"))
227    }
228
229    fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
230        json_schema!({
231            "type": "number",
232            "exclusiveMinimum": 0.0
233        })
234    }
235}
236
237#[derive(Debug, Error)]
238pub enum FourierNetworkCodecError {
240    #[error("FourierNetwork does not support the dtype {0}")]
242    UnsupportedDtype(AnyArrayDType),
243    #[error("FourierNetwork does not support non-finite (infinite or NaN) floating point data")]
246    NonFiniteData,
247    #[error("FourierNetwork failed during a neural network computation")]
249    NeuralNetworkError {
250        #[from]
252        source: NeuralNetworkError,
253    },
254    #[error("FourierNetwork must be provided the output array during decoding")]
256    MissingDecodingOutput,
257    #[error(
260        "FourierNetwork can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
261    )]
262    EncodedDataNotBytes {
263        dtype: AnyArrayDType,
265    },
266    #[error(
269        "FourierNetwork can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
270    )]
271    EncodedDataNotOneDimensional {
272        shape: Vec<usize>,
274    },
275    #[error("FourierNetwork cannot decode into the provided array")]
277    MismatchedDecodeIntoArray {
278        #[from]
280        source: AnyArrayAssignError,
281    },
282}
283
284#[derive(Debug, Error)]
285#[error(transparent)]
286pub struct NeuralNetworkError(RecorderError);
288
289pub trait FloatExt:
291    AddAssign + BurnElement + ConstOne + ConstZero + FloatTrait + FromPrimitive
292{
293    type Precision: PrecisionSettings;
295
296    fn from_usize(x: usize) -> Self;
298}
299
300impl FloatExt for f32 {
301    type Precision = FullPrecisionSettings;
302
303    #[expect(clippy::cast_precision_loss)]
304    fn from_usize(x: usize) -> Self {
305        x as Self
306    }
307}
308
309impl FloatExt for f64 {
310    type Precision = DoublePrecisionSettings;
311
312    #[expect(clippy::cast_precision_loss)]
313    fn from_usize(x: usize) -> Self {
314        x as Self
315    }
316}
317
318#[expect(clippy::similar_names)] #[expect(clippy::missing_panics_doc)] #[expect(clippy::too_many_arguments)] pub fn encode<T: FloatExt, S: Data<Elem = T>, D: Dimension, B: AutodiffBackend<FloatElem = T>>(
341    device: &B::Device,
342    data: ArrayBase<S, D>,
343    fourier_features: NonZeroUsize,
344    fourier_scale: Positive<f64>,
345    num_blocks: NonZeroUsize,
346    learning_rate: Positive<f64>,
347    num_epochs: usize,
348    mini_batch_size: Option<NonZeroUsize>,
349    seed: u64,
350) -> Result<Array<u8, Ix1>, FourierNetworkCodecError> {
351    let Some(mean) = data.mean() else {
352        return Ok(Array::from_vec(Vec::new()));
353    };
354    let stdv = data.std(T::ZERO);
355    let stdv = if stdv == T::ZERO { T::ONE } else { stdv };
356
357    if !Zip::from(&data).all(|x| x.is_finite()) {
358        return Err(FourierNetworkCodecError::NonFiniteData);
359    }
360
361    B::seed(seed);
362
363    let b_t = Tensor::<B, 2, Float>::random(
364        [data.ndim(), fourier_features.get()],
365        Distribution::Normal(0.0, fourier_scale.0),
366        device,
367    );
368
369    let train_xs = flat_grid_like(&data, device);
370    let train_xs = fourier_mapping(train_xs, b_t.clone());
371
372    let train_ys_shape = [data.len(), 1];
373    let mut train_ys = data.into_owned();
374    train_ys.mapv_inplace(|x| (x - mean) / stdv);
375    #[expect(clippy::unwrap_used)] let train_ys = train_ys
377        .into_shape_clone((train_ys_shape, Order::RowMajor))
378        .unwrap();
379    let train_ys = Tensor::from_data(
380        TensorData::new(train_ys.into_raw_vec_and_offset().0, train_ys_shape),
381        device,
382    );
383
384    let model = train(
385        device,
386        &train_xs,
387        &train_ys,
388        fourier_features,
389        num_blocks,
390        learning_rate,
391        num_epochs,
392        mini_batch_size,
393        stdv,
394    );
395
396    let extra = ModelExtra {
397        model: model.into_record(),
398        b_t: Param::from_tensor(b_t).set_require_grad(false),
399        mean: Param::from_tensor(Tensor::from_data(
400            TensorData::new(vec![mean], vec![1]),
401            device,
402        ))
403        .set_require_grad(false),
404        stdv: Param::from_tensor(Tensor::from_data(
405            TensorData::new(vec![stdv], vec![1]),
406            device,
407        ))
408        .set_require_grad(false),
409        version: StaticCodecVersion,
410    };
411
412    let recorder = BinBytesRecorder::<T::Precision>::new();
413    let encoded = recorder.record(extra, ()).map_err(NeuralNetworkError)?;
414
415    Ok(Array::from_vec(encoded))
416}
417
418#[expect(clippy::missing_panics_doc)] pub fn decode_into<T: FloatExt, S: Data<Elem = u8>, D: Dimension, B: Backend<FloatElem = T>>(
433    device: &B::Device,
434    encoded: ArrayBase<S, Ix1>,
435    mut decoded: ArrayViewMut<T, D>,
436    fourier_features: NonZeroUsize,
437    num_blocks: NonZeroUsize,
438) -> Result<(), FourierNetworkCodecError> {
439    if encoded.is_empty() {
440        if decoded.is_empty() {
441            return Ok(());
442        }
443
444        return Err(FourierNetworkCodecError::MismatchedDecodeIntoArray {
445            source: AnyArrayAssignError::ShapeMismatch {
446                src: encoded.shape().to_vec(),
447                dst: decoded.shape().to_vec(),
448            },
449        });
450    }
451
452    let encoded = encoded.into_owned().into_raw_vec_and_offset().0;
453
454    let recorder = BinBytesRecorder::<T::Precision>::new();
455    let record: ModelExtra<B> = recorder.load(encoded, device).map_err(NeuralNetworkError)?;
456
457    let model = ModelConfig::new(fourier_features, num_blocks)
458        .init(device)
459        .load_record(record.model);
460    let b_t = record.b_t.into_value();
461    let mean = record.mean.into_value().into_scalar();
462    let stdv = record.stdv.into_value().into_scalar();
463
464    let test_xs = flat_grid_like(&decoded, device);
465    let test_xs = fourier_mapping(test_xs, b_t);
466
467    let prediction = model.forward(test_xs).into_data();
468    #[expect(clippy::unwrap_used)] let prediction = prediction.as_slice().unwrap();
470
471    #[expect(clippy::unwrap_used)] decoded.assign(&ArrayView::from_shape(decoded.shape(), prediction).unwrap());
473    decoded.mapv_inplace(|x| (x * stdv) + mean);
474
475    Ok(())
476}
477
478fn flat_grid_like<T: FloatExt, S: Data<Elem = T>, D: Dimension, B: Backend<FloatElem = T>>(
479    a: &ArrayBase<S, D>,
480    device: &B::Device,
481) -> Tensor<B, 2, Float> {
482    let grid = a
483        .shape()
484        .iter()
485        .copied()
486        .map(|s| {
487            #[expect(clippy::useless_conversion)] (0..s)
489                .into_iter()
490                .map(move |x| <T as FloatExt>::from_usize(x) / <T as FloatExt>::from_usize(s))
491        })
492        .multi_cartesian_product()
493        .flatten()
494        .collect::<Vec<_>>();
495
496    Tensor::from_data(TensorData::new(grid, [a.len(), a.ndim()]), device)
497}
498
499fn fourier_mapping<B: Backend>(
500    xs: Tensor<B, 2, Float>,
501    b_t: Tensor<B, 2, Float>,
502) -> Tensor<B, 2, Float> {
503    let xs_proj = xs.mul_scalar(core::f64::consts::TAU).matmul(b_t);
504
505    Tensor::cat(vec![xs_proj.clone().sin(), xs_proj.cos()], 1)
506}
507
508#[expect(clippy::similar_names)] #[expect(clippy::too_many_arguments)] fn train<T: FloatExt, B: AutodiffBackend<FloatElem = T>>(
511    device: &B::Device,
512    train_xs: &Tensor<B, 2, Float>,
513    train_ys: &Tensor<B, 2, Float>,
514    fourier_features: NonZeroUsize,
515    num_blocks: NonZeroUsize,
516    learning_rate: Positive<f64>,
517    num_epochs: usize,
518    mini_batch_size: Option<NonZeroUsize>,
519    stdv: T,
520) -> Model<B> {
521    let num_samples = train_ys.shape().num_elements();
522    let num_batches = mini_batch_size.map(|b| num_samples.div_ceil(b.get()));
523
524    let mut model = ModelConfig::new(fourier_features, num_blocks).init(device);
525    let mut optim = AdamConfig::new().init();
526
527    let mut best_loss = T::infinity();
528    let mut best_epoch = 0;
529    let mut best_model_checkpoint = model.clone().into_record().into_item::<T::Precision>();
530
531    for epoch in 1..=num_epochs {
532        #[expect(clippy::option_if_let_else)]
533        let (train_xs_batches, train_ys_batches) = match num_batches {
534            Some(num_batches) => {
535                let shuffle = Tensor::<B, 1, Float>::random(
536                    [num_samples],
537                    Distribution::Uniform(0.0, 1.0),
538                    device,
539                );
540                let shuffle_indices = shuffle.argsort(0);
541
542                let train_xs_shuffled = train_xs.clone().select(0, shuffle_indices.clone());
543                let train_ys_shuffled = train_ys.clone().select(0, shuffle_indices);
544
545                (
546                    train_xs_shuffled.chunk(num_batches, 0),
547                    train_ys_shuffled.chunk(num_batches, 0),
548                )
549            }
550            None => (vec![train_xs.clone()], vec![train_ys.clone()]),
551        };
552
553        let mut loss_sum = T::ZERO;
554
555        let mut se_sum = T::ZERO;
556        let mut ae_sum = T::ZERO;
557        let mut l_inf = T::ZERO;
558
559        for (train_xs_batch, train_ys_batch) in train_xs_batches.into_iter().zip(train_ys_batches) {
560            let prediction = model.forward(train_xs_batch);
561            let loss =
562                MseLoss::new().forward(prediction.clone(), train_ys_batch.clone(), Reduction::Mean);
563
564            let grads = GradientsParams::from_grads(loss.backward(), &model);
565            model = optim.step(learning_rate.0, model, grads);
566
567            loss_sum += loss.into_scalar();
568
569            let err = prediction - train_ys_batch;
570
571            se_sum += (err.clone() * err.clone()).sum().into_scalar();
572            ae_sum += err.clone().abs().sum().into_scalar();
573            l_inf = l_inf.max(err.abs().max().into_scalar());
574        }
575
576        let loss_mean = loss_sum / <T as FloatExt>::from_usize(num_batches.unwrap_or(1));
577
578        if loss_mean < best_loss {
579            best_loss = loss_mean;
580            best_epoch = epoch;
581            best_model_checkpoint = model.clone().into_record().into_item::<T::Precision>();
582        }
583
584        let rmse = stdv * (se_sum / <T as FloatExt>::from_usize(num_samples)).sqrt();
585        let mae = stdv * ae_sum / <T as FloatExt>::from_usize(num_samples);
586        let l_inf = stdv * l_inf;
587
588        log::info!(
589            "[{epoch}/{num_epochs}]: loss={loss_mean:0.3} MAE={mae:0.3} RMSE={rmse:0.3} Linf={l_inf:0.3}"
590        );
591    }
592
593    if best_epoch != num_epochs {
594        model = model.load_record(ModelRecord::from_item(best_model_checkpoint, device));
595
596        log::info!("restored from epoch {best_epoch} with lowest loss={best_loss:0.3}");
597    }
598
599    model
600}
601
602#[cfg(test)]
603#[expect(clippy::unwrap_used)]
604mod tests {
605    use super::*;
606
607    #[test]
608    fn empty() {
609        std::mem::drop(simple_logger::init());
610
611        let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
612            &NdArrayDevice::Cpu,
613            Array::<f32, _>::zeros((0,)),
614            NonZeroUsize::MIN,
615            Positive(1.0),
616            NonZeroUsize::MIN,
617            Positive(1e-4),
618            10,
619            None,
620            42,
621        )
622        .unwrap();
623        assert!(encoded.is_empty());
624        let mut decoded = Array::<f32, _>::zeros((0,));
625        decode_into::<f32, _, _, NdArray<f32>>(
626            &NdArrayDevice::Cpu,
627            encoded,
628            decoded.view_mut(),
629            NonZeroUsize::MIN,
630            NonZeroUsize::MIN,
631        )
632        .unwrap();
633    }
634
635    #[test]
636    fn ones() {
637        std::mem::drop(simple_logger::init());
638
639        let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
640            &NdArrayDevice::Cpu,
641            Array::<f32, _>::zeros((1, 1, 1, 1)),
642            NonZeroUsize::MIN,
643            Positive(1.0),
644            NonZeroUsize::MIN,
645            Positive(1e-4),
646            10,
647            None,
648            42,
649        )
650        .unwrap();
651        let mut decoded = Array::<f32, _>::zeros((1, 1, 1, 1));
652        decode_into::<f32, _, _, NdArray<f32>>(
653            &NdArrayDevice::Cpu,
654            encoded,
655            decoded.view_mut(),
656            NonZeroUsize::MIN,
657            NonZeroUsize::MIN,
658        )
659        .unwrap();
660    }
661
662    #[test]
663    fn r#const() {
664        std::mem::drop(simple_logger::init());
665
666        let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
667            &NdArrayDevice::Cpu,
668            Array::<f32, _>::from_elem((2, 1, 3), 42.0),
669            NonZeroUsize::MIN,
670            Positive(1.0),
671            NonZeroUsize::MIN,
672            Positive(1e-4),
673            10,
674            None,
675            42,
676        )
677        .unwrap();
678        let mut decoded = Array::<f32, _>::zeros((2, 1, 3));
679        decode_into::<f32, _, _, NdArray<f32>>(
680            &NdArrayDevice::Cpu,
681            encoded,
682            decoded.view_mut(),
683            NonZeroUsize::MIN,
684            NonZeroUsize::MIN,
685        )
686        .unwrap();
687    }
688
689    #[test]
690    fn const_batched() {
691        std::mem::drop(simple_logger::init());
692
693        let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
694            &NdArrayDevice::Cpu,
695            Array::<f32, _>::from_elem((2, 1, 3), 42.0),
696            NonZeroUsize::MIN,
697            Positive(1.0),
698            NonZeroUsize::MIN,
699            Positive(1e-4),
700            10,
701            Some(NonZeroUsize::MIN.saturating_add(1)),
702            42,
703        )
704        .unwrap();
705        let mut decoded = Array::<f32, _>::zeros((2, 1, 3));
706        decode_into::<f32, _, _, NdArray<f32>>(
707            &NdArrayDevice::Cpu,
708            encoded,
709            decoded.view_mut(),
710            NonZeroUsize::MIN,
711            NonZeroUsize::MIN,
712        )
713        .unwrap();
714    }
715
716    #[test]
717    fn linspace() {
718        std::mem::drop(simple_logger::init());
719
720        let data = Array::linspace(0.0_f64, 100.0_f64, 100);
721
722        let fourier_features = NonZeroUsize::new(16).unwrap();
723        let fourier_scale = Positive(10.0);
724        let num_blocks = NonZeroUsize::new(2).unwrap();
725        let learning_rate = Positive(1e-4);
726        let num_epochs = 100;
727        let seed = 42;
728
729        for mini_batch_size in [
730            None,                                         Some(NonZeroUsize::MIN),                      Some(NonZeroUsize::MIN.saturating_add(6)),    Some(NonZeroUsize::MIN.saturating_add(9)),    Some(NonZeroUsize::MIN.saturating_add(1000)), ] {
736            let mut decoded = Array::<f64, _>::zeros(data.shape());
737            let encoded = encode::<f64, _, _, Autodiff<NdArray<f64>>>(
738                &NdArrayDevice::Cpu,
739                data.view(),
740                fourier_features,
741                fourier_scale,
742                num_blocks,
743                learning_rate,
744                num_epochs,
745                mini_batch_size,
746                seed,
747            )
748            .unwrap();
749
750            decode_into::<f64, _, _, NdArray<f64>>(
751                &NdArrayDevice::Cpu,
752                encoded,
753                decoded.view_mut(),
754                fourier_features,
755                num_blocks,
756            )
757            .unwrap();
758        }
759    }
760}