1#![expect(clippy::multiple_crate_versions)]
21
22use std::{borrow::Cow, num::NonZeroUsize, ops::AddAssign};
23
24use burn::{
25 backend::{ndarray::NdArrayDevice, Autodiff, NdArray},
26 module::{Module, Param},
27 nn::loss::{MseLoss, Reduction},
28 optim::{AdamConfig, GradientsParams, Optimizer},
29 prelude::Backend,
30 record::{
31 BinBytesRecorder, DoublePrecisionSettings, FullPrecisionSettings, PrecisionSettings,
32 Record, Recorder, RecorderError,
33 },
34 tensor::{
35 backend::AutodiffBackend, Distribution, Element as BurnElement, Float, Tensor, TensorData,
36 },
37};
38use itertools::Itertools;
39use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension, Ix1, Order, Zip};
40use num_traits::{ConstOne, ConstZero, Float as FloatTrait, FromPrimitive};
41use numcodecs::{
42 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
43 Codec, StaticCodec, StaticCodecConfig,
44};
45use schemars::{json_schema, JsonSchema, Schema, SchemaGenerator};
46use serde::{Deserialize, Deserializer, Serialize, Serializer};
47use thiserror::Error;
48
49#[cfg(test)]
50use ::serde_json as _;
51
52use ::{bincode as _, bincode_derive as _};
54
55mod modules;
56
57use modules::{Model, ModelConfig, ModelExtra, ModelRecord};
58
59#[derive(Clone, Serialize, Deserialize, JsonSchema)]
60#[serde(deny_unknown_fields)]
61pub struct FourierNetworkCodec {
68 pub fourier_features: NonZeroUsize,
70 pub fourier_scale: Positive<f64>,
72 pub num_blocks: NonZeroUsize,
74 pub learning_rate: Positive<f64>,
76 pub num_epochs: usize,
78 #[serde(deserialize_with = "deserialize_required_option")]
84 #[schemars(required, extend("type" = ["integer", "null"]))]
85 pub mini_batch_size: Option<NonZeroUsize>,
86 pub seed: u64,
88}
89
90fn deserialize_required_option<'de, T: serde::Deserialize<'de>, D: serde::Deserializer<'de>>(
92 deserializer: D,
93) -> Result<Option<T>, D::Error> {
94 Option::<T>::deserialize(deserializer)
95}
96
97impl Codec for FourierNetworkCodec {
98 type Error = FourierNetworkCodecError;
99
100 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
101 match data {
102 AnyCowArray::F32(data) => Ok(AnyArray::U8(
103 encode::<f32, _, _, Autodiff<NdArray<f32>>>(
104 &NdArrayDevice::Cpu,
105 data,
106 self.fourier_features,
107 self.fourier_scale,
108 self.num_blocks,
109 self.learning_rate,
110 self.num_epochs,
111 self.mini_batch_size,
112 self.seed,
113 )?
114 .into_dyn(),
115 )),
116 AnyCowArray::F64(data) => Ok(AnyArray::U8(
117 encode::<f64, _, _, Autodiff<NdArray<f64>>>(
118 &NdArrayDevice::Cpu,
119 data,
120 self.fourier_features,
121 self.fourier_scale,
122 self.num_blocks,
123 self.learning_rate,
124 self.num_epochs,
125 self.mini_batch_size,
126 self.seed,
127 )?
128 .into_dyn(),
129 )),
130 encoded => Err(FourierNetworkCodecError::UnsupportedDtype(encoded.dtype())),
131 }
132 }
133
134 fn decode(&self, _encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
135 Err(FourierNetworkCodecError::MissingDecodingOutput)
136 }
137
138 fn decode_into(
139 &self,
140 encoded: AnyArrayView,
141 decoded: AnyArrayViewMut,
142 ) -> Result<(), Self::Error> {
143 let AnyArrayView::U8(encoded) = encoded else {
144 return Err(FourierNetworkCodecError::EncodedDataNotBytes {
145 dtype: encoded.dtype(),
146 });
147 };
148
149 let Ok(encoded): Result<ArrayBase<_, Ix1>, _> = encoded.view().into_dimensionality() else {
150 return Err(FourierNetworkCodecError::EncodedDataNotOneDimensional {
151 shape: encoded.shape().to_vec(),
152 });
153 };
154
155 match decoded {
156 AnyArrayViewMut::F32(decoded) => decode_into::<f32, _, _, NdArray<f32>>(
157 &NdArrayDevice::Cpu,
158 encoded,
159 decoded,
160 self.fourier_features,
161 self.num_blocks,
162 ),
163 AnyArrayViewMut::F64(decoded) => decode_into::<f64, _, _, NdArray<f64>>(
164 &NdArrayDevice::Cpu,
165 encoded,
166 decoded,
167 self.fourier_features,
168 self.num_blocks,
169 ),
170 decoded => Err(FourierNetworkCodecError::UnsupportedDtype(decoded.dtype())),
171 }
172 }
173}
174
175impl StaticCodec for FourierNetworkCodec {
176 const CODEC_ID: &'static str = "fourier-network";
177
178 type Config<'de> = Self;
179
180 fn from_config(config: Self::Config<'_>) -> Self {
181 config
182 }
183
184 fn get_config(&self) -> StaticCodecConfig<Self> {
185 StaticCodecConfig::from(self)
186 }
187}
188
189#[expect(clippy::derive_partial_eq_without_eq)] #[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
191pub struct Positive<T: FloatTrait>(T);
193
194impl Serialize for Positive<f64> {
195 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
196 serializer.serialize_f64(self.0)
197 }
198}
199
200impl<'de> Deserialize<'de> for Positive<f64> {
201 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
202 let x = f64::deserialize(deserializer)?;
203
204 if x > 0.0 {
205 Ok(Self(x))
206 } else {
207 Err(serde::de::Error::invalid_value(
208 serde::de::Unexpected::Float(x),
209 &"a positive value",
210 ))
211 }
212 }
213}
214
215impl JsonSchema for Positive<f64> {
216 fn schema_name() -> Cow<'static, str> {
217 Cow::Borrowed("PositiveF64")
218 }
219
220 fn schema_id() -> Cow<'static, str> {
221 Cow::Borrowed(concat!(module_path!(), "::", "Positive<f64>"))
222 }
223
224 fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
225 json_schema!({
226 "type": "number",
227 "exclusiveMinimum": 0.0
228 })
229 }
230}
231
232#[derive(Debug, Error)]
233pub enum FourierNetworkCodecError {
235 #[error("FourierNetwork does not support the dtype {0}")]
237 UnsupportedDtype(AnyArrayDType),
238 #[error("FourierNetwork does not support non-finite (infinite or NaN) floating point data")]
241 NonFiniteData,
242 #[error("FourierNetwork failed during a neural network computation")]
244 NeuralNetworkError {
245 #[from]
247 source: NeuralNetworkError,
248 },
249 #[error("FourierNetwork must be provided the output array during decoding")]
251 MissingDecodingOutput,
252 #[error(
255 "FourierNetwork can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
256 )]
257 EncodedDataNotBytes {
258 dtype: AnyArrayDType,
260 },
261 #[error("FourierNetwork can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}")]
264 EncodedDataNotOneDimensional {
265 shape: Vec<usize>,
267 },
268 #[error("FourierNetwork cannot decode into the provided array")]
270 MismatchedDecodeIntoArray {
271 #[from]
273 source: AnyArrayAssignError,
274 },
275}
276
277#[derive(Debug, Error)]
278#[error(transparent)]
279pub struct NeuralNetworkError(RecorderError);
281
282pub trait FloatExt:
284 AddAssign + BurnElement + ConstOne + ConstZero + FloatTrait + FromPrimitive
285{
286 type Precision: PrecisionSettings;
288
289 fn from_usize(x: usize) -> Self;
291}
292
293impl FloatExt for f32 {
294 type Precision = FullPrecisionSettings;
295
296 #[expect(clippy::cast_precision_loss)]
297 fn from_usize(x: usize) -> Self {
298 x as Self
299 }
300}
301
302impl FloatExt for f64 {
303 type Precision = DoublePrecisionSettings;
304
305 #[expect(clippy::cast_precision_loss)]
306 fn from_usize(x: usize) -> Self {
307 x as Self
308 }
309}
310
311#[expect(clippy::similar_names)] #[expect(clippy::missing_panics_doc)] #[expect(clippy::too_many_arguments)] pub fn encode<T: FloatExt, S: Data<Elem = T>, D: Dimension, B: AutodiffBackend<FloatElem = T>>(
334 device: &B::Device,
335 data: ArrayBase<S, D>,
336 fourier_features: NonZeroUsize,
337 fourier_scale: Positive<f64>,
338 num_blocks: NonZeroUsize,
339 learning_rate: Positive<f64>,
340 num_epochs: usize,
341 mini_batch_size: Option<NonZeroUsize>,
342 seed: u64,
343) -> Result<Array<u8, Ix1>, FourierNetworkCodecError> {
344 let Some(mean) = data.mean() else {
345 return Ok(Array::from_vec(Vec::new()));
346 };
347 let stdv = data.std(T::ZERO);
348 let stdv = if stdv == T::ZERO { T::ONE } else { stdv };
349
350 if !Zip::from(&data).all(|x| x.is_finite()) {
351 return Err(FourierNetworkCodecError::NonFiniteData);
352 }
353
354 B::seed(seed);
355
356 let b_t = Tensor::<B, 2, Float>::random(
357 [data.ndim(), fourier_features.get()],
358 Distribution::Normal(0.0, fourier_scale.0),
359 device,
360 );
361
362 let train_xs = flat_grid_like(&data, device);
363 let train_xs = fourier_mapping(train_xs, b_t.clone());
364
365 let train_ys_shape = [data.len(), 1];
366 let mut train_ys = data.into_owned();
367 train_ys.mapv_inplace(|x| (x - mean) / stdv);
368 #[expect(clippy::unwrap_used)] let train_ys = train_ys
370 .into_shape_clone((train_ys_shape, Order::RowMajor))
371 .unwrap();
372 let train_ys = Tensor::from_data(
373 TensorData::new(train_ys.into_raw_vec_and_offset().0, train_ys_shape),
374 device,
375 );
376
377 let model = train(
378 device,
379 &train_xs,
380 &train_ys,
381 fourier_features,
382 num_blocks,
383 learning_rate,
384 num_epochs,
385 mini_batch_size,
386 stdv,
387 );
388
389 let extra = ModelExtra {
390 model,
391 b_t: Param::from_tensor(b_t).set_require_grad(false),
392 mean: Param::from_tensor(Tensor::from_data(
393 TensorData::new(vec![mean], vec![1]),
394 device,
395 ))
396 .set_require_grad(false),
397 stdv: Param::from_tensor(Tensor::from_data(
398 TensorData::new(vec![stdv], vec![1]),
399 device,
400 ))
401 .set_require_grad(false),
402 };
403
404 let recorder = BinBytesRecorder::<T::Precision>::new();
405 let encoded = recorder
406 .record(extra.into_record(), ())
407 .map_err(NeuralNetworkError)?;
408
409 Ok(Array::from_vec(encoded))
410}
411
412#[expect(clippy::missing_panics_doc)] pub fn decode_into<T: FloatExt, S: Data<Elem = u8>, D: Dimension, B: Backend<FloatElem = T>>(
427 device: &B::Device,
428 encoded: ArrayBase<S, Ix1>,
429 mut decoded: ArrayViewMut<T, D>,
430 fourier_features: NonZeroUsize,
431 num_blocks: NonZeroUsize,
432) -> Result<(), FourierNetworkCodecError> {
433 if encoded.is_empty() {
434 if decoded.is_empty() {
435 return Ok(());
436 }
437
438 return Err(FourierNetworkCodecError::MismatchedDecodeIntoArray {
439 source: AnyArrayAssignError::ShapeMismatch {
440 src: encoded.shape().to_vec(),
441 dst: decoded.shape().to_vec(),
442 },
443 });
444 }
445
446 let encoded = encoded.into_owned().into_raw_vec_and_offset().0;
447
448 let recorder = BinBytesRecorder::<T::Precision>::new();
449 let record = recorder.load(encoded, device).map_err(NeuralNetworkError)?;
450
451 let extra = ModelExtra::<B> {
452 model: ModelConfig::new(fourier_features, num_blocks).init(device),
453 b_t: Param::from_tensor(Tensor::zeros(
454 [decoded.ndim(), fourier_features.get()],
455 device,
456 ))
457 .set_require_grad(false),
458 mean: Param::from_tensor(Tensor::zeros([1], device)).set_require_grad(false),
459 stdv: Param::from_tensor(Tensor::ones([1], device)).set_require_grad(false),
460 }
461 .load_record(record);
462
463 let model = extra.model;
464 let b_t = extra.b_t.into_value();
465 let mean = extra.mean.into_value().into_scalar();
466 let stdv = extra.stdv.into_value().into_scalar();
467
468 let test_xs = flat_grid_like(&decoded, device);
469 let test_xs = fourier_mapping(test_xs, b_t);
470
471 let prediction = model.forward(test_xs).into_data();
472 #[expect(clippy::unwrap_used)] let prediction = prediction.as_slice().unwrap();
474
475 #[expect(clippy::unwrap_used)] decoded.assign(&ArrayView::from_shape(decoded.shape(), prediction).unwrap());
477 decoded.mapv_inplace(|x| (x * stdv) + mean);
478
479 Ok(())
480}
481
482fn flat_grid_like<T: FloatExt, S: Data<Elem = T>, D: Dimension, B: Backend<FloatElem = T>>(
483 a: &ArrayBase<S, D>,
484 device: &B::Device,
485) -> Tensor<B, 2, Float> {
486 let grid = a
487 .shape()
488 .iter()
489 .copied()
490 .map(|s| {
491 #[expect(clippy::useless_conversion)] (0..s)
493 .into_iter()
494 .map(move |x| <T as FloatExt>::from_usize(x) / <T as FloatExt>::from_usize(s))
495 })
496 .multi_cartesian_product()
497 .flatten()
498 .collect::<Vec<_>>();
499
500 Tensor::from_data(TensorData::new(grid, [a.len(), a.ndim()]), device)
501}
502
503fn fourier_mapping<B: Backend>(
504 xs: Tensor<B, 2, Float>,
505 b_t: Tensor<B, 2, Float>,
506) -> Tensor<B, 2, Float> {
507 let xs_proj = xs.mul_scalar(core::f64::consts::TAU).matmul(b_t);
508
509 Tensor::cat(vec![xs_proj.clone().sin(), xs_proj.cos()], 1)
510}
511
512#[expect(clippy::similar_names)] #[expect(clippy::too_many_arguments)] fn train<T: FloatExt, B: AutodiffBackend<FloatElem = T>>(
515 device: &B::Device,
516 train_xs: &Tensor<B, 2, Float>,
517 train_ys: &Tensor<B, 2, Float>,
518 fourier_features: NonZeroUsize,
519 num_blocks: NonZeroUsize,
520 learning_rate: Positive<f64>,
521 num_epochs: usize,
522 mini_batch_size: Option<NonZeroUsize>,
523 stdv: T,
524) -> Model<B> {
525 let num_samples = train_ys.shape().num_elements();
526 let num_batches = mini_batch_size.map(|b| num_samples.div_ceil(b.get()));
527
528 let mut model = ModelConfig::new(fourier_features, num_blocks).init(device);
529 let mut optim = AdamConfig::new().init();
530
531 let mut best_loss = T::infinity();
532 let mut best_epoch = 0;
533 let mut best_model_checkpoint = model.clone().into_record().into_item::<T::Precision>();
534
535 for epoch in 1..=num_epochs {
536 #[expect(clippy::option_if_let_else)]
537 let (train_xs_batches, train_ys_batches) = match num_batches {
538 Some(num_batches) => {
539 let shuffle = Tensor::<B, 1, Float>::random(
540 [num_samples],
541 Distribution::Uniform(0.0, 1.0),
542 device,
543 );
544 let shuffle_indices = shuffle.argsort(0);
545
546 let train_xs_shuffled = train_xs.clone().select(0, shuffle_indices.clone());
547 let train_ys_shuffled = train_ys.clone().select(0, shuffle_indices);
548
549 (
550 train_xs_shuffled.chunk(num_batches, 0),
551 train_ys_shuffled.chunk(num_batches, 0),
552 )
553 }
554 None => (vec![train_xs.clone()], vec![train_ys.clone()]),
555 };
556
557 let mut loss_sum = T::ZERO;
558
559 let mut se_sum = T::ZERO;
560 let mut ae_sum = T::ZERO;
561 let mut l_inf = T::ZERO;
562
563 for (train_xs_batch, train_ys_batch) in train_xs_batches.into_iter().zip(train_ys_batches) {
564 let prediction = model.forward(train_xs_batch);
565 let loss =
566 MseLoss::new().forward(prediction.clone(), train_ys_batch.clone(), Reduction::Mean);
567
568 let grads = GradientsParams::from_grads(loss.backward(), &model);
569 model = optim.step(learning_rate.0, model, grads);
570
571 loss_sum += loss.into_scalar();
572
573 let err = prediction - train_ys_batch;
574
575 se_sum += (err.clone() * err.clone()).sum().into_scalar();
576 ae_sum += err.clone().abs().sum().into_scalar();
577 l_inf = l_inf.max(err.abs().max().into_scalar());
578 }
579
580 let loss_mean = loss_sum / <T as FloatExt>::from_usize(num_batches.unwrap_or(1));
581
582 if loss_mean < best_loss {
583 best_loss = loss_mean;
584 best_epoch = epoch;
585 best_model_checkpoint = model.clone().into_record().into_item::<T::Precision>();
586 }
587
588 let rmse = stdv * (se_sum / <T as FloatExt>::from_usize(num_samples)).sqrt();
589 let mae = stdv * ae_sum / <T as FloatExt>::from_usize(num_samples);
590 let l_inf = stdv * l_inf;
591
592 log::info!("[{epoch}/{num_epochs}]: loss={loss_mean:0.3} MAE={mae:0.3} RMSE={rmse:0.3} Linf={l_inf:0.3}");
593 }
594
595 if best_epoch != num_epochs {
596 model = model.load_record(ModelRecord::from_item(best_model_checkpoint, device));
597
598 log::info!("restored from epoch {best_epoch} with lowest loss={best_loss:0.3}");
599 }
600
601 model
602}
603
604#[cfg(test)]
605#[expect(clippy::unwrap_used)]
606mod tests {
607 use super::*;
608
609 #[test]
610 fn empty() {
611 std::mem::drop(simple_logger::init());
612
613 let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
614 &NdArrayDevice::Cpu,
615 Array::<f32, _>::zeros((0,)),
616 NonZeroUsize::MIN,
617 Positive(1.0),
618 NonZeroUsize::MIN,
619 Positive(1e-4),
620 10,
621 None,
622 42,
623 )
624 .unwrap();
625 assert!(encoded.is_empty());
626 let mut decoded = Array::<f32, _>::zeros((0,));
627 decode_into::<f32, _, _, NdArray<f32>>(
628 &NdArrayDevice::Cpu,
629 encoded,
630 decoded.view_mut(),
631 NonZeroUsize::MIN,
632 NonZeroUsize::MIN,
633 )
634 .unwrap();
635 }
636
637 #[test]
638 fn ones() {
639 std::mem::drop(simple_logger::init());
640
641 let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
642 &NdArrayDevice::Cpu,
643 Array::<f32, _>::zeros((1, 1, 1, 1)),
644 NonZeroUsize::MIN,
645 Positive(1.0),
646 NonZeroUsize::MIN,
647 Positive(1e-4),
648 10,
649 None,
650 42,
651 )
652 .unwrap();
653 let mut decoded = Array::<f32, _>::zeros((1, 1, 1, 1));
654 decode_into::<f32, _, _, NdArray<f32>>(
655 &NdArrayDevice::Cpu,
656 encoded,
657 decoded.view_mut(),
658 NonZeroUsize::MIN,
659 NonZeroUsize::MIN,
660 )
661 .unwrap();
662 }
663
664 #[test]
665 fn r#const() {
666 std::mem::drop(simple_logger::init());
667
668 let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
669 &NdArrayDevice::Cpu,
670 Array::<f32, _>::from_elem((2, 1, 3), 42.0),
671 NonZeroUsize::MIN,
672 Positive(1.0),
673 NonZeroUsize::MIN,
674 Positive(1e-4),
675 10,
676 None,
677 42,
678 )
679 .unwrap();
680 let mut decoded = Array::<f32, _>::zeros((2, 1, 3));
681 decode_into::<f32, _, _, NdArray<f32>>(
682 &NdArrayDevice::Cpu,
683 encoded,
684 decoded.view_mut(),
685 NonZeroUsize::MIN,
686 NonZeroUsize::MIN,
687 )
688 .unwrap();
689 }
690
691 #[test]
692 fn const_batched() {
693 std::mem::drop(simple_logger::init());
694
695 let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
696 &NdArrayDevice::Cpu,
697 Array::<f32, _>::from_elem((2, 1, 3), 42.0),
698 NonZeroUsize::MIN,
699 Positive(1.0),
700 NonZeroUsize::MIN,
701 Positive(1e-4),
702 10,
703 Some(NonZeroUsize::MIN.saturating_add(1)),
704 42,
705 )
706 .unwrap();
707 let mut decoded = Array::<f32, _>::zeros((2, 1, 3));
708 decode_into::<f32, _, _, NdArray<f32>>(
709 &NdArrayDevice::Cpu,
710 encoded,
711 decoded.view_mut(),
712 NonZeroUsize::MIN,
713 NonZeroUsize::MIN,
714 )
715 .unwrap();
716 }
717
718 #[test]
719 fn linspace() {
720 std::mem::drop(simple_logger::init());
721
722 let data = Array::linspace(0.0_f64, 100.0_f64, 100);
723
724 let fourier_features = NonZeroUsize::new(16).unwrap();
725 let fourier_scale = Positive(10.0);
726 let num_blocks = NonZeroUsize::new(2).unwrap();
727 let learning_rate = Positive(1e-4);
728 let num_epochs = 100;
729 let seed = 42;
730
731 for mini_batch_size in [
732 None, Some(NonZeroUsize::MIN), Some(NonZeroUsize::MIN.saturating_add(6)), Some(NonZeroUsize::MIN.saturating_add(9)), Some(NonZeroUsize::MIN.saturating_add(1000)), ] {
738 let mut decoded = Array::<f64, _>::zeros(data.shape());
739 let encoded = encode::<f64, _, _, Autodiff<NdArray<f64>>>(
740 &NdArrayDevice::Cpu,
741 data.view(),
742 fourier_features,
743 fourier_scale,
744 num_blocks,
745 learning_rate,
746 num_epochs,
747 mini_batch_size,
748 seed,
749 )
750 .unwrap();
751
752 decode_into::<f64, _, _, NdArray<f64>>(
753 &NdArrayDevice::Cpu,
754 encoded,
755 decoded.view_mut(),
756 fourier_features,
757 num_blocks,
758 )
759 .unwrap();
760 }
761 }
762}