@@ -37,23 +37,23 @@ model = FourierLayer(128, 128, 100, 16, Ο)
37
37
"""
38
38
struct FourierLayer{F,Tc<: Complex{<:AbstractFloat} ,N,Tr<: AbstractFloat ,Bf,Bl}
39
39
# F: Activation, Tc/Tr: Complex/Real eltype
40
- Wf:: AbstractArray {Tc,N}
41
- Wl:: AbstractMatrix {Tr}
40
+ Wf:: Array {Tc,N}
41
+ Wl:: Array {Tr,2 }
42
42
grid:: Tuple
43
43
Ο:: F
44
44
Ξ»:: Tuple
45
45
bf:: Bf
46
46
bl:: Bl
47
47
# Constructor for the entire fourier layer
48
48
function FourierLayer (
49
- Wf:: AbstractArray {Tc,N} , Wl:: AbstractMatrix {Tr} ,
49
+ Wf:: Array {Tc,N} , Wl:: Array {Tr,2 } ,
50
50
grid:: Tuple ,Ο:: F = identity,
51
51
Ξ»:: Tuple = (12 ), bf = true , bl = true ) where
52
52
{F,Tc<: Complex{<:AbstractFloat} ,N,Tr<: AbstractFloat }
53
53
54
54
# create the biases with one singleton dimension for broadcasting
55
- bf = Flux. create_bias (Wf, bf, 1 , size (Wf,2 ), grid... )
56
- bl = Flux. create_bias (Wl, bl, 1 , size (Wl,1 ), grid... )
55
+ bf = Flux. create_bias (Wf, bf, size (Wf,2 ), grid... , 1 )
56
+ bl = Flux. create_bias (Wl, bl, size (Wl,1 ), grid... , 1 )
57
57
new {F,Tc,N,Tr,typeof(bf),typeof(bl)} (Wf, Wl, grid, Ο, Ξ», bf, bl)
58
58
end
59
59
end
@@ -124,10 +124,10 @@ this is implemented as a generated function =#
124
124
@generated function (a:: FourierLayer )(x:: AbstractArray{T,N} ) where {T,N}
125
125
#= Assign the parameters =#
126
126
params = quote
127
- Wα΅© = a. Wf
128
- Wβ = a. Wl
129
- bα΅© = a. bf
130
- bβ = a. bl
127
+ Wf = a. Wf
128
+ Wl = a. Wl
129
+ bf = a. bf
130
+ bl = a. bl
131
131
Ο = fast_act (a. Ο, x)
132
132
end
133
133
@@ -136,17 +136,15 @@ this is implemented as a generated function =#
136
136
for the rest, it's more convenient to have it in the first one
137
137
For this we need to generate the permutation tuple first
138
138
experm evaluates to a tuple (N,1,2,...,N-1) =#
139
- experm = :(tuple (N,$ :([k for k = 1 : N- 1 ]. .. )))
140
- permute = :(xβ = permutedims (x, $ experm))
141
139
142
140
#= The linear path
143
141
x -> Wl
144
142
As an argument to the einsum macro we need a list of named grid dimensions
145
143
grids evaluates to a tuple of names of schema (grid_1, grid_2, ..., grid_N) =#
146
144
grids = [Symbol (" grid_$(i) " ) for i β 1 : N- 2 ]
147
- linear_mul = :(@ein π[batch, out, $ (grids... )] :=
148
- Wβ [out, in] * xβ[batch, in, $ (grids... )])
149
- linear_bias = :(π .+ bβ )
145
+ linear_mul = :(@ein π[out, $ (grids... ), batch ] :=
146
+ Wl [out, in] * x[ in, $ (grids... ), batch ])
147
+ linear_bias = :(π .+ = bl )
150
148
151
149
#= The convolution path
152
150
x -> π -> Wf -> iπ
@@ -156,28 +154,26 @@ this is implemented as a generated function =#
156
154
fourier_dims evaluates to a tuple of Ints with range 3:N since the grid dims
157
155
are sequential up to the last dim of the input =#
158
156
fourier_dims = :([n for n β 3 : N])
159
- fourier_mul = :(@ein π[batch, out, $ (grids... )] :=
160
- Wα΅© [in, out, $ (grids... )] * fft (xβ , $ (fourier_dims))[batch, in, $ (grids... )])
161
- fourier_bias = :(π .+ bα΅© )
157
+ fourier_mul = :(@ein π[out, $ (grids... ), batch ] :=
158
+ Wf [in, out, $ (grids... )] * fft (x , $ (fourier_dims))[in, $ (grids... ), batch ])
159
+ fourier_bias = :(π .+ = bf )
162
160
163
161
#= Do the inverse transform
164
162
We need to permute back to match the shape of the linear path =#
165
163
fourier_inv = :(iπ = ifft (π, $ (fourier_dims)))
166
164
167
165
#= Undo the initial permutation
168
166
experm_inv evaluates to a tuple (2,3,...,N,1) =#
169
- experm_inv = :(tuple ($ :([k for k = 2 : N]. .. ),1 ))
170
167
171
168
return Expr (
172
169
:block ,
173
170
params,
174
- permute,
175
171
linear_mul,
176
172
linear_bias,
177
173
fourier_mul,
178
174
fourier_bias,
179
175
fourier_inv,
180
- :(return permutedims ( Ο .(π + real (iπ)), $ experm_inv ))
176
+ :(return Ο .(π + real (iπ)))
181
177
)
182
178
end
183
179
0 commit comments