numcodecs_fourier_network/
modules.rs
1use std::num::NonZeroUsize;
2
3use burn::{
4 config::Config,
5 module::{Module, Param},
6 nn::{BatchNorm, BatchNormConfig, Gelu, Linear, LinearConfig},
7 prelude::Backend,
8 tensor::{Float, Tensor},
9};
10
11#[derive(Debug, Module)]
12pub struct Block<B: Backend> {
13 bn2_1: BatchNorm<B, 0>,
14 gu2_2: Gelu,
15 ln2_3: Linear<B>,
16}
17
18impl<B: Backend> Block<B> {
19 #[expect(clippy::let_and_return)]
20 pub fn forward(&self, x: Tensor<B, 2, Float>) -> Tensor<B, 2, Float> {
21 let x = self.bn2_1.forward(x);
22 let x = self.gu2_2.forward(x);
23 let x = self.ln2_3.forward(x);
24 x
25 }
26}
27
28#[derive(Config, Debug)]
29pub struct BlockConfig {
30 pub fourier_features: NonZeroUsize,
31}
32
33impl BlockConfig {
34 pub fn init<B: Backend>(&self, device: &B::Device) -> Block<B> {
35 Block {
36 bn2_1: BatchNormConfig::new(self.fourier_features.get()).init(device),
37 gu2_2: Gelu,
38 ln2_3: LinearConfig::new(self.fourier_features.get(), self.fourier_features.get())
39 .init(device),
40 }
41 }
42}
43
44#[derive(Debug, Module)]
45pub struct Model<B: Backend> {
46 ln1: Linear<B>,
47 bl2: Vec<Block<B>>,
48 bn3: BatchNorm<B, 0>,
49 gu4: Gelu,
50 ln5: Linear<B>,
51}
52
53impl<B: Backend> Model<B> {
54 #[expect(clippy::let_and_return)]
55 pub fn forward(&self, x: Tensor<B, 2, Float>) -> Tensor<B, 2, Float> {
56 let x = self.ln1.forward(x);
57
58 let mut x = x;
59 for block in &self.bl2 {
60 x = block.forward(x);
61 }
62
63 let x = self.bn3.forward(x);
64 let x = self.gu4.forward(x);
65 let x = self.ln5.forward(x);
66
67 x
68 }
69}
70
71#[derive(Config, Debug)]
72pub struct ModelConfig {
73 pub fourier_features: NonZeroUsize,
74 pub num_blocks: NonZeroUsize,
75}
76
77impl ModelConfig {
78 pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
79 let block = BlockConfig::new(self.fourier_features);
80
81 Model {
82 ln1: LinearConfig::new(self.fourier_features.get() * 2, self.fourier_features.get())
83 .init(device),
84 #[expect(clippy::useless_conversion)] bl2: (1..self.num_blocks.get())
86 .into_iter()
87 .map(|_| block.init(device))
88 .collect(),
89 bn3: BatchNormConfig::new(self.fourier_features.get()).init(device),
90 gu4: Gelu,
91 ln5: LinearConfig::new(self.fourier_features.get(), 1).init(device),
92 }
93 }
94}
95
96#[derive(Debug, Module)]
97pub struct ModelExtra<B: Backend> {
98 pub model: Model<B>,
99 pub b_t: Param<Tensor<B, 2, Float>>,
100 pub mean: Param<Tensor<B, 1, Float>>,
101 pub stdv: Param<Tensor<B, 1, Float>>,
102}