1#![expect(clippy::multiple_crate_versions)]
21
22use std::{borrow::Cow, num::NonZeroUsize, ops::AddAssign};
23
24use burn::{
25 backend::{Autodiff, NdArray, ndarray::NdArrayDevice},
26 module::{Module, Param},
27 nn::loss::{MseLoss, Reduction},
28 optim::{AdamConfig, GradientsParams, Optimizer},
29 prelude::Backend,
30 record::{
31 BinBytesRecorder, DoublePrecisionSettings, FullPrecisionSettings, PrecisionSettings,
32 Record, Recorder, RecorderError,
33 },
34 tensor::{
35 Distribution, Element as BurnElement, Float, Tensor, TensorData, backend::AutodiffBackend,
36 },
37};
38use itertools::Itertools;
39use ndarray::{Array, ArrayBase, ArrayView, ArrayViewMut, Data, Dimension, Ix1, Order, Zip};
40use num_traits::{ConstOne, ConstZero, Float as FloatTrait, FromPrimitive};
41use numcodecs::{
42 AnyArray, AnyArrayAssignError, AnyArrayDType, AnyArrayView, AnyArrayViewMut, AnyCowArray,
43 Codec, StaticCodec, StaticCodecConfig, StaticCodecVersion,
44};
45use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
46use serde::{Deserialize, Deserializer, Serialize, Serializer};
47use thiserror::Error;
48
49use ::bytemuck as _;
51
52#[cfg(test)]
53use ::serde_json as _;
54
55mod modules;
56
57use modules::{Model, ModelConfig, ModelExtra, ModelRecord};
58
59type FourierNetworkCodecVersion = StaticCodecVersion<0, 1, 0>;
60
61#[derive(Clone, Serialize, Deserialize, JsonSchema)]
62#[serde(deny_unknown_fields)]
63pub struct FourierNetworkCodec {
70 pub fourier_features: NonZeroUsize,
72 pub fourier_scale: Positive<f64>,
74 pub num_blocks: NonZeroUsize,
76 pub learning_rate: Positive<f64>,
78 pub num_epochs: usize,
80 #[serde(deserialize_with = "deserialize_required_option")]
86 #[schemars(required, extend("type" = ["integer", "null"]))]
87 pub mini_batch_size: Option<NonZeroUsize>,
88 pub seed: u64,
90 #[serde(default, rename = "_version")]
92 pub version: FourierNetworkCodecVersion,
93}
94
95fn deserialize_required_option<'de, T: serde::Deserialize<'de>, D: serde::Deserializer<'de>>(
97 deserializer: D,
98) -> Result<Option<T>, D::Error> {
99 Option::<T>::deserialize(deserializer)
100}
101
102impl Codec for FourierNetworkCodec {
103 type Error = FourierNetworkCodecError;
104
105 fn encode(&self, data: AnyCowArray) -> Result<AnyArray, Self::Error> {
106 match data {
107 AnyCowArray::F32(data) => Ok(AnyArray::U8(
108 encode::<f32, _, _, Autodiff<NdArray<f32>>>(
109 &NdArrayDevice::Cpu,
110 data,
111 self.fourier_features,
112 self.fourier_scale,
113 self.num_blocks,
114 self.learning_rate,
115 self.num_epochs,
116 self.mini_batch_size,
117 self.seed,
118 )?
119 .into_dyn(),
120 )),
121 AnyCowArray::F64(data) => Ok(AnyArray::U8(
122 encode::<f64, _, _, Autodiff<NdArray<f64>>>(
123 &NdArrayDevice::Cpu,
124 data,
125 self.fourier_features,
126 self.fourier_scale,
127 self.num_blocks,
128 self.learning_rate,
129 self.num_epochs,
130 self.mini_batch_size,
131 self.seed,
132 )?
133 .into_dyn(),
134 )),
135 encoded => Err(FourierNetworkCodecError::UnsupportedDtype(encoded.dtype())),
136 }
137 }
138
139 fn decode(&self, _encoded: AnyCowArray) -> Result<AnyArray, Self::Error> {
140 Err(FourierNetworkCodecError::MissingDecodingOutput)
141 }
142
143 fn decode_into(
144 &self,
145 encoded: AnyArrayView,
146 decoded: AnyArrayViewMut,
147 ) -> Result<(), Self::Error> {
148 let AnyArrayView::U8(encoded) = encoded else {
149 return Err(FourierNetworkCodecError::EncodedDataNotBytes {
150 dtype: encoded.dtype(),
151 });
152 };
153
154 let Ok(encoded): Result<ArrayBase<_, Ix1>, _> = encoded.view().into_dimensionality() else {
155 return Err(FourierNetworkCodecError::EncodedDataNotOneDimensional {
156 shape: encoded.shape().to_vec(),
157 });
158 };
159
160 match decoded {
161 AnyArrayViewMut::F32(decoded) => decode_into::<f32, _, _, NdArray<f32>>(
162 &NdArrayDevice::Cpu,
163 encoded,
164 decoded,
165 self.fourier_features,
166 self.num_blocks,
167 ),
168 AnyArrayViewMut::F64(decoded) => decode_into::<f64, _, _, NdArray<f64>>(
169 &NdArrayDevice::Cpu,
170 encoded,
171 decoded,
172 self.fourier_features,
173 self.num_blocks,
174 ),
175 decoded => Err(FourierNetworkCodecError::UnsupportedDtype(decoded.dtype())),
176 }
177 }
178}
179
180impl StaticCodec for FourierNetworkCodec {
181 const CODEC_ID: &'static str = "fourier-network.rs";
182
183 type Config<'de> = Self;
184
185 fn from_config(config: Self::Config<'_>) -> Self {
186 config
187 }
188
189 fn get_config(&self) -> StaticCodecConfig<'_, Self> {
190 StaticCodecConfig::from(self)
191 }
192}
193
194#[expect(clippy::derive_partial_eq_without_eq)] #[derive(Copy, Clone, PartialEq, PartialOrd, Hash)]
196pub struct Positive<T: FloatTrait>(T);
198
199impl Serialize for Positive<f64> {
200 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
201 serializer.serialize_f64(self.0)
202 }
203}
204
205impl<'de> Deserialize<'de> for Positive<f64> {
206 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
207 let x = f64::deserialize(deserializer)?;
208
209 if x > 0.0 {
210 Ok(Self(x))
211 } else {
212 Err(serde::de::Error::invalid_value(
213 serde::de::Unexpected::Float(x),
214 &"a positive value",
215 ))
216 }
217 }
218}
219
220impl JsonSchema for Positive<f64> {
221 fn schema_name() -> Cow<'static, str> {
222 Cow::Borrowed("PositiveF64")
223 }
224
225 fn schema_id() -> Cow<'static, str> {
226 Cow::Borrowed(concat!(module_path!(), "::", "Positive<f64>"))
227 }
228
229 fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
230 json_schema!({
231 "type": "number",
232 "exclusiveMinimum": 0.0
233 })
234 }
235}
236
237#[derive(Debug, Error)]
238pub enum FourierNetworkCodecError {
240 #[error("FourierNetwork does not support the dtype {0}")]
242 UnsupportedDtype(AnyArrayDType),
243 #[error("FourierNetwork does not support non-finite (infinite or NaN) floating point data")]
246 NonFiniteData,
247 #[error("FourierNetwork failed during a neural network computation")]
249 NeuralNetworkError {
250 #[from]
252 source: NeuralNetworkError,
253 },
254 #[error("FourierNetwork must be provided the output array during decoding")]
256 MissingDecodingOutput,
257 #[error(
260 "FourierNetwork can only decode one-dimensional byte arrays but received an array of dtype {dtype}"
261 )]
262 EncodedDataNotBytes {
263 dtype: AnyArrayDType,
265 },
266 #[error(
269 "FourierNetwork can only decode one-dimensional byte arrays but received a byte array of shape {shape:?}"
270 )]
271 EncodedDataNotOneDimensional {
272 shape: Vec<usize>,
274 },
275 #[error("FourierNetwork cannot decode into the provided array")]
277 MismatchedDecodeIntoArray {
278 #[from]
280 source: AnyArrayAssignError,
281 },
282}
283
284#[derive(Debug, Error)]
285#[error(transparent)]
286pub struct NeuralNetworkError(RecorderError);
288
289pub trait FloatExt:
291 AddAssign + BurnElement + ConstOne + ConstZero + FloatTrait + FromPrimitive
292{
293 type Precision: PrecisionSettings;
295
296 fn from_usize(x: usize) -> Self;
298}
299
300impl FloatExt for f32 {
301 type Precision = FullPrecisionSettings;
302
303 #[expect(clippy::cast_precision_loss)]
304 fn from_usize(x: usize) -> Self {
305 x as Self
306 }
307}
308
309impl FloatExt for f64 {
310 type Precision = DoublePrecisionSettings;
311
312 #[expect(clippy::cast_precision_loss)]
313 fn from_usize(x: usize) -> Self {
314 x as Self
315 }
316}
317
318#[expect(clippy::similar_names)] #[expect(clippy::missing_panics_doc)] #[expect(clippy::too_many_arguments)] pub fn encode<T: FloatExt, S: Data<Elem = T>, D: Dimension, B: AutodiffBackend<FloatElem = T>>(
341 device: &B::Device,
342 data: ArrayBase<S, D>,
343 fourier_features: NonZeroUsize,
344 fourier_scale: Positive<f64>,
345 num_blocks: NonZeroUsize,
346 learning_rate: Positive<f64>,
347 num_epochs: usize,
348 mini_batch_size: Option<NonZeroUsize>,
349 seed: u64,
350) -> Result<Array<u8, Ix1>, FourierNetworkCodecError> {
351 let Some(mean) = data.mean() else {
352 return Ok(Array::from_vec(Vec::new()));
353 };
354 let stdv = data.std(T::ZERO);
355 let stdv = if stdv == T::ZERO { T::ONE } else { stdv };
356
357 if !Zip::from(&data).all(|x| x.is_finite()) {
358 return Err(FourierNetworkCodecError::NonFiniteData);
359 }
360
361 B::seed(seed);
362
363 let b_t = Tensor::<B, 2, Float>::random(
364 [data.ndim(), fourier_features.get()],
365 Distribution::Normal(0.0, fourier_scale.0),
366 device,
367 );
368
369 let train_xs = flat_grid_like(&data, device);
370 let train_xs = fourier_mapping(train_xs, b_t.clone());
371
372 let train_ys_shape = [data.len(), 1];
373 let mut train_ys = data.into_owned();
374 train_ys.mapv_inplace(|x| (x - mean) / stdv);
375 #[expect(clippy::unwrap_used)] let train_ys = train_ys
377 .into_shape_clone((train_ys_shape, Order::RowMajor))
378 .unwrap();
379 let train_ys = Tensor::from_data(
380 TensorData::new(train_ys.into_raw_vec_and_offset().0, train_ys_shape),
381 device,
382 );
383
384 let model = train(
385 device,
386 &train_xs,
387 &train_ys,
388 fourier_features,
389 num_blocks,
390 learning_rate,
391 num_epochs,
392 mini_batch_size,
393 stdv,
394 );
395
396 let extra = ModelExtra {
397 model: model.into_record(),
398 b_t: Param::from_tensor(b_t).set_require_grad(false),
399 mean: Param::from_tensor(Tensor::from_data(
400 TensorData::new(vec![mean], vec![1]),
401 device,
402 ))
403 .set_require_grad(false),
404 stdv: Param::from_tensor(Tensor::from_data(
405 TensorData::new(vec![stdv], vec![1]),
406 device,
407 ))
408 .set_require_grad(false),
409 version: StaticCodecVersion,
410 };
411
412 let recorder = BinBytesRecorder::<T::Precision>::new();
413 let encoded = recorder.record(extra, ()).map_err(NeuralNetworkError)?;
414
415 Ok(Array::from_vec(encoded))
416}
417
418#[expect(clippy::missing_panics_doc)] pub fn decode_into<T: FloatExt, S: Data<Elem = u8>, D: Dimension, B: Backend<FloatElem = T>>(
433 device: &B::Device,
434 encoded: ArrayBase<S, Ix1>,
435 mut decoded: ArrayViewMut<T, D>,
436 fourier_features: NonZeroUsize,
437 num_blocks: NonZeroUsize,
438) -> Result<(), FourierNetworkCodecError> {
439 if encoded.is_empty() {
440 if decoded.is_empty() {
441 return Ok(());
442 }
443
444 return Err(FourierNetworkCodecError::MismatchedDecodeIntoArray {
445 source: AnyArrayAssignError::ShapeMismatch {
446 src: encoded.shape().to_vec(),
447 dst: decoded.shape().to_vec(),
448 },
449 });
450 }
451
452 let encoded = encoded.into_owned().into_raw_vec_and_offset().0;
453
454 let recorder = BinBytesRecorder::<T::Precision>::new();
455 let record: ModelExtra<B> = recorder.load(encoded, device).map_err(NeuralNetworkError)?;
456
457 let model = ModelConfig::new(fourier_features, num_blocks)
458 .init(device)
459 .load_record(record.model);
460 let b_t = record.b_t.into_value();
461 let mean = record.mean.into_value().into_scalar();
462 let stdv = record.stdv.into_value().into_scalar();
463
464 let test_xs = flat_grid_like(&decoded, device);
465 let test_xs = fourier_mapping(test_xs, b_t);
466
467 let prediction = model.forward(test_xs).into_data();
468 #[expect(clippy::unwrap_used)] let prediction = prediction.as_slice().unwrap();
470
471 #[expect(clippy::unwrap_used)] decoded.assign(&ArrayView::from_shape(decoded.shape(), prediction).unwrap());
473 decoded.mapv_inplace(|x| (x * stdv) + mean);
474
475 Ok(())
476}
477
478fn flat_grid_like<T: FloatExt, S: Data<Elem = T>, D: Dimension, B: Backend<FloatElem = T>>(
479 a: &ArrayBase<S, D>,
480 device: &B::Device,
481) -> Tensor<B, 2, Float> {
482 let grid = a
483 .shape()
484 .iter()
485 .copied()
486 .map(|s| {
487 #[expect(clippy::useless_conversion)] (0..s)
489 .into_iter()
490 .map(move |x| <T as FloatExt>::from_usize(x) / <T as FloatExt>::from_usize(s))
491 })
492 .multi_cartesian_product()
493 .flatten()
494 .collect::<Vec<_>>();
495
496 Tensor::from_data(TensorData::new(grid, [a.len(), a.ndim()]), device)
497}
498
499fn fourier_mapping<B: Backend>(
500 xs: Tensor<B, 2, Float>,
501 b_t: Tensor<B, 2, Float>,
502) -> Tensor<B, 2, Float> {
503 let xs_proj = xs.mul_scalar(core::f64::consts::TAU).matmul(b_t);
504
505 Tensor::cat(vec![xs_proj.clone().sin(), xs_proj.cos()], 1)
506}
507
508#[expect(clippy::similar_names)] #[expect(clippy::too_many_arguments)] fn train<T: FloatExt, B: AutodiffBackend<FloatElem = T>>(
511 device: &B::Device,
512 train_xs: &Tensor<B, 2, Float>,
513 train_ys: &Tensor<B, 2, Float>,
514 fourier_features: NonZeroUsize,
515 num_blocks: NonZeroUsize,
516 learning_rate: Positive<f64>,
517 num_epochs: usize,
518 mini_batch_size: Option<NonZeroUsize>,
519 stdv: T,
520) -> Model<B> {
521 let num_samples = train_ys.shape().num_elements();
522 let num_batches = mini_batch_size.map(|b| num_samples.div_ceil(b.get()));
523
524 let mut model = ModelConfig::new(fourier_features, num_blocks).init(device);
525 let mut optim = AdamConfig::new().init();
526
527 let mut best_loss = T::infinity();
528 let mut best_epoch = 0;
529 let mut best_model_checkpoint = model.clone().into_record().into_item::<T::Precision>();
530
531 for epoch in 1..=num_epochs {
532 #[expect(clippy::option_if_let_else)]
533 let (train_xs_batches, train_ys_batches) = match num_batches {
534 Some(num_batches) => {
535 let shuffle = Tensor::<B, 1, Float>::random(
536 [num_samples],
537 Distribution::Uniform(0.0, 1.0),
538 device,
539 );
540 let shuffle_indices = shuffle.argsort(0);
541
542 let train_xs_shuffled = train_xs.clone().select(0, shuffle_indices.clone());
543 let train_ys_shuffled = train_ys.clone().select(0, shuffle_indices);
544
545 (
546 train_xs_shuffled.chunk(num_batches, 0),
547 train_ys_shuffled.chunk(num_batches, 0),
548 )
549 }
550 None => (vec![train_xs.clone()], vec![train_ys.clone()]),
551 };
552
553 let mut loss_sum = T::ZERO;
554
555 let mut se_sum = T::ZERO;
556 let mut ae_sum = T::ZERO;
557 let mut l_inf = T::ZERO;
558
559 for (train_xs_batch, train_ys_batch) in train_xs_batches.into_iter().zip(train_ys_batches) {
560 let prediction = model.forward(train_xs_batch);
561 let loss =
562 MseLoss::new().forward(prediction.clone(), train_ys_batch.clone(), Reduction::Mean);
563
564 let grads = GradientsParams::from_grads(loss.backward(), &model);
565 model = optim.step(learning_rate.0, model, grads);
566
567 loss_sum += loss.into_scalar();
568
569 let err = prediction - train_ys_batch;
570
571 se_sum += (err.clone() * err.clone()).sum().into_scalar();
572 ae_sum += err.clone().abs().sum().into_scalar();
573 l_inf = l_inf.max(err.abs().max().into_scalar());
574 }
575
576 let loss_mean = loss_sum / <T as FloatExt>::from_usize(num_batches.unwrap_or(1));
577
578 if loss_mean < best_loss {
579 best_loss = loss_mean;
580 best_epoch = epoch;
581 best_model_checkpoint = model.clone().into_record().into_item::<T::Precision>();
582 }
583
584 let rmse = stdv * (se_sum / <T as FloatExt>::from_usize(num_samples)).sqrt();
585 let mae = stdv * ae_sum / <T as FloatExt>::from_usize(num_samples);
586 let l_inf = stdv * l_inf;
587
588 log::info!(
589 "[{epoch}/{num_epochs}]: loss={loss_mean:0.3} MAE={mae:0.3} RMSE={rmse:0.3} Linf={l_inf:0.3}"
590 );
591 }
592
593 if best_epoch != num_epochs {
594 model = model.load_record(ModelRecord::from_item(best_model_checkpoint, device));
595
596 log::info!("restored from epoch {best_epoch} with lowest loss={best_loss:0.3}");
597 }
598
599 model
600}
601
602#[cfg(test)]
603#[expect(clippy::unwrap_used)]
604mod tests {
605 use super::*;
606
607 #[test]
608 fn empty() {
609 std::mem::drop(simple_logger::init());
610
611 let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
612 &NdArrayDevice::Cpu,
613 Array::<f32, _>::zeros((0,)),
614 NonZeroUsize::MIN,
615 Positive(1.0),
616 NonZeroUsize::MIN,
617 Positive(1e-4),
618 10,
619 None,
620 42,
621 )
622 .unwrap();
623 assert!(encoded.is_empty());
624 let mut decoded = Array::<f32, _>::zeros((0,));
625 decode_into::<f32, _, _, NdArray<f32>>(
626 &NdArrayDevice::Cpu,
627 encoded,
628 decoded.view_mut(),
629 NonZeroUsize::MIN,
630 NonZeroUsize::MIN,
631 )
632 .unwrap();
633 }
634
635 #[test]
636 fn ones() {
637 std::mem::drop(simple_logger::init());
638
639 let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
640 &NdArrayDevice::Cpu,
641 Array::<f32, _>::zeros((1, 1, 1, 1)),
642 NonZeroUsize::MIN,
643 Positive(1.0),
644 NonZeroUsize::MIN,
645 Positive(1e-4),
646 10,
647 None,
648 42,
649 )
650 .unwrap();
651 let mut decoded = Array::<f32, _>::zeros((1, 1, 1, 1));
652 decode_into::<f32, _, _, NdArray<f32>>(
653 &NdArrayDevice::Cpu,
654 encoded,
655 decoded.view_mut(),
656 NonZeroUsize::MIN,
657 NonZeroUsize::MIN,
658 )
659 .unwrap();
660 }
661
662 #[test]
663 fn r#const() {
664 std::mem::drop(simple_logger::init());
665
666 let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
667 &NdArrayDevice::Cpu,
668 Array::<f32, _>::from_elem((2, 1, 3), 42.0),
669 NonZeroUsize::MIN,
670 Positive(1.0),
671 NonZeroUsize::MIN,
672 Positive(1e-4),
673 10,
674 None,
675 42,
676 )
677 .unwrap();
678 let mut decoded = Array::<f32, _>::zeros((2, 1, 3));
679 decode_into::<f32, _, _, NdArray<f32>>(
680 &NdArrayDevice::Cpu,
681 encoded,
682 decoded.view_mut(),
683 NonZeroUsize::MIN,
684 NonZeroUsize::MIN,
685 )
686 .unwrap();
687 }
688
689 #[test]
690 fn const_batched() {
691 std::mem::drop(simple_logger::init());
692
693 let encoded = encode::<f32, _, _, Autodiff<NdArray<f32>>>(
694 &NdArrayDevice::Cpu,
695 Array::<f32, _>::from_elem((2, 1, 3), 42.0),
696 NonZeroUsize::MIN,
697 Positive(1.0),
698 NonZeroUsize::MIN,
699 Positive(1e-4),
700 10,
701 Some(NonZeroUsize::MIN.saturating_add(1)),
702 42,
703 )
704 .unwrap();
705 let mut decoded = Array::<f32, _>::zeros((2, 1, 3));
706 decode_into::<f32, _, _, NdArray<f32>>(
707 &NdArrayDevice::Cpu,
708 encoded,
709 decoded.view_mut(),
710 NonZeroUsize::MIN,
711 NonZeroUsize::MIN,
712 )
713 .unwrap();
714 }
715
716 #[test]
717 fn linspace() {
718 std::mem::drop(simple_logger::init());
719
720 let data = Array::linspace(0.0_f64, 100.0_f64, 100);
721
722 let fourier_features = NonZeroUsize::new(16).unwrap();
723 let fourier_scale = Positive(10.0);
724 let num_blocks = NonZeroUsize::new(2).unwrap();
725 let learning_rate = Positive(1e-4);
726 let num_epochs = 100;
727 let seed = 42;
728
729 for mini_batch_size in [
730 None, Some(NonZeroUsize::MIN), Some(NonZeroUsize::MIN.saturating_add(6)), Some(NonZeroUsize::MIN.saturating_add(9)), Some(NonZeroUsize::MIN.saturating_add(1000)), ] {
736 let mut decoded = Array::<f64, _>::zeros(data.shape());
737 let encoded = encode::<f64, _, _, Autodiff<NdArray<f64>>>(
738 &NdArrayDevice::Cpu,
739 data.view(),
740 fourier_features,
741 fourier_scale,
742 num_blocks,
743 learning_rate,
744 num_epochs,
745 mini_batch_size,
746 seed,
747 )
748 .unwrap();
749
750 decode_into::<f64, _, _, NdArray<f64>>(
751 &NdArrayDevice::Cpu,
752 encoded,
753 decoded.view_mut(),
754 fourier_features,
755 num_blocks,
756 )
757 .unwrap();
758 }
759 }
760}