numcodecs_fourier_network/
modules.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
use std::num::NonZeroUsize;

use burn::{
    config::Config,
    module::{Module, Param},
    nn::{BatchNorm, BatchNormConfig, Gelu, Linear, LinearConfig},
    prelude::Backend,
    tensor::{Float, Tensor},
};

#[derive(Debug, Module)]
pub struct Block<B: Backend> {
    bn2_1: BatchNorm<B, 0>,
    gu2_2: Gelu,
    ln2_3: Linear<B>,
}

impl<B: Backend> Block<B> {
    #[allow(clippy::let_and_return)]
    pub fn forward(&self, x: Tensor<B, 2, Float>) -> Tensor<B, 2, Float> {
        let x = self.bn2_1.forward(x);
        let x = self.gu2_2.forward(x);
        let x = self.ln2_3.forward(x);
        x
    }
}

#[derive(Config, Debug)]
pub struct BlockConfig {
    pub fourier_features: NonZeroUsize,
}

impl BlockConfig {
    pub fn init<B: Backend>(&self, device: &B::Device) -> Block<B> {
        Block {
            bn2_1: BatchNormConfig::new(self.fourier_features.get()).init(device),
            gu2_2: Gelu,
            ln2_3: LinearConfig::new(self.fourier_features.get(), self.fourier_features.get())
                .init(device),
        }
    }
}

#[derive(Debug, Module)]
pub struct Model<B: Backend> {
    ln1: Linear<B>,
    bl2: Vec<Block<B>>,
    bn3: BatchNorm<B, 0>,
    gu4: Gelu,
    ln5: Linear<B>,
}

impl<B: Backend> Model<B> {
    #[allow(clippy::let_and_return)]
    pub fn forward(&self, x: Tensor<B, 2, Float>) -> Tensor<B, 2, Float> {
        let x = self.ln1.forward(x);

        let mut x = x;
        for block in &self.bl2 {
            x = block.forward(x);
        }

        let x = self.bn3.forward(x);
        let x = self.gu4.forward(x);
        let x = self.ln5.forward(x);

        x
    }
}

#[derive(Config, Debug)]
pub struct ModelConfig {
    pub fourier_features: NonZeroUsize,
    pub num_blocks: NonZeroUsize,
}

impl ModelConfig {
    pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
        let block = BlockConfig::new(self.fourier_features);

        Model {
            ln1: LinearConfig::new(self.fourier_features.get() * 2, self.fourier_features.get())
                .init(device),
            #[allow(clippy::useless_conversion)] // (1..num_blocks).into_iter()
            bl2: (1..self.num_blocks.get())
                .into_iter()
                .map(|_| block.init(device))
                .collect(),
            bn3: BatchNormConfig::new(self.fourier_features.get()).init(device),
            gu4: Gelu,
            ln5: LinearConfig::new(self.fourier_features.get(), 1).init(device),
        }
    }
}

#[derive(Debug, Module)]
pub struct ModelExtra<B: Backend> {
    pub model: Model<B>,
    pub b_t: Param<Tensor<B, 2, Float>>,
    pub mean: Param<Tensor<B, 1, Float>>,
    pub stdv: Param<Tensor<B, 1, Float>>,
}