Skip to content

Commit 2e9d4f0

Browse files
committed
use TensorCore's boxdot in the linear path
1 parent a97ce2e commit 2e9d4f0

File tree

3 files changed

+11
-22
lines changed

3 files changed

+11
-22
lines changed

β€ŽProject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1010
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1111
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
1212
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
13+
TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50"
1314

1415
[compat]
1516
CUDA = "3"
1617
FFTW = "1"
1718
Flux = "0.12"
1819
NNlib = "0.8"
19-
OMEinsum = "0.6"
2020
julia = "1.6"
2121

2222
[extras]

β€Žsrc/FourierLayer.jl

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -35,26 +35,26 @@ So we would have:
3535
model = FourierLayer(128, 128, 100, 16, Οƒ)
3636
```
3737
"""
38-
struct FourierLayer{F,Tc<:Complex{<:AbstractFloat},N,Tr<:AbstractFloat,Bf,Bl}
38+
struct FourierLayer{F,wf,wl,Bf,Bl}
3939
# F: Activation, Tc/Tr: Complex/Real eltype
40-
Wf::Array{Tc,N}
41-
Wl::Array{Tr,2}
40+
Wf::wf
41+
Wl::wl
4242
grid::Tuple
4343
Οƒ::F
4444
Ξ»::Tuple
4545
bf::Bf
4646
bl::Bl
4747
# Constructor for the entire fourier layer
4848
function FourierLayer(
49-
Wf::Array{Tc,N}, Wl::Array{Tr,2},
49+
Wf::wf, Wl::wl,
5050
grid::Tuple,Οƒ::F = identity,
5151
Ξ»::Tuple = (12), bf = true, bl = true) where
52-
{F,Tc<:Complex{<:AbstractFloat},N,Tr<:AbstractFloat}
52+
{F,wf,wl}
5353

5454
# create the biases with one singleton dimension for broadcasting
5555
bf = Flux.create_bias(Wf, bf, size(Wf,2), grid..., 1)
5656
bl = Flux.create_bias(Wl, bl, size(Wl,1), grid..., 1)
57-
new{F,Tc,N,Tr,typeof(bf),typeof(bl)}(Wf, Wl, grid, Οƒ, Ξ», bf, bl)
57+
new{F,wf,wl,typeof(bf),typeof(bl)}(Wf, Wl, grid, Οƒ, Ξ», bf, bl)
5858
end
5959
end
6060

@@ -131,20 +131,12 @@ this is implemented as a generated function =#
131131
Οƒ = fast_act(a.Οƒ, x)
132132
end
133133

134-
#= Do a permutation
135-
DataLoader requires batch to be the last dim
136-
for the rest, it's more convenient to have it in the first one
137-
For this we need to generate the permutation tuple first
138-
experm evaluates to a tuple (N,1,2,...,N-1) =#
139-
140134
#= The linear path
141135
x -> Wl
142136
As an argument to the einsum macro we need a list of named grid dimensions
143137
grids evaluates to a tuple of names of schema (grid_1, grid_2, ..., grid_N) =#
144138
grids = [Symbol("grid_$(i)") for i ∈ 1:N-2]
145-
linear_mul = :(@ein 𝔏[out, $(grids...), batch] :=
146-
Wl[out, in] * x[in, $(grids...), batch])
147-
linear_bias = :(𝔏 .+= bl)
139+
linear_mul = :(𝔏 = Wl ⊑ x .+ bl)
148140

149141
#= The convolution path
150142
x -> 𝔉 -> Wf -> i𝔉
@@ -162,18 +154,14 @@ this is implemented as a generated function =#
162154
We need to permute back to match the shape of the linear path =#
163155
fourier_inv = :(i𝔉 = ifft(𝔉, $(fourier_dims)))
164156

165-
#= Undo the initial permutation
166-
experm_inv evaluates to a tuple (2,3,...,N,1) =#
167-
168157
return Expr(
169158
:block,
170159
params,
171160
linear_mul,
172-
linear_bias,
173161
fourier_mul,
174-
fourier_bias,
162+
#fourier_bias,
175163
fourier_inv,
176-
:(return Οƒ.(𝔏 + real(i𝔉)))
164+
:(return Οƒ.(𝔏 + real.(i𝔉)))
177165
)
178166
end
179167

β€Žsrc/OperatorLearning.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using Random
99
using Random: AbstractRNG
1010
using Flux: nfan, glorot_uniform, batch
1111
using OMEinsum
12+
using TensorCore
1213
using NNlib: fast_act
1314

1415
export FourierLayer, DeepONet

0 commit comments

Comments
Β (0)