Skip to content

Commit a97ce2e

Browse files
committed
get rid of unnecessary permutations
1 parent 8e69b17 commit a97ce2e

File tree

1 file changed

+16
-20
lines changed

1 file changed

+16
-20
lines changed

β€Žsrc/FourierLayer.jl

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,23 @@ model = FourierLayer(128, 128, 100, 16, Οƒ)
3737
"""
3838
struct FourierLayer{F,Tc<:Complex{<:AbstractFloat},N,Tr<:AbstractFloat,Bf,Bl}
3939
# F: Activation, Tc/Tr: Complex/Real eltype
40-
Wf::AbstractArray{Tc,N}
41-
Wl::AbstractMatrix{Tr}
40+
Wf::Array{Tc,N}
41+
Wl::Array{Tr,2}
4242
grid::Tuple
4343
Οƒ::F
4444
Ξ»::Tuple
4545
bf::Bf
4646
bl::Bl
4747
# Constructor for the entire fourier layer
4848
function FourierLayer(
49-
Wf::AbstractArray{Tc,N}, Wl::AbstractMatrix{Tr},
49+
Wf::Array{Tc,N}, Wl::Array{Tr,2},
5050
grid::Tuple,Οƒ::F = identity,
5151
Ξ»::Tuple = (12), bf = true, bl = true) where
5252
{F,Tc<:Complex{<:AbstractFloat},N,Tr<:AbstractFloat}
5353

5454
# create the biases with one singleton dimension for broadcasting
55-
bf = Flux.create_bias(Wf, bf, 1, size(Wf,2), grid...)
56-
bl = Flux.create_bias(Wl, bl, 1, size(Wl,1), grid...)
55+
bf = Flux.create_bias(Wf, bf, size(Wf,2), grid..., 1)
56+
bl = Flux.create_bias(Wl, bl, size(Wl,1), grid..., 1)
5757
new{F,Tc,N,Tr,typeof(bf),typeof(bl)}(Wf, Wl, grid, Οƒ, Ξ», bf, bl)
5858
end
5959
end
@@ -124,10 +124,10 @@ this is implemented as a generated function =#
124124
@generated function (a::FourierLayer)(x::AbstractArray{T,N}) where {T,N}
125125
#= Assign the parameters =#
126126
params = quote
127-
Wα΅© = a.Wf
128-
Wβ‚— = a.Wl
129-
bα΅© = a.bf
130-
bβ‚— = a.bl
127+
Wf = a.Wf
128+
Wl = a.Wl
129+
bf = a.bf
130+
bl = a.bl
131131
Οƒ = fast_act(a.Οƒ, x)
132132
end
133133

@@ -136,17 +136,15 @@ this is implemented as a generated function =#
136136
for the rest, it's more convenient to have it in the first one
137137
For this we need to generate the permutation tuple first
138138
experm evaluates to a tuple (N,1,2,...,N-1) =#
139-
experm = :(tuple(N,$:([k for k = 1:N-1]...)))
140-
permute = :(xβ‚š = permutedims(x, $experm))
141139

142140
#= The linear path
143141
x -> Wl
144142
As an argument to the einsum macro we need a list of named grid dimensions
145143
grids evaluates to a tuple of names of schema (grid_1, grid_2, ..., grid_N) =#
146144
grids = [Symbol("grid_$(i)") for i ∈ 1:N-2]
147-
linear_mul = :(@ein 𝔏[batch, out, $(grids...)] :=
148-
Wβ‚—[out, in] * xβ‚š[batch, in, $(grids...)])
149-
linear_bias = :(𝔏 .+ bβ‚—)
145+
linear_mul = :(@ein 𝔏[out, $(grids...), batch] :=
146+
Wl[out, in] * x[in, $(grids...), batch])
147+
linear_bias = :(𝔏 .+= bl)
150148

151149
#= The convolution path
152150
x -> 𝔉 -> Wf -> i𝔉
@@ -156,28 +154,26 @@ this is implemented as a generated function =#
156154
fourier_dims evaluates to a tuple of Ints with range 3:N since the grid dims
157155
are sequential up to the last dim of the input =#
158156
fourier_dims = :([n for n ∈ 3:N])
159-
fourier_mul = :(@ein 𝔉[batch, out, $(grids...)] :=
160-
Wα΅©[in, out, $(grids...)] * fft(xβ‚š, $(fourier_dims))[batch, in, $(grids...)])
161-
fourier_bias = :(𝔉 .+ bα΅©)
157+
fourier_mul = :(@ein 𝔉[out, $(grids...), batch] :=
158+
Wf[in, out, $(grids...)] * fft(x, $(fourier_dims))[in, $(grids...), batch])
159+
fourier_bias = :(𝔉 .+= bf)
162160

163161
#= Do the inverse transform
164162
We need to permute back to match the shape of the linear path =#
165163
fourier_inv = :(i𝔉 = ifft(𝔉, $(fourier_dims)))
166164

167165
#= Undo the initial permutation
168166
experm_inv evaluates to a tuple (2,3,...,N,1) =#
169-
experm_inv = :(tuple($:([k for k = 2:N]...),1))
170167

171168
return Expr(
172169
:block,
173170
params,
174-
permute,
175171
linear_mul,
176172
linear_bias,
177173
fourier_mul,
178174
fourier_bias,
179175
fourier_inv,
180-
:(return permutedims(Οƒ.(𝔏 + real(i𝔉)), $experm_inv))
176+
:(return Οƒ.(𝔏 + real(i𝔉)))
181177
)
182178
end
183179

0 commit comments

Comments
Β (0)