1#![allow(clippy::multiple_crate_versions)]
21
22use std::{borrow::Cow, num::NonZeroUsize, ops::AddAssign};
23
24use burn::{
25 backend::{Autodiff, NdArray, ndarray::NdArrayDevice},
26 module::{Module, Param},
27 nn::loss::{MseLoss, Reduction},
28 optim::{AdamConfig, GradientsParams, Optimizer},
29 prelude::Backend,
30 record::{
31 BinBytesRecorder, DoublePrecisionSettings, FullPrecisionSettings, PrecisionSettings,
32 Record, Recorder, RecorderError,
33 },
34 tensor::{
35 Distribution, Element as BurnElement, Float, Tensor, TensorData, backend::AutodiffBackend,
36 },
37};
38use itertools::Itertools;
39use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension, Ix1, Order, Zip};
40use num_traits::{ConstOne, ConstZero, Float as FloatTrait, FromPrimitive};
41use numcodecs::{
42 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
43 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
44};
45use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
46use serde::{Deserialize, Deserializer, Serialize, Serializer};
47use thiserror::Error;
48
49use ::bytemuck as _;
51
52use ::wasm_bindgen as _;
55
56#[cfg(test)]
57use ::serde_json as _;
58
59mod modules;
60
61use modules::{Model, ModelConfig, ModelExtra, ModelRecord};
62
63type FourierNetworkCodecVersion = StaticCodecVersion<0, 1, 0>;
64
65#[derive(Clone, Serialize, Deserialize, JsonSchema)]
66#[serde(deny_unknown_fields)]
67pub struct FourierNetworkCodec {
74 pub fourier_features: NonZeroUsize,
76 pub fourier_scale: Positive<f64>,
78 pub num_blocks: NonZeroUsize,
80 pub learning_rate: Positive<f64>,
82 pub num_epochs: usize,
84 #[serde(deserialize_with = "deserialize_required_option")]
90 #[schemars(required, extend("type" = ["integer", "null"]))]
91 pub mini_batch_size: Option<NonZeroUsize>,
92 pub seed: u64,
94 #[serde(default, rename = "_version")]
96 pub version: FourierNetworkCodecVersion,
97}
98
99fn deserialize_required_option<'de, T: serde::Deserialize<'de>, D: serde::Deserializer<'de>>(
101 deserializer: D,
102) -> Result<Option<T>, D::Error> {
103 Option::<T>::deserialize(deserializer)
104}
105
106impl Codec for FourierNetworkCodec {
107 type Error = FourierNetworkCodecError;
108
109 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
110 match data {
111 AnyCowArray::F32(data) => Ok(AnyArray::U8(
112 encode::<f32, _, _, Autodiff<NdArray<f32>>>(
113 &NdArrayDevice::Cpu,
114 data,
115 self.fourier_features,
116 self.fourier_scale,
117 self.num_blocks,
118 self.learning_rate,
119 self.num_epochs,
120 self.mini_batch_size,
121 self.seed,
122 )?
123 .into_dyn(),
124 )),
125 AnyCowArray::F64(data) => Ok(AnyArray::U8(
126 encode::<f64, _, _, Autodiff<NdArray<f64>>>(
127 &NdArrayDevice::Cpu,
128 data,
129 self.fourier_features,
130 self.fourier_scale,
131 self.num_blocks,
132 self.learning_rate,
133 self.num_epochs,
134 self.mini_batch_size,
135 self.seed,
136 )?
137 .into_dyn(),
138 )),
139 encoded => Err(FourierNetworkCodecError::UnsupportedDtype(encoded.dtype())),
140 }
141 }
142
143 fn decode(&self, _encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
144 Err(FourierNetworkCodecError::MissingDecodingOutput)
145 }
146
147 fn decode_into(
148 &self,
149 encoded: AnyArrayView,
150 decoded: AnyArrayViewMut,
151 ) -> Result<(), Self::Error> {
152 let AnyArrayView::U8(encoded) = encoded else {
153 return Err(FourierNetworkCodecError::EncodedDataNotBytes {
154 dtype: encoded.dtype(),
155 });
156 };
157
158 let Ok(encoded): Result<ArrayBase<_, Ix1>, _> = encoded.view().into_dimensionality() else {
159 return Err(FourierNetworkCodecError::EncodedDataNotOneDimensional {
160 shape: encoded.shape().to_vec(),
161 });
162 };
163
164 match decoded {
165 AnyArrayViewMut::F32(decoded) => decode_into::<f32, _, _, NdArray<f32>>(
166 &NdArrayDevice::Cpu,
167 encoded,
168 decoded,
169 self.fourier_features,
170 self.num_blocks,
171 ),
172 AnyArrayViewMut::F64(decoded) => decode_into::<f64, _, _, NdArray<f64>>(
173 &NdArrayDevice::Cpu,
174 encoded,
175 decoded,
176 self.fourier_features,
177 self.num_blocks,
178 ),
179 decoded => Err(FourierNetworkCodecError::UnsupportedDtype(decoded.dtype())),
180 }
181 }
182}
183
184impl StaticCodec for FourierNetworkCodec {
185 const CODEC_ID: &'static str = "fourier-network.rs";
186
187 type Config<'de> = Self;
188
189 fn from_config(config: Self::Config<'_>) -> Self {
190 config
191 }
192
193 fn get_config(&self) -> StaticCodecConfig<'_, Self> {
194 StaticCodecConfig::from(self)
195 }
196}
197
198#[expect(clippy::derive_partial_eq_without_eq)] #[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
200pub struct Positive<T: FloatTrait>(T);
202
203impl Serialize for Positive<f64> {
204 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
205 serializer.serialize_f64(self.0)
206 }
207}
208
209impl<'de> Deserialize<'de> for Positive<f64> {
210 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
211 let x = f64::deserialize(deserializer)?;
212
213 if x > 0.0 {
214 Ok(Self(x))
215 } else {
216 Err(serde::de::Error::invalid_value(
217 serde::de::Unexpected::Float(x),
218 &"a positive value",
219 ))
220 }
221 }
222}
223
224impl JsonSchema for Positive<f64> {
225 fn schema_name() -> Cow<'static, str> {
226 Cow::Borrowed("PositiveF64")
227 }
228
229 fn schema_id() -> Cow<'static, str> {
230 Cow::Borrowed(concat!(module_path!(), "::", "Positive<f64>"))
231 }
232
233 fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
234 json_schema!({
235 "type": "number",
236 "exclusiveMinimum": 0.0
237 })
238 }
239}
240
241#[derive(Debug, Error)]
242pub enum FourierNetworkCodecError {
244 #[error("FourierNetwork does not support the dtype {0}")]
246 UnsupportedDtype(AnyArrayDType),
247 #[error("FourierNetwork does not support non-finite (infinite or NaN) floating point data")]
250 NonFiniteData,
251 #[error("FourierNetwork failed during a neural network computation")]
253 NeuralNetworkError {
254 #[from]
256 source: NeuralNetworkError,
257 },
258 #[error("FourierNetwork must be provided the output array during decoding")]
260 MissingDecodingOutput,
261 #[error(
264 "FourierNetwork can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
265 )]
266 EncodedDataNotBytes {
267 dtype: AnyArrayDType,
269 },
270 #[error(
273 "FourierNetwork can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
274 )]
275 EncodedDataNotOneDimensional {
276 shape: Vec<usize>,
278 },
279 #[error("FourierNetwork cannot decode into the provided array")]
281 MismatchedDecodeIntoArray {
282 #[from]
284 source: AnyArrayAssignError,
285 },
286}
287
288#[derive(Debug, Error)]
289#[error(transparent)]
290pub struct NeuralNetworkError(RecorderError);
292
293pub trait FloatExt:
295 AddAssign + BurnElement + ConstOne + ConstZero + FloatTrait + FromPrimitive
296{
297 type Precision: PrecisionSettings;
299
300 fn from_usize(x: usize) -> Self;
302}
303
304impl FloatExt for f32 {
305 type Precision = FullPrecisionSettings;
306
307 #[expect(clippy::cast_precision_loss)]
308 fn from_usize(x: usize) -> Self {
309 x as Self
310 }
311}
312
313impl FloatExt for f64 {
314 type Precision = DoublePrecisionSettings;
315
316 #[expect(clippy::cast_precision_loss)]
317 fn from_usize(x: usize) -> Self {
318 x as Self
319 }
320}
321
322#[expect(clippy::similar_names)] #[expect(clippy::missing_panics_doc)] #[expect(clippy::too_many_arguments)] pub fn encode<T: FloatExt, S: Data<Elem = T>, D: Dimension, B: AutodiffBackend<FloatElem = T>>(
345 device: &B::Device,
346 data: ArrayBase<S, D>,
347 fourier_features: NonZeroUsize,
348 fourier_scale: Positive<f64>,
349 num_blocks: NonZeroUsize,
350 learning_rate: Positive<f64>,
351 num_epochs: usize,
352 mini_batch_size: Option<NonZeroUsize>,
353 seed: u64,
354) -> Result<Array<u8, Ix1>, FourierNetworkCodecError> {
355 let Some(mean) = data.mean() else {
356 return Ok(Array::from_vec(Vec::new()));
357 };
358 let stdv = data.std(T::ZERO);
359 let stdv = if stdv == T::ZERO { T::ONE } else { stdv };
360
361 if !Zip::from(&data).all(|x| x.is_finite()) {
362 return Err(FourierNetworkCodecError::NonFiniteData);
363 }
364
365 B::seed(seed);
366
367 let b_t = Tensor::<B, 2, Float>::random(
368 [data.ndim(), fourier_features.get()],
369 Distribution::Normal(0.0, fourier_scale.0),
370 device,
371 );
372
373 let train_xs = flat_grid_like(&data, device);
374 let train_xs = fourier_mapping(train_xs, b_t.clone());
375
376 let train_ys_shape = [data.len(), 1];
377 let mut train_ys = data.into_owned();
378 train_ys.mapv_inplace(|x| (x - mean) / stdv);
379 #[expect(clippy::unwrap_used)] let train_ys = train_ys
381 .into_shape_clone((train_ys_shape, Order::RowMajor))
382 .unwrap();
383 let train_ys = Tensor::from_data(
384 TensorData::new(train_ys.into_raw_vec_and_offset().0, train_ys_shape),
385 device,
386 );
387
388 let model = train(
389 device,
390 &train_xs,
391 &train_ys,
392 fourier_features,
393 num_blocks,
394 learning_rate,
395 num_epochs,
396 mini_batch_size,
397 stdv,
398 );
399
400 let extra = ModelExtra {
401 model: model.into_record(),
402 b_t: Param::from_tensor(b_t).set_require_grad(false),
403 mean: Param::from_tensor(Tensor::from_data(
404 TensorData::new(vec![mean], vec![1]),
405 device,
406 ))
407 .set_require_grad(false),
408 stdv: Param::from_tensor(Tensor::from_data(
409 TensorData::new(vec![stdv], vec![1]),
410 device,
411 ))
412 .set_require_grad(false),
413 version: StaticCodecVersion,
414 };
415
416 let recorder = BinBytesRecorder::<T::Precision>::new();
417 let encoded = recorder.record(extra, ()).map_err(NeuralNetworkError)?;
418
419 Ok(Array::from_vec(encoded))
420}
421
422#[expect(clippy::missing_panics_doc)] pub fn decode_into<T: FloatExt, S: Data<Elem = u8>, D: Dimension, B: Backend<FloatElem = T>>(
437 device: &B::Device,
438 encoded: ArrayBase<S, Ix1>,
439 mut decoded: ArrayViewMut<T, D>,
440 fourier_features: NonZeroUsize,
441 num_blocks: NonZeroUsize,
442) -> Result<(), FourierNetworkCodecError> {
443 if encoded.is_empty() {
444 if decoded.is_empty() {
445 return Ok(());
446 }
447
448 return Err(FourierNetworkCodecError::MismatchedDecodeIntoArray {
449 source: AnyArrayAssignError::ShapeMismatch {
450 src: encoded.shape().to_vec(),
451 dst: decoded.shape().to_vec(),
452 },
453 });
454 }
455
456 let encoded = encoded.into_owned().into_raw_vec_and_offset().0;
457
458 let recorder = BinBytesRecorder::<T::Precision>::new();
459 let record: ModelExtra<B> = recorder.load(encoded, device).map_err(NeuralNetworkError)?;
460
461 let model = ModelConfig::new(fourier_features, num_blocks)
462 .init(device)
463 .load_record(record.model);
464 let b_t = record.b_t.into_value();
465 let mean = record.mean.into_value().into_scalar();
466 let stdv = record.stdv.into_value().into_scalar();
467
468 let test_xs = flat_grid_like(&decoded, device);
469 let test_xs = fourier_mapping(test_xs, b_t);
470
471 let prediction = model.forward(test_xs).into_data();
472 #[expect(clippy::unwrap_used)] let prediction = prediction.as_slice().unwrap();
474
475 #[expect(clippy::unwrap_used)] decoded.assign(&ArrayView::from_shape(decoded.shape(), prediction).unwrap());
477 decoded.mapv_inplace(|x| (x * stdv) + mean);
478
479 Ok(())
480}
481
482fn flat_grid_like<T: FloatExt, S: Data<Elem = T>, D: Dimension, B: Backend<FloatElem = T>>(
483 a: &ArrayBase<S, D>,
484 device: &B::Device,
485) -> Tensor<B, 2, Float> {
486 let grid = a
487 .shape()
488 .iter()
489 .copied()
490 .map(|s| {
491 #[expect(clippy::useless_conversion)] (0..s)
493 .into_iter()
494 .map(move |x| <T as FloatExt>::from_usize(x) / <T as FloatExt>::from_usize(s))
495 })
496 .multi_cartesian_product()
497 .flatten()
498 .collect::<Vec<_>>();
499
500 Tensor::from_data(TensorData::new(grid, [a.len(), a.ndim()]), device)
501}
502
503fn fourier_mapping<B: Backend>(
504 xs: Tensor<B, 2, Float>,
505 b_t: Tensor<B, 2, Float>,
506) -> Tensor<B, 2, Float> {
507 let xs_proj = xs.mul_scalar(core::f64::consts::TAU).matmul(b_t);
508
509 Tensor::cat(vec![xs_proj.clone().sin(), xs_proj.cos()], 1)
510}
511
512#[expect(clippy::similar_names)] #[expect(clippy::too_many_arguments)] fn train<T: FloatExt, B: AutodiffBackend<FloatElem = T>>(
515 device: &B::Device,
516 train_xs: &Tensor<B, 2, Float>,
517 train_ys: &Tensor<B, 2, Float>,
518 fourier_features: NonZeroUsize,
519 num_blocks: NonZeroUsize,
520 learning_rate: Positive<f64>,
521 num_epochs: usize,
522 mini_batch_size: Option<NonZeroUsize>,
523 stdv: T,
524) -> Model<B> {
525 let num_samples = train_ys.shape().num_elements();
526 let num_batches = mini_batch_size.map(|b| num_samples.div_ceil(b.get()));
527
528 let mut model = ModelConfig::new(fourier_features, num_blocks).init(device);
529 let mut optim = AdamConfig::new().init();
530
531 let mut best_loss = T::infinity();
532 let mut best_epoch = 0;
533 let mut best_model_checkpoint = model.clone().into_record().into_item::<T::Precision>();
534
535 for epoch in 1..=num_epochs {
536 #[expect(clippy::option_if_let_else)]
537 let (train_xs_batches, train_ys_batches) = match num_batches {
538 Some(num_batches) => {
539 let shuffle = Tensor::<B, 1, Float>::random(
540 [num_samples],
541 Distribution::Uniform(0.0, 1.0),
542 device,
543 );
544 let shuffle_indices = shuffle.argsort(0);
545
546 let train_xs_shuffled = train_xs.clone().select(0, shuffle_indices.clone());
547 let train_ys_shuffled = train_ys.clone().select(0, shuffle_indices);
548
549 (
550 train_xs_shuffled.chunk(num_batches, 0),
551 train_ys_shuffled.chunk(num_batches, 0),
552 )
553 }
554 None => (vec![train_xs.clone()], vec![train_ys.clone()]),
555 };
556
557 let mut loss_sum = T::ZERO;
558
559 let mut se_sum = T::ZERO;
560 let mut ae_sum = T::ZERO;
561 let mut l_inf = T::ZERO;
562
563 for (train_xs_batch, train_ys_batch) in train_xs_batches.into_iter().zip(train_ys_batches) {
564 let prediction = model.forward(train_xs_batch);
565 let loss =
566 MseLoss::new().forward(prediction.clone(), train_ys_batch.clone(), Reduction::Mean);
567
568 let grads = GradientsParams::from_grads(loss.backward(), &model);
569 model = optim.step(learning_rate.0, model, grads);
570
571 loss_sum += loss.into_scalar();
572
573 let err = prediction - train_ys_batch;
574
575 se_sum += (err.clone() * err.clone()).sum().into_scalar();
576 ae_sum += err.clone().abs().sum().into_scalar();
577 l_inf = l_inf.max(err.abs().max().into_scalar());
578 }
579
580 let loss_mean = loss_sum / <T as FloatExt>::from_usize(num_batches.unwrap_or(1));
581
582 if loss_mean < best_loss {
583 best_loss = loss_mean;
584 best_epoch = epoch;
585 best_model_checkpoint = model.clone().into_record().into_item::<T::Precision>();
586 }
587
588 let rmse = stdv * (se_sum / <T as FloatExt>::from_usize(num_samples)).sqrt();
589 let mae = stdv * ae_sum / <T as FloatExt>::from_usize(num_samples);
590 let l_inf = stdv * l_inf;
591
592 log::info!(
593 "[{epoch}/{num_epochs}]: loss={loss_mean:0.3} MAE={mae:0.3} RMSE={rmse:0.3} Linf={l_inf:0.3}"
594 );
595 }
596
597 if best_epoch != num_epochs {
598 model = model.load_record(ModelRecord::from_item(best_model_checkpoint, device));
599
600 log::info!("restored from epoch {best_epoch} with lowest loss={best_loss:0.3}");
601 }
602
603 model
604}
605
606#[cfg(test)]
607#[expect(clippy::unwrap_used)]
608mod tests {
609 use super::*;
610
611 #[test]
612 fn empty() {
613 std::mem::drop(simple_logger::init());
614
615 let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
616 &NdArrayDevice::Cpu,
617 Array::<f32, _>::zeros((0,)),
618 NonZeroUsize::MIN,
619 Positive(1.0),
620 NonZeroUsize::MIN,
621 Positive(1e-4),
622 10,
623 None,
624 42,
625 )
626 .unwrap();
627 assert!(encoded.is_empty());
628 let mut decoded = Array::<f32, _>::zeros((0,));
629 decode_into::<f32, _, _, NdArray<f32>>(
630 &NdArrayDevice::Cpu,
631 encoded,
632 decoded.view_mut(),
633 NonZeroUsize::MIN,
634 NonZeroUsize::MIN,
635 )
636 .unwrap();
637 }
638
639 #[test]
640 fn ones() {
641 std::mem::drop(simple_logger::init());
642
643 let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
644 &NdArrayDevice::Cpu,
645 Array::<f32, _>::zeros((1, 1, 1, 1)),
646 NonZeroUsize::MIN,
647 Positive(1.0),
648 NonZeroUsize::MIN,
649 Positive(1e-4),
650 10,
651 None,
652 42,
653 )
654 .unwrap();
655 let mut decoded = Array::<f32, _>::zeros((1, 1, 1, 1));
656 decode_into::<f32, _, _, NdArray<f32>>(
657 &NdArrayDevice::Cpu,
658 encoded,
659 decoded.view_mut(),
660 NonZeroUsize::MIN,
661 NonZeroUsize::MIN,
662 )
663 .unwrap();
664 }
665
666 #[test]
667 fn r#const() {
668 std::mem::drop(simple_logger::init());
669
670 let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
671 &NdArrayDevice::Cpu,
672 Array::<f32, _>::from_elem((2, 1, 3), 42.0),
673 NonZeroUsize::MIN,
674 Positive(1.0),
675 NonZeroUsize::MIN,
676 Positive(1e-4),
677 10,
678 None,
679 42,
680 )
681 .unwrap();
682 let mut decoded = Array::<f32, _>::zeros((2, 1, 3));
683 decode_into::<f32, _, _, NdArray<f32>>(
684 &NdArrayDevice::Cpu,
685 encoded,
686 decoded.view_mut(),
687 NonZeroUsize::MIN,
688 NonZeroUsize::MIN,
689 )
690 .unwrap();
691 }
692
693 #[test]
694 fn const_batched() {
695 std::mem::drop(simple_logger::init());
696
697 let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
698 &NdArrayDevice::Cpu,
699 Array::<f32, _>::from_elem((2, 1, 3), 42.0),
700 NonZeroUsize::MIN,
701 Positive(1.0),
702 NonZeroUsize::MIN,
703 Positive(1e-4),
704 10,
705 Some(NonZeroUsize::MIN.saturating_add(1)),
706 42,
707 )
708 .unwrap();
709 let mut decoded = Array::<f32, _>::zeros((2, 1, 3));
710 decode_into::<f32, _, _, NdArray<f32>>(
711 &NdArrayDevice::Cpu,
712 encoded,
713 decoded.view_mut(),
714 NonZeroUsize::MIN,
715 NonZeroUsize::MIN,
716 )
717 .unwrap();
718 }
719
720 #[test]
721 fn linspace() {
722 std::mem::drop(simple_logger::init());
723
724 let data = Array::linspace(0.0_f64, 100.0_f64, 100);
725
726 let fourier_features = NonZeroUsize::new(16).unwrap();
727 let fourier_scale = Positive(10.0);
728 let num_blocks = NonZeroUsize::new(2).unwrap();
729 let learning_rate = Positive(1e-4);
730 let num_epochs = 100;
731 let seed = 42;
732
733 for mini_batch_size in [
734 None, Some(NonZeroUsize::MIN), Some(NonZeroUsize::MIN.saturating_add(6)), Some(NonZeroUsize::MIN.saturating_add(9)), Some(NonZeroUsize::MIN.saturating_add(1000)), ] {
740 let mut decoded = Array::<f64, _>::zeros(data.shape());
741 let encoded = encode::<f64, _, _, Autodiff<NdArray<f64>>>(
742 &NdArrayDevice::Cpu,
743 data.view(),
744 fourier_features,
745 fourier_scale,
746 num_blocks,
747 learning_rate,
748 num_epochs,
749 mini_batch_size,
750 seed,
751 )
752 .unwrap();
753
754 decode_into::<f64, _, _, NdArray<f64>>(
755 &NdArrayDevice::Cpu,
756 encoded,
757 decoded.view_mut(),
758 fourier_features,
759 num_blocks,
760 )
761 .unwrap();
762 }
763 }
764}