Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 59d69cd

Browse files
committed
Prototype SimpleJNFK
1 parent 1600f7c commit 59d69cd

File tree

3 files changed

+62
-3
lines changed

3 files changed

+62
-3
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "0.1.20"
4+
version = "0.1.21"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
99
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1010
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
1213
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
1314
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1415
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1516
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
17+
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
1618
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1719

1820
[weakdeps]

src/SimpleNonlinearSolve.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ module SimpleNonlinearSolve
22

33
using Reexport
44
using FiniteDiff, ForwardDiff
5-
using ForwardDiff: Dual
5+
using ForwardDiff: Dual, Partials, Tag
66
using StaticArraysCore
77
using LinearAlgebra
8+
using LinearSolve
89
import ArrayInterface
910
using DiffEqBase
1011

@@ -39,6 +40,7 @@ include("ad.jl")
3940
include("halley.jl")
4041
include("alefeld.jl")
4142
include("itp.jl")
43+
include("jnfk.jl")
4244

4345
# Batched Solver Support
4446
include("batched/utils.jl")
@@ -77,7 +79,7 @@ end
7779

7880
# DiffEq styled algorithms
7981
export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Halley, Klement,
80-
Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld, ITP
82+
Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld, ITP, SimpleJFNK
8183
export BatchedBroyden, BatchedSimpleNewtonRaphson, BatchedSimpleDFSane
8284

8385
end # module

src/jnfk.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
struct SimpleJNFKJacVecTag end
2+
3+
function jvp_forwarddiff(f, x::AbstractArray{T}, v) where {T}
4+
v_ = reshape(v, axes(x))
5+
y = (Dual{Tag{SimpleJNFKJacVecTag, T}, T, 1}).(x, Partials.(tuple.(v_)))
6+
return vec(ForwardDiff.partials.(vec(f(y)), 1))
7+
end
8+
9+
struct JacVecOperator{F, X}
10+
f::F
11+
x::X
12+
end
13+
14+
(jvp::JacVecOperator)(v, _, _) = jvp_forwarddiff(jvp.f, jvp.x, v)
15+
16+
"""
17+
SimpleJNFK()
18+
19+
"""
20+
struct SimpleJFNK end
21+
22+
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleJFNK, args...;
23+
abstol = nothing, reltol= nothing, maxiters = 1000, linsolve_kwargs = (;), kwargs...)
24+
iip = SciMLBase.isinplace(prob)
25+
@assert !iip "SimpleJFNK does not support inplace problems"
26+
27+
f = Base.Fix2(prob.f, prob.p)
28+
x = float(prob.u0)
29+
fx = f(x)
30+
T = typeof(x)
31+
32+
atol = abstol !== nothing ? abstol :
33+
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)
34+
rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5)
35+
36+
op = FunctionOperator(JacVecOperator(f, x), x)
37+
linprob = LinearProblem(op, -fx)
38+
lincache = init(linprob, SimpleGMRES(); abstol, reltol, maxiters, linsolve_kwargs...)
39+
40+
for i in 1:maxiters
41+
iszero(fx) &&
42+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
43+
44+
linsol = solve!(lincache)
45+
x .-= linsol.u
46+
lincache = linsol.cache
47+
48+
# FIXME: not nothing
49+
if isapprox(x, nothing; atol, rtol)
50+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
51+
end
52+
end
53+
54+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
55+
end

0 commit comments

Comments
 (0)