Skip to content

Commit

Permalink
Merge pull request #103 from SciML/optjlintegration
Browse files Browse the repository at this point in the history
Single closure definition to avoid overwriting in manual supplied derivatives cases
  • Loading branch information
Vaibhavdixit02 committed Sep 12, 2024
2 parents b2cbc1a + f78872b commit ba5e519
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 128 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimizationBase"
uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
authors = ["Vaibhav Dixit <vaibhavyashdixit@gmail.com> and contributors"]
version = "2.0.2"
version = "2.0.3"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion ext/OptimizationEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
return res_vjp
end
elseif cons_vjp == true && cons !== nothing
cons_vjp! = (θ, σ) -> f.cons_vjp(θ, σ, p)
cons_vjp! = (θ, v) -> f.cons_vjp(θ, v, p)
else
cons_vjp! = nothing
end
Expand Down
54 changes: 12 additions & 42 deletions ext/OptimizationZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ function OptimizationBase.instantiate_function(
end
end
elseif g == true
grad = (G, θ) -> f.grad(G, θ, p)
if p !== SciMLBase.NullParameters() && p !== nothing
grad = (G, θ, p) -> f.grad(G, θ, p)
end
grad = (G, θ, p = p) -> f.grad(G, θ, p)
else
grad = nothing
end
Expand All @@ -67,10 +64,7 @@ function OptimizationBase.instantiate_function(
end
end
elseif fg == true
fg! = (G, θ) -> f.fg(G, θ, p)
if p !== SciMLBase.NullParameters() && p !== nothing
fg! = (G, θ, p) -> f.fg(G, θ, p)
end
fg! = (G, θ, p = p) -> f.fg(G, θ, p)
else
fg! = nothing
end
Expand All @@ -89,10 +83,7 @@ function OptimizationBase.instantiate_function(
end
end
elseif h == true
hess = (H, θ) -> f.hess(H, θ, p)
if p !== SciMLBase.NullParameters() && p !== nothing
hess = (H, θ, p) -> f.hess(H, θ, p)
end
hess = (H, θ, p = p) -> f.hess(H, θ, p)
else
hess = nothing
end
Expand All @@ -110,10 +101,7 @@ function OptimizationBase.instantiate_function(
end
end
elseif fgh == true
fgh! = (G, H, θ) -> f.fgh(G, H, θ, p)
if p !== SciMLBase.NullParameters() && p !== nothing
fgh! = (G, H, θ, p) -> f.fgh(G, H, θ, p)
end
fgh! = (G, H, θ, p = p) -> f.fgh(G, H, θ, p)
else
fgh! = nothing
end
Expand All @@ -130,10 +118,7 @@ function OptimizationBase.instantiate_function(
end
end
elseif hv == true
hv! = (H, θ, v) -> f.hv(H, θ, v, p)
if p !== SciMLBase.NullParameters() && p !== nothing
hv! = (H, θ, v, p) -> f.hv(H, θ, v, p)
end
hv! = (H, θ, v, p = p) -> f.hv(H, θ, v, p)
else
hv! = nothing
end
Expand Down Expand Up @@ -268,7 +253,7 @@ function OptimizationBase.instantiate_function(
end
end
elseif cons !== nothing && lag_h == true
lag_h! = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p)
lag_h! = (res, θ, σ, μ, p = p) -> f.lag_h(res, θ, σ, μ, p)
else
lag_h! = nothing
end
Expand Down Expand Up @@ -324,10 +309,7 @@ function OptimizationBase.instantiate_function(
end
end
elseif g == true
grad = (G, θ) -> f.grad(G, θ, p)
if p !== SciMLBase.NullParameters() && p !== nothing
grad = (G, θ, p) -> f.grad(G, θ, p)
end
grad = (G, θ, p = p) -> f.grad(G, θ, p)
else
grad = nothing
end
Expand All @@ -348,10 +330,7 @@ function OptimizationBase.instantiate_function(
end
end
elseif fg == true
fg! = (G, θ) -> f.fg(G, θ, p)
if p !== SciMLBase.NullParameters() && p !== nothing
fg! = (G, θ, p) -> f.fg(G, θ, p)
end
fg! = (G, θ, p = p) -> f.fg(G, θ, p)
else
fg! = nothing
end
Expand All @@ -373,10 +352,7 @@ function OptimizationBase.instantiate_function(
end
end
elseif h == true
hess = (H, θ) -> f.hess(H, θ, p)
if p !== SciMLBase.NullParameters() && p !== nothing
hess = (H, θ, p) -> f.hess(H, θ, p)
end
hess = (H, θ, p = p) -> f.hess(H, θ, p)
else
hess = nothing
end
Expand All @@ -395,10 +371,7 @@ function OptimizationBase.instantiate_function(
end
end
elseif fgh == true
fgh! = (G, H, θ) -> f.fgh(G, H, θ, p)
if p !== SciMLBase.NullParameters() && p !== nothing
fgh! = (G, H, θ, p) -> f.fgh(G, H, θ, p)
end
fgh!(G, H, θ, p = p) = f.fgh(G, H, θ, p)
else
fgh! = nothing
end
Expand All @@ -415,10 +388,7 @@ function OptimizationBase.instantiate_function(
end
end
elseif hv == true
hv! = (H, θ, v) -> f.hv(H, θ, v, p)
if p !== SciMLBase.NullParameters() && p !== nothing
hv! = (H, θ, v, p) -> f.hv(H, θ, v, p)
end
hv! = (H, θ, v, p = p) -> f.hv(H, θ, v, p)
else
hv! = nothing
end
Expand Down Expand Up @@ -564,7 +534,7 @@ function OptimizationBase.instantiate_function(
end
end
elseif cons !== nothing && cons_h == true
lag_h! = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p)
lag_h! = (res, θ, σ, μ, p = p) -> f.lag_h(res, θ, σ, μ, p)
else
lag_h! = nothing
end
Expand Down
54 changes: 12 additions & 42 deletions src/OptimizationDIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,7 @@ function instantiate_function(
end
end
elseif g == true
grad = (G, θ) -> f.grad(G, θ, p)
if p !== SciMLBase.NullParameters() && p !== nothing
grad = (G, θ, p) -> f.grad(G, θ, p)
end
grad = (G, θ, p = p) -> f.grad(G, θ, p)
else
grad = nothing
end
Expand All @@ -74,10 +71,7 @@ function instantiate_function(
end
end
elseif fg == true
fg! = (G, θ) -> f.fg(G, θ, p)
if p !== SciMLBase.NullParameters()
fg! = (G, θ, p) -> f.fg(G, θ, p)
end
fg! = (G, θ, p = p) -> f.fg(G, θ, p)
else
fg! = nothing
end
Expand All @@ -96,10 +90,7 @@ function instantiate_function(
end
end
elseif h == true
hess = (H, θ) -> f.hess(H, θ, p)
if p !== SciMLBase.NullParameters() && p !== nothing
hess = (H, θ, p) -> f.hess(H, θ, p)
end
hess = (H, θ, p = p) -> f.hess(H, θ, p)
else
hess = nothing
end
Expand All @@ -119,10 +110,7 @@ function instantiate_function(
end
end
elseif fgh == true
fgh! = (G, H, θ) -> f.fgh(G, H, θ, p)
if p !== SciMLBase.NullParameters() && p !== nothing
fgh! = (G, H, θ, p) -> f.fgh(G, H, θ, p)
end
fgh! = (G, H, θ, p = p) -> f.fgh(G, H, θ, p)
else
fgh! = nothing
end
Expand All @@ -139,10 +127,7 @@ function instantiate_function(
end
end
elseif hv == true
hv! = (H, θ, v) -> f.hv(H, θ, v, p)
if p !== SciMLBase.NullParameters() && p !== nothing
hv! = (H, θ, v, p) -> f.hv(H, θ, v, p)
end
hv! = (H, θ, v, p = p) -> f.hv(H, θ, v, p)
else
hv! = nothing
end
Expand Down Expand Up @@ -277,7 +262,7 @@ function instantiate_function(
end
end
elseif lag_h == true && cons !== nothing
lag_h! = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p)
lag_h! = (res, θ, σ, μ, p = p) -> f.lag_h(res, θ, σ, μ, p)
else
lag_h! = nothing
end
Expand Down Expand Up @@ -334,10 +319,7 @@ function instantiate_function(
end
end
elseif g == true
grad = (θ) -> f.grad(θ, p)
if p !== SciMLBase.NullParameters() && p !== nothing
grad = (θ, p) -> f.grad(θ, p)
end
grad = (θ, p = p) -> f.grad(θ, p)
else
grad = nothing
end
Expand All @@ -358,10 +340,7 @@ function instantiate_function(
end
end
elseif fg == true
fg! = (θ) -> f.fg(θ, p)
if p !== SciMLBase.NullParameters() && p !== nothing
fg! = (θ, p) -> f.fg(θ, p)
end
fg! = (θ, p = p) -> f.fg(θ, p)
else
fg! = nothing
end
Expand All @@ -380,10 +359,7 @@ function instantiate_function(
end
end
elseif h == true
hess = (θ) -> f.hess(θ, p)
if p !== SciMLBase.NullParameters() && p !== nothing
hess = (θ, p) -> f.hess(θ, p)
end
hess = (θ, p = p) -> f.hess(θ, p)
else
hess = nothing
end
Expand All @@ -401,10 +377,7 @@ function instantiate_function(
end
end
elseif fgh == true
fgh! = (θ) -> f.fgh(θ, p)
if p !== SciMLBase.NullParameters() && p !== nothing
fgh! = (θ, p) -> f.fgh(θ, p)
end
fgh! = (θ, p = p) -> f.fgh(θ, p)
else
fgh! = nothing
end
Expand All @@ -421,10 +394,7 @@ function instantiate_function(
end
end
elseif hv == true
hv! = (θ, v) -> f.hv(θ, v, p)
if p !== SciMLBase.NullParameters() && p !== nothing
hv! = (θ, v, p) -> f.hv(θ, v, p)
end
hv! = (θ, v, p = p) -> f.hv(θ, v, p)
else
hv! = nothing
end
Expand Down Expand Up @@ -530,7 +500,7 @@ function instantiate_function(
end
end
elseif lag_h == true && cons !== nothing
lag_h! = (θ, σ, λ) -> f.lag_h(θ, σ, λ, p)
lag_h! = (θ, σ, λ, p = p) -> f.lag_h(θ, σ, λ, p)
else
lag_h! = nothing
end
Expand Down
Loading

0 comments on commit ba5e519

Please sign in to comment.