@@ -35,26 +35,26 @@ So we would have:
35
35
model = FourierLayer(128, 128, 100, 16, Ο)
36
36
```
37
37
"""
38
- struct FourierLayer{F,Tc <: Complex{<:AbstractFloat} ,N,Tr <: AbstractFloat ,Bf,Bl}
38
+ struct FourierLayer{F,wf,wl ,Bf,Bl}
39
39
# F: Activation, Tc/Tr: Complex/Real eltype
40
- Wf:: Array{Tc,N}
41
- Wl:: Array{Tr,2}
40
+ Wf:: wf
41
+ Wl:: wl
42
42
grid:: Tuple
43
43
Ο:: F
44
44
Ξ»:: Tuple
45
45
bf:: Bf
46
46
bl:: Bl
47
47
# Constructor for the entire fourier layer
48
48
function FourierLayer (
49
- Wf:: Array{Tc,N} , Wl:: Array{Tr,2} ,
49
+ Wf:: wf , Wl:: wl ,
50
50
grid:: Tuple ,Ο:: F = identity,
51
51
Ξ»:: Tuple = (12 ), bf = true , bl = true ) where
52
- {F,Tc <: Complex{<:AbstractFloat} ,N,Tr <: AbstractFloat }
52
+ {F,wf,wl }
53
53
54
54
# create the biases with one singleton dimension for broadcasting
55
55
bf = Flux. create_bias (Wf, bf, size (Wf,2 ), grid... , 1 )
56
56
bl = Flux. create_bias (Wl, bl, size (Wl,1 ), grid... , 1 )
57
- new {F,Tc,N,Tr ,typeof(bf),typeof(bl)} (Wf, Wl, grid, Ο, Ξ», bf, bl)
57
+ new {F,wf,wl ,typeof(bf),typeof(bl)} (Wf, Wl, grid, Ο, Ξ», bf, bl)
58
58
end
59
59
end
60
60
@@ -131,20 +131,12 @@ this is implemented as a generated function =#
131
131
Ο = fast_act (a. Ο, x)
132
132
end
133
133
134
- #= Do a permutation
135
- DataLoader requires batch to be the last dim
136
- for the rest, it's more convenient to have it in the first one
137
- For this we need to generate the permutation tuple first
138
- experm evaluates to a tuple (N,1,2,...,N-1) =#
139
-
140
134
#= The linear path
141
135
x -> Wl
142
136
As an argument to the einsum macro we need a list of named grid dimensions
143
137
grids evaluates to a tuple of names of schema (grid_1, grid_2, ..., grid_N) =#
144
138
grids = [Symbol (" grid_$(i) " ) for i β 1 : N- 2 ]
145
- linear_mul = :(@ein π[out, $ (grids... ), batch] :=
146
- Wl[out, in] * x[in, $ (grids... ), batch])
147
- linear_bias = :(π .+ = bl)
139
+ linear_mul = :(π = Wl β‘ x .+ bl)
148
140
149
141
#= The convolution path
150
142
x -> π -> Wf -> iπ
@@ -162,18 +154,14 @@ this is implemented as a generated function =#
162
154
We need to permute back to match the shape of the linear path =#
163
155
fourier_inv = :(iπ = ifft (π, $ (fourier_dims)))
164
156
165
- #= Undo the initial permutation
166
- experm_inv evaluates to a tuple (2,3,...,N,1) =#
167
-
168
157
return Expr (
169
158
:block ,
170
159
params,
171
160
linear_mul,
172
- linear_bias,
173
161
fourier_mul,
174
- fourier_bias,
162
+ # fourier_bias,
175
163
fourier_inv,
176
- :(return Ο .(π + real (iπ)))
164
+ :(return Ο .(π + real . (iπ)))
177
165
)
178
166
end
179
167
0 commit comments