1#![expect(clippy::multiple_crate_versions)]
21
22use std::{borrow::Cow, num::NonZeroUsize, ops::AddAssign};
23
24use burn::{
25 backend::{Autodiff, NdArray, ndarray::NdArrayDevice},
26 module::{Module, Param},
27 nn::loss::{MseLoss, Reduction},
28 optim::{AdamConfig, GradientsParams, Optimizer},
29 prelude::Backend,
30 record::{
31 BinBytesRecorder, DoublePrecisionSettings, FullPrecisionSettings, PrecisionSettings,
32 Record, Recorder, RecorderError,
33 },
34 tensor::{
35 Distribution, Element as BurnElement, Float, Tensor, TensorData, backend::AutodiffBackend,
36 },
37};
38use itertools::Itertools;
39use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension, Ix1, Order, Zip};
40use num_traits::{ConstOne, ConstZero, Float as FloatTrait, FromPrimitive};
41use numcodecs::{
42 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
43 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
44};
45use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
46use serde::{Deserialize, Deserializer, Serialize, Serializer};
47use thiserror::Error;
48
49#[cfg(test)]
50use ::serde_json as _;
51
52mod modules;
53
54use modules::{Model, ModelConfig, ModelExtra, ModelRecord};
55
56type FourierNetworkCodecVersion = StaticCodecVersion<0, 1, 0>;
57
58#[derive(Clone, Serialize, Deserialize, JsonSchema)]
59#[serde(deny_unknown_fields)]
60pub struct FourierNetworkCodec {
67 pub fourier_features: NonZeroUsize,
69 pub fourier_scale: Positive<f64>,
71 pub num_blocks: NonZeroUsize,
73 pub learning_rate: Positive<f64>,
75 pub num_epochs: usize,
77 #[serde(deserialize_with = "deserialize_required_option")]
83 #[schemars(required, extend("type" = ["integer", "null"]))]
84 pub mini_batch_size: Option<NonZeroUsize>,
85 pub seed: u64,
87 #[serde(default, rename = "_version")]
89 pub version: FourierNetworkCodecVersion,
90}
91
92fn deserialize_required_option<'de, T: serde::Deserialize<'de>, D: serde::Deserializer<'de>>(
94 deserializer: D,
95) -> Result<Option<T>, D::Error> {
96 Option::<T>::deserialize(deserializer)
97}
98
99impl Codec for FourierNetworkCodec {
100 type Error = FourierNetworkCodecError;
101
102 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
103 match data {
104 AnyCowArray::F32(data) => Ok(AnyArray::U8(
105 encode::<f32, _, _, Autodiff<NdArray<f32>>>(
106 &NdArrayDevice::Cpu,
107 data,
108 self.fourier_features,
109 self.fourier_scale,
110 self.num_blocks,
111 self.learning_rate,
112 self.num_epochs,
113 self.mini_batch_size,
114 self.seed,
115 )?
116 .into_dyn(),
117 )),
118 AnyCowArray::F64(data) => Ok(AnyArray::U8(
119 encode::<f64, _, _, Autodiff<NdArray<f64>>>(
120 &NdArrayDevice::Cpu,
121 data,
122 self.fourier_features,
123 self.fourier_scale,
124 self.num_blocks,
125 self.learning_rate,
126 self.num_epochs,
127 self.mini_batch_size,
128 self.seed,
129 )?
130 .into_dyn(),
131 )),
132 encoded => Err(FourierNetworkCodecError::UnsupportedDtype(encoded.dtype())),
133 }
134 }
135
136 fn decode(&self, _encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
137 Err(FourierNetworkCodecError::MissingDecodingOutput)
138 }
139
140 fn decode_into(
141 &self,
142 encoded: AnyArrayView,
143 decoded: AnyArrayViewMut,
144 ) -> Result<(), Self::Error> {
145 let AnyArrayView::U8(encoded) = encoded else {
146 return Err(FourierNetworkCodecError::EncodedDataNotBytes {
147 dtype: encoded.dtype(),
148 });
149 };
150
151 let Ok(encoded): Result<ArrayBase<_, Ix1>, _> = encoded.view().into_dimensionality() else {
152 return Err(FourierNetworkCodecError::EncodedDataNotOneDimensional {
153 shape: encoded.shape().to_vec(),
154 });
155 };
156
157 match decoded {
158 AnyArrayViewMut::F32(decoded) => decode_into::<f32, _, _, NdArray<f32>>(
159 &NdArrayDevice::Cpu,
160 encoded,
161 decoded,
162 self.fourier_features,
163 self.num_blocks,
164 ),
165 AnyArrayViewMut::F64(decoded) => decode_into::<f64, _, _, NdArray<f64>>(
166 &NdArrayDevice::Cpu,
167 encoded,
168 decoded,
169 self.fourier_features,
170 self.num_blocks,
171 ),
172 decoded => Err(FourierNetworkCodecError::UnsupportedDtype(decoded.dtype())),
173 }
174 }
175}
176
177impl StaticCodec for FourierNetworkCodec {
178 const CODEC_ID: &'static str = "fourier-network.rs";
179
180 type Config<'de> = Self;
181
182 fn from_config(config: Self::Config<'_>) -> Self {
183 config
184 }
185
186 fn get_config(&self) -> StaticCodecConfig<Self> {
187 StaticCodecConfig::from(self)
188 }
189}
190
191#[expect(clippy::derive_partial_eq_without_eq)] #[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
193pub struct Positive<T: FloatTrait>(T);
195
196impl Serialize for Positive<f64> {
197 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
198 serializer.serialize_f64(self.0)
199 }
200}
201
202impl<'de> Deserialize<'de> for Positive<f64> {
203 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
204 let x = f64::deserialize(deserializer)?;
205
206 if x > 0.0 {
207 Ok(Self(x))
208 } else {
209 Err(serde::de::Error::invalid_value(
210 serde::de::Unexpected::Float(x),
211 &"a positive value",
212 ))
213 }
214 }
215}
216
217impl JsonSchema for Positive<f64> {
218 fn schema_name() -> Cow<'static, str> {
219 Cow::Borrowed("PositiveF64")
220 }
221
222 fn schema_id() -> Cow<'static, str> {
223 Cow::Borrowed(concat!(module_path!(), "::", "Positive<f64>"))
224 }
225
226 fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
227 json_schema!({
228 "type": "number",
229 "exclusiveMinimum": 0.0
230 })
231 }
232}
233
234#[derive(Debug, Error)]
235pub enum FourierNetworkCodecError {
237 #[error("FourierNetwork does not support the dtype {0}")]
239 UnsupportedDtype(AnyArrayDType),
240 #[error("FourierNetwork does not support non-finite (infinite or NaN) floating point data")]
243 NonFiniteData,
244 #[error("FourierNetwork failed during a neural network computation")]
246 NeuralNetworkError {
247 #[from]
249 source: NeuralNetworkError,
250 },
251 #[error("FourierNetwork must be provided the output array during decoding")]
253 MissingDecodingOutput,
254 #[error(
257 "FourierNetwork can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
258 )]
259 EncodedDataNotBytes {
260 dtype: AnyArrayDType,
262 },
263 #[error(
266 "FourierNetwork can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
267 )]
268 EncodedDataNotOneDimensional {
269 shape: Vec<usize>,
271 },
272 #[error("FourierNetwork cannot decode into the provided array")]
274 MismatchedDecodeIntoArray {
275 #[from]
277 source: AnyArrayAssignError,
278 },
279}
280
281#[derive(Debug, Error)]
282#[error(transparent)]
283pub struct NeuralNetworkError(RecorderError);
285
286pub trait FloatExt:
288 AddAssign + BurnElement + ConstOne + ConstZero + FloatTrait + FromPrimitive
289{
290 type Precision: PrecisionSettings;
292
293 fn from_usize(x: usize) -> Self;
295}
296
297impl FloatExt for f32 {
298 type Precision = FullPrecisionSettings;
299
300 #[expect(clippy::cast_precision_loss)]
301 fn from_usize(x: usize) -> Self {
302 x as Self
303 }
304}
305
306impl FloatExt for f64 {
307 type Precision = DoublePrecisionSettings;
308
309 #[expect(clippy::cast_precision_loss)]
310 fn from_usize(x: usize) -> Self {
311 x as Self
312 }
313}
314
315#[expect(clippy::similar_names)] #[expect(clippy::missing_panics_doc)] #[expect(clippy::too_many_arguments)] pub fn encode<T: FloatExt, S: Data<Elem = T>, D: Dimension, B: AutodiffBackend<FloatElem = T>>(
338 device: &B::Device,
339 data: ArrayBase<S, D>,
340 fourier_features: NonZeroUsize,
341 fourier_scale: Positive<f64>,
342 num_blocks: NonZeroUsize,
343 learning_rate: Positive<f64>,
344 num_epochs: usize,
345 mini_batch_size: Option<NonZeroUsize>,
346 seed: u64,
347) -> Result<Array<u8, Ix1>, FourierNetworkCodecError> {
348 let Some(mean) = data.mean() else {
349 return Ok(Array::from_vec(Vec::new()));
350 };
351 let stdv = data.std(T::ZERO);
352 let stdv = if stdv == T::ZERO { T::ONE } else { stdv };
353
354 if !Zip::from(&data).all(|x| x.is_finite()) {
355 return Err(FourierNetworkCodecError::NonFiniteData);
356 }
357
358 B::seed(seed);
359
360 let b_t = Tensor::<B, 2, Float>::random(
361 [data.ndim(), fourier_features.get()],
362 Distribution::Normal(0.0, fourier_scale.0),
363 device,
364 );
365
366 let train_xs = flat_grid_like(&data, device);
367 let train_xs = fourier_mapping(train_xs, b_t.clone());
368
369 let train_ys_shape = [data.len(), 1];
370 let mut train_ys = data.into_owned();
371 train_ys.mapv_inplace(|x| (x - mean) / stdv);
372 #[expect(clippy::unwrap_used)] let train_ys = train_ys
374 .into_shape_clone((train_ys_shape, Order::RowMajor))
375 .unwrap();
376 let train_ys = Tensor::from_data(
377 TensorData::new(train_ys.into_raw_vec_and_offset().0, train_ys_shape),
378 device,
379 );
380
381 let model = train(
382 device,
383 &train_xs,
384 &train_ys,
385 fourier_features,
386 num_blocks,
387 learning_rate,
388 num_epochs,
389 mini_batch_size,
390 stdv,
391 );
392
393 let extra = ModelExtra {
394 model: model.into_record(),
395 b_t: Param::from_tensor(b_t).set_require_grad(false),
396 mean: Param::from_tensor(Tensor::from_data(
397 TensorData::new(vec![mean], vec![1]),
398 device,
399 ))
400 .set_require_grad(false),
401 stdv: Param::from_tensor(Tensor::from_data(
402 TensorData::new(vec![stdv], vec![1]),
403 device,
404 ))
405 .set_require_grad(false),
406 version: StaticCodecVersion,
407 };
408
409 let recorder = BinBytesRecorder::<T::Precision>::new();
410 let encoded = recorder.record(extra, ()).map_err(NeuralNetworkError)?;
411
412 Ok(Array::from_vec(encoded))
413}
414
415#[expect(clippy::missing_panics_doc)] pub fn decode_into<T: FloatExt, S: Data<Elem = u8>, D: Dimension, B: Backend<FloatElem = T>>(
430 device: &B::Device,
431 encoded: ArrayBase<S, Ix1>,
432 mut decoded: ArrayViewMut<T, D>,
433 fourier_features: NonZeroUsize,
434 num_blocks: NonZeroUsize,
435) -> Result<(), FourierNetworkCodecError> {
436 if encoded.is_empty() {
437 if decoded.is_empty() {
438 return Ok(());
439 }
440
441 return Err(FourierNetworkCodecError::MismatchedDecodeIntoArray {
442 source: AnyArrayAssignError::ShapeMismatch {
443 src: encoded.shape().to_vec(),
444 dst: decoded.shape().to_vec(),
445 },
446 });
447 }
448
449 let encoded = encoded.into_owned().into_raw_vec_and_offset().0;
450
451 let recorder = BinBytesRecorder::<T::Precision>::new();
452 let record: ModelExtra<B> = recorder.load(encoded, device).map_err(NeuralNetworkError)?;
453
454 let model = ModelConfig::new(fourier_features, num_blocks)
455 .init(device)
456 .load_record(record.model);
457 let b_t = record.b_t.into_value();
458 let mean = record.mean.into_value().into_scalar();
459 let stdv = record.stdv.into_value().into_scalar();
460
461 let test_xs = flat_grid_like(&decoded, device);
462 let test_xs = fourier_mapping(test_xs, b_t);
463
464 let prediction = model.forward(test_xs).into_data();
465 #[expect(clippy::unwrap_used)] let prediction = prediction.as_slice().unwrap();
467
468 #[expect(clippy::unwrap_used)] decoded.assign(&ArrayView::from_shape(decoded.shape(), prediction).unwrap());
470 decoded.mapv_inplace(|x| (x * stdv) + mean);
471
472 Ok(())
473}
474
475fn flat_grid_like<T: FloatExt, S: Data<Elem = T>, D: Dimension, B: Backend<FloatElem = T>>(
476 a: &ArrayBase<S, D>,
477 device: &B::Device,
478) -> Tensor<B, 2, Float> {
479 let grid = a
480 .shape()
481 .iter()
482 .copied()
483 .map(|s| {
484 #[expect(clippy::useless_conversion)] (0..s)
486 .into_iter()
487 .map(move |x| <T as FloatExt>::from_usize(x) / <T as FloatExt>::from_usize(s))
488 })
489 .multi_cartesian_product()
490 .flatten()
491 .collect::<Vec<_>>();
492
493 Tensor::from_data(TensorData::new(grid, [a.len(), a.ndim()]), device)
494}
495
496fn fourier_mapping<B: Backend>(
497 xs: Tensor<B, 2, Float>,
498 b_t: Tensor<B, 2, Float>,
499) -> Tensor<B, 2, Float> {
500 let xs_proj = xs.mul_scalar(core::f64::consts::TAU).matmul(b_t);
501
502 Tensor::cat(vec![xs_proj.clone().sin(), xs_proj.cos()], 1)
503}
504
505#[expect(clippy::similar_names)] #[expect(clippy::too_many_arguments)] fn train<T: FloatExt, B: AutodiffBackend<FloatElem = T>>(
508 device: &B::Device,
509 train_xs: &Tensor<B, 2, Float>,
510 train_ys: &Tensor<B, 2, Float>,
511 fourier_features: NonZeroUsize,
512 num_blocks: NonZeroUsize,
513 learning_rate: Positive<f64>,
514 num_epochs: usize,
515 mini_batch_size: Option<NonZeroUsize>,
516 stdv: T,
517) -> Model<B> {
518 let num_samples = train_ys.shape().num_elements();
519 let num_batches = mini_batch_size.map(|b| num_samples.div_ceil(b.get()));
520
521 let mut model = ModelConfig::new(fourier_features, num_blocks).init(device);
522 let mut optim = AdamConfig::new().init();
523
524 let mut best_loss = T::infinity();
525 let mut best_epoch = 0;
526 let mut best_model_checkpoint = model.clone().into_record().into_item::<T::Precision>();
527
528 for epoch in 1..=num_epochs {
529 #[expect(clippy::option_if_let_else)]
530 let (train_xs_batches, train_ys_batches) = match num_batches {
531 Some(num_batches) => {
532 let shuffle = Tensor::<B, 1, Float>::random(
533 [num_samples],
534 Distribution::Uniform(0.0, 1.0),
535 device,
536 );
537 let shuffle_indices = shuffle.argsort(0);
538
539 let train_xs_shuffled = train_xs.clone().select(0, shuffle_indices.clone());
540 let train_ys_shuffled = train_ys.clone().select(0, shuffle_indices);
541
542 (
543 train_xs_shuffled.chunk(num_batches, 0),
544 train_ys_shuffled.chunk(num_batches, 0),
545 )
546 }
547 None => (vec![train_xs.clone()], vec![train_ys.clone()]),
548 };
549
550 let mut loss_sum = T::ZERO;
551
552 let mut se_sum = T::ZERO;
553 let mut ae_sum = T::ZERO;
554 let mut l_inf = T::ZERO;
555
556 for (train_xs_batch, train_ys_batch) in train_xs_batches.into_iter().zip(train_ys_batches) {
557 let prediction = model.forward(train_xs_batch);
558 let loss =
559 MseLoss::new().forward(prediction.clone(), train_ys_batch.clone(), Reduction::Mean);
560
561 let grads = GradientsParams::from_grads(loss.backward(), &model);
562 model = optim.step(learning_rate.0, model, grads);
563
564 loss_sum += loss.into_scalar();
565
566 let err = prediction - train_ys_batch;
567
568 se_sum += (err.clone() * err.clone()).sum().into_scalar();
569 ae_sum += err.clone().abs().sum().into_scalar();
570 l_inf = l_inf.max(err.abs().max().into_scalar());
571 }
572
573 let loss_mean = loss_sum / <T as FloatExt>::from_usize(num_batches.unwrap_or(1));
574
575 if loss_mean < best_loss {
576 best_loss = loss_mean;
577 best_epoch = epoch;
578 best_model_checkpoint = model.clone().into_record().into_item::<T::Precision>();
579 }
580
581 let rmse = stdv * (se_sum / <T as FloatExt>::from_usize(num_samples)).sqrt();
582 let mae = stdv * ae_sum / <T as FloatExt>::from_usize(num_samples);
583 let l_inf = stdv * l_inf;
584
585 log::info!(
586 "[{epoch}/{num_epochs}]: loss={loss_mean:0.3} MAE={mae:0.3} RMSE={rmse:0.3} Linf={l_inf:0.3}"
587 );
588 }
589
590 if best_epoch != num_epochs {
591 model = model.load_record(ModelRecord::from_item(best_model_checkpoint, device));
592
593 log::info!("restored from epoch {best_epoch} with lowest loss={best_loss:0.3}");
594 }
595
596 model
597}
598
599#[cfg(test)]
600#[expect(clippy::unwrap_used)]
601mod tests {
602 use super::*;
603
604 #[test]
605 fn empty() {
606 std::mem::drop(simple_logger::init());
607
608 let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
609 &NdArrayDevice::Cpu,
610 Array::<f32, _>::zeros((0,)),
611 NonZeroUsize::MIN,
612 Positive(1.0),
613 NonZeroUsize::MIN,
614 Positive(1e-4),
615 10,
616 None,
617 42,
618 )
619 .unwrap();
620 assert!(encoded.is_empty());
621 let mut decoded = Array::<f32, _>::zeros((0,));
622 decode_into::<f32, _, _, NdArray<f32>>(
623 &NdArrayDevice::Cpu,
624 encoded,
625 decoded.view_mut(),
626 NonZeroUsize::MIN,
627 NonZeroUsize::MIN,
628 )
629 .unwrap();
630 }
631
632 #[test]
633 fn ones() {
634 std::mem::drop(simple_logger::init());
635
636 let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
637 &NdArrayDevice::Cpu,
638 Array::<f32, _>::zeros((1, 1, 1, 1)),
639 NonZeroUsize::MIN,
640 Positive(1.0),
641 NonZeroUsize::MIN,
642 Positive(1e-4),
643 10,
644 None,
645 42,
646 )
647 .unwrap();
648 let mut decoded = Array::<f32, _>::zeros((1, 1, 1, 1));
649 decode_into::<f32, _, _, NdArray<f32>>(
650 &NdArrayDevice::Cpu,
651 encoded,
652 decoded.view_mut(),
653 NonZeroUsize::MIN,
654 NonZeroUsize::MIN,
655 )
656 .unwrap();
657 }
658
659 #[test]
660 fn r#const() {
661 std::mem::drop(simple_logger::init());
662
663 let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
664 &NdArrayDevice::Cpu,
665 Array::<f32, _>::from_elem((2, 1, 3), 42.0),
666 NonZeroUsize::MIN,
667 Positive(1.0),
668 NonZeroUsize::MIN,
669 Positive(1e-4),
670 10,
671 None,
672 42,
673 )
674 .unwrap();
675 let mut decoded = Array::<f32, _>::zeros((2, 1, 3));
676 decode_into::<f32, _, _, NdArray<f32>>(
677 &NdArrayDevice::Cpu,
678 encoded,
679 decoded.view_mut(),
680 NonZeroUsize::MIN,
681 NonZeroUsize::MIN,
682 )
683 .unwrap();
684 }
685
686 #[test]
687 fn const_batched() {
688 std::mem::drop(simple_logger::init());
689
690 let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
691 &NdArrayDevice::Cpu,
692 Array::<f32, _>::from_elem((2, 1, 3), 42.0),
693 NonZeroUsize::MIN,
694 Positive(1.0),
695 NonZeroUsize::MIN,
696 Positive(1e-4),
697 10,
698 Some(NonZeroUsize::MIN.saturating_add(1)),
699 42,
700 )
701 .unwrap();
702 let mut decoded = Array::<f32, _>::zeros((2, 1, 3));
703 decode_into::<f32, _, _, NdArray<f32>>(
704 &NdArrayDevice::Cpu,
705 encoded,
706 decoded.view_mut(),
707 NonZeroUsize::MIN,
708 NonZeroUsize::MIN,
709 )
710 .unwrap();
711 }
712
713 #[test]
714 fn linspace() {
715 std::mem::drop(simple_logger::init());
716
717 let data = Array::linspace(0.0_f64, 100.0_f64, 100);
718
719 let fourier_features = NonZeroUsize::new(16).unwrap();
720 let fourier_scale = Positive(10.0);
721 let num_blocks = NonZeroUsize::new(2).unwrap();
722 let learning_rate = Positive(1e-4);
723 let num_epochs = 100;
724 let seed = 42;
725
726 for mini_batch_size in [
727 None, Some(NonZeroUsize::MIN), Some(NonZeroUsize::MIN.saturating_add(6)), Some(NonZeroUsize::MIN.saturating_add(9)), Some(NonZeroUsize::MIN.saturating_add(1000)), ] {
733 let mut decoded = Array::<f64, _>::zeros(data.shape());
734 let encoded = encode::<f64, _, _, Autodiff<NdArray<f64>>>(
735 &NdArrayDevice::Cpu,
736 data.view(),
737 fourier_features,
738 fourier_scale,
739 num_blocks,
740 learning_rate,
741 num_epochs,
742 mini_batch_size,
743 seed,
744 )
745 .unwrap();
746
747 decode_into::<f64, _, _, NdArray<f64>>(
748 &NdArrayDevice::Cpu,
749 encoded,
750 decoded.view_mut(),
751 fourier_features,
752 num_blocks,
753 )
754 .unwrap();
755 }
756 }
757}