numcodecs_fourier_network/
modules.rs1use std::num::NonZeroUsize;
2
3use burn::{
4 config::Config,
5 module::{Module, Param},
6 nn::{BatchNorm, BatchNormConfig, Gelu, Linear, LinearConfig},
7 prelude::Backend,
8 record::{PrecisionSettings, Record},
9 tensor::{Float, Tensor},
10};
11
12#[derive(Debug, Module)]
13pub struct Block<B: Backend> {
14 bn2_1: BatchNorm<B, 0>,
15 gu2_2: Gelu,
16 ln2_3: Linear<B>,
17}
18
19impl<B: Backend> Block<B> {
20 #[expect(clippy::let_and_return)]
21 pub fn forward(&self, x: Tensor<B, 2, Float>) -> Tensor<B, 2, Float> {
22 let x = self.bn2_1.forward(x);
23 let x = self.gu2_2.forward(x);
24 let x = self.ln2_3.forward(x);
25 x
26 }
27}
28
29#[derive(Config, Debug)]
30pub struct BlockConfig {
31 pub fourier_features: NonZeroUsize,
32}
33
34impl BlockConfig {
35 pub fn init<B: Backend>(&self, device: &B::Device) -> Block<B> {
36 Block {
37 bn2_1: BatchNormConfig::new(self.fourier_features.get()).init(device),
38 gu2_2: Gelu,
39 ln2_3: LinearConfig::new(self.fourier_features.get(), self.fourier_features.get())
40 .init(device),
41 }
42 }
43}
44
45#[derive(Debug, Module)]
46pub struct Model<B: Backend> {
47 ln1: Linear<B>,
48 bl2: Vec<Block<B>>,
49 bn3: BatchNorm<B, 0>,
50 gu4: Gelu,
51 ln5: Linear<B>,
52}
53
54impl<B: Backend> Model<B> {
55 #[expect(clippy::let_and_return)]
56 pub fn forward(&self, x: Tensor<B, 2, Float>) -> Tensor<B, 2, Float> {
57 let x = self.ln1.forward(x);
58
59 let mut x = x;
60 for block in &self.bl2 {
61 x = block.forward(x);
62 }
63
64 let x = self.bn3.forward(x);
65 let x = self.gu4.forward(x);
66 let x = self.ln5.forward(x);
67
68 x
69 }
70}
71
72#[derive(Config, Debug)]
73pub struct ModelConfig {
74 pub fourier_features: NonZeroUsize,
75 pub num_blocks: NonZeroUsize,
76}
77
78impl ModelConfig {
79 pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
80 let block = BlockConfig::new(self.fourier_features);
81
82 Model {
83 ln1: LinearConfig::new(self.fourier_features.get() * 2, self.fourier_features.get())
84 .init(device),
85 #[expect(clippy::useless_conversion)] bl2: (1..self.num_blocks.get())
87 .into_iter()
88 .map(|_| block.init(device))
89 .collect(),
90 bn3: BatchNormConfig::new(self.fourier_features.get()).init(device),
91 gu4: Gelu,
92 ln5: LinearConfig::new(self.fourier_features.get(), 1).init(device),
93 }
94 }
95}
96
97pub struct ModelExtra<B: Backend> {
98 pub model: <Model<B> as Module<B>>::Record,
99 pub b_t: Param<Tensor<B, 2, Float>>,
100 pub mean: Param<Tensor<B, 1, Float>>,
101 pub stdv: Param<Tensor<B, 1, Float>>,
102 pub version: crate::FourierNetworkCodecVersion,
103}
104
105impl<B: Backend> Record<B> for ModelExtra<B> {
106 type Item<S: PrecisionSettings> = ModelExtraItem<B, S>;
107
108 fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
109 ModelExtraItem {
110 model: self.model.into_item(),
111 b_t: self.b_t.into_item(),
112 mean: self.mean.into_item(),
113 stdv: self.stdv.into_item(),
114 version: self.version,
115 }
116 }
117
118 fn from_item<S: PrecisionSettings>(item: Self::Item<S>, device: &B::Device) -> Self {
119 Self {
120 model: Record::<B>::from_item::<S>(item.model, device),
121 b_t: Record::<B>::from_item::<S>(item.b_t, device),
122 mean: Record::<B>::from_item::<S>(item.mean, device),
123 stdv: Record::<B>::from_item::<S>(item.stdv, device),
124 version: item.version,
125 }
126 }
127}
128
129#[derive(serde::Serialize, serde::Deserialize)]
130#[serde(bound = "")]
131pub struct ModelExtraItem<B: Backend, S: PrecisionSettings> {
132 model: <<Model<B> as Module<B>>::Record as Record<B>>::Item<S>,
133 b_t: <Param<Tensor<B, 2, Float>> as Record<B>>::Item<S>,
134 mean: <Param<Tensor<B, 1, Float>> as Record<B>>::Item<S>,
135 stdv: <Param<Tensor<B, 1, Float>> as Record<B>>::Item<S>,
136 version: crate::FourierNetworkCodecVersion,
137}