numcodecs_fourier_network/
modules.rs

1use std::num::NonZeroUsize;
2
3use burn::{
4    config::Config,
5    module::{Module, Param},
6    nn::{BatchNorm, BatchNormConfig, Gelu, Linear, LinearConfig},
7    prelude::Backend,
8    tensor::{Float, Tensor},
9};
10
11#[derive(Debug, Module)]
12pub struct Block<B: Backend> {
13    bn2_1: BatchNorm<B, 0>,
14    gu2_2: Gelu,
15    ln2_3: Linear<B>,
16}
17
18impl<B: Backend> Block<B> {
19    #[expect(clippy::let_and_return)]
20    pub fn forward(&self, x: Tensor<B, 2, Float>) -> Tensor<B, 2, Float> {
21        let x = self.bn2_1.forward(x);
22        let x = self.gu2_2.forward(x);
23        let x = self.ln2_3.forward(x);
24        x
25    }
26}
27
28#[derive(Config, Debug)]
29pub struct BlockConfig {
30    pub fourier_features: NonZeroUsize,
31}
32
33impl BlockConfig {
34    pub fn init<B: Backend>(&self, device: &B::Device) -> Block<B> {
35        Block {
36            bn2_1: BatchNormConfig::new(self.fourier_features.get()).init(device),
37            gu2_2: Gelu,
38            ln2_3: LinearConfig::new(self.fourier_features.get(), self.fourier_features.get())
39                .init(device),
40        }
41    }
42}
43
44#[derive(Debug, Module)]
45pub struct Model<B: Backend> {
46    ln1: Linear<B>,
47    bl2: Vec<Block<B>>,
48    bn3: BatchNorm<B, 0>,
49    gu4: Gelu,
50    ln5: Linear<B>,
51}
52
53impl<B: Backend> Model<B> {
54    #[expect(clippy::let_and_return)]
55    pub fn forward(&self, x: Tensor<B, 2, Float>) -> Tensor<B, 2, Float> {
56        let x = self.ln1.forward(x);
57
58        let mut x = x;
59        for block in &self.bl2 {
60            x = block.forward(x);
61        }
62
63        let x = self.bn3.forward(x);
64        let x = self.gu4.forward(x);
65        let x = self.ln5.forward(x);
66
67        x
68    }
69}
70
71#[derive(Config, Debug)]
72pub struct ModelConfig {
73    pub fourier_features: NonZeroUsize,
74    pub num_blocks: NonZeroUsize,
75}
76
77impl ModelConfig {
78    pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
79        let block = BlockConfig::new(self.fourier_features);
80
81        Model {
82            ln1: LinearConfig::new(self.fourier_features.get() * 2, self.fourier_features.get())
83                .init(device),
84            #[expect(clippy::useless_conversion)] // (1..num_blocks).into_iter()
85            bl2: (1..self.num_blocks.get())
86                .into_iter()
87                .map(|_| block.init(device))
88                .collect(),
89            bn3: BatchNormConfig::new(self.fourier_features.get()).init(device),
90            gu4: Gelu,
91            ln5: LinearConfig::new(self.fourier_features.get(), 1).init(device),
92        }
93    }
94}
95
96#[derive(Debug, Module)]
97pub struct ModelExtra<B: Backend> {
98    pub model: Model<B>,
99    pub b_t: Param<Tensor<B, 2, Float>>,
100    pub mean: Param<Tensor<B, 1, Float>>,
101    pub stdv: Param<Tensor<B, 1, Float>>,
102}