From f78872b3e426c4c834ec605301b0a9d8ee77d3d8 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Thu, 12 Sep 2024 11:36:11 -0400 Subject: [PATCH] go back to closure but single definition --- Project.toml | 2 +- ext/OptimizationEnzymeExt.jl | 2 +- ext/OptimizationZygoteExt.jl | 54 ++++++++-------------------------- src/OptimizationDIExt.jl | 54 ++++++++-------------------------- src/OptimizationDISparseExt.jl | 54 ++++++++-------------------------- 5 files changed, 38 insertions(+), 128 deletions(-) diff --git a/Project.toml b/Project.toml index 3f39d0b..46b0c6b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "OptimizationBase" uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb" authors = ["Vaibhav Dixit and contributors"] -version = "2.0.2" +version = "2.0.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/OptimizationEnzymeExt.jl b/ext/OptimizationEnzymeExt.jl index 5bc6c2e..35e223a 100644 --- a/ext/OptimizationEnzymeExt.jl +++ b/ext/OptimizationEnzymeExt.jl @@ -563,7 +563,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x return res_vjp end elseif cons_vjp == true && cons !== nothing - cons_vjp! = (θ, σ) -> f.cons_vjp(θ, σ, p) + cons_vjp! = (θ, v) -> f.cons_vjp(θ, v, p) else cons_vjp! = nothing end diff --git a/ext/OptimizationZygoteExt.jl b/ext/OptimizationZygoteExt.jl index d830d3a..138d604 100644 --- a/ext/OptimizationZygoteExt.jl +++ b/ext/OptimizationZygoteExt.jl @@ -43,10 +43,7 @@ function OptimizationBase.instantiate_function( end end elseif g == true - grad = (G, θ) -> f.grad(G, θ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - grad = (G, θ, p) -> f.grad(G, θ, p) - end + grad = (G, θ, p = p) -> f.grad(G, θ, p) else grad = nothing end @@ -67,10 +64,7 @@ function OptimizationBase.instantiate_function( end end elseif fg == true - fg! = (G, θ) -> f.fg(G, θ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - fg! = (G, θ, p) -> f.fg(G, θ, p) - end + fg! = (G, θ, p = p) -> f.fg(G, θ, p) else fg! = nothing end @@ -89,10 +83,7 @@ function OptimizationBase.instantiate_function( end end elseif h == true - hess = (H, θ) -> f.hess(H, θ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - hess = (H, θ, p) -> f.hess(H, θ, p) - end + hess = (H, θ, p = p) -> f.hess(H, θ, p) else hess = nothing end @@ -110,10 +101,7 @@ function OptimizationBase.instantiate_function( end end elseif fgh == true - fgh! = (G, H, θ) -> f.fgh(G, H, θ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - fgh! = (G, H, θ, p) -> f.fgh(G, H, θ, p) - end + fgh! = (G, H, θ, p = p) -> f.fgh(G, H, θ, p) else fgh! = nothing end @@ -130,10 +118,7 @@ function OptimizationBase.instantiate_function( end end elseif hv == true - hv! = (H, θ, v) -> f.hv(H, θ, v, p) - if p !== SciMLBase.NullParameters() && p !== nothing - hv! = (H, θ, v, p) -> f.hv(H, θ, v, p) - end + hv! = (H, θ, v, p = p) -> f.hv(H, θ, v, p) else hv! = nothing end @@ -268,7 +253,7 @@ function OptimizationBase.instantiate_function( end end elseif cons !== nothing && lag_h == true - lag_h! = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) + lag_h! = (res, θ, σ, μ, p = p) -> f.lag_h(res, θ, σ, μ, p) else lag_h! = nothing end @@ -324,10 +309,7 @@ function OptimizationBase.instantiate_function( end end elseif g == true - grad = (G, θ) -> f.grad(G, θ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - grad = (G, θ, p) -> f.grad(G, θ, p) - end + grad = (G, θ, p = p) -> f.grad(G, θ, p) else grad = nothing end @@ -348,10 +330,7 @@ function OptimizationBase.instantiate_function( end end elseif fg == true - fg! = (G, θ) -> f.fg(G, θ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - fg! = (G, θ, p) -> f.fg(G, θ, p) - end + fg! = (G, θ, p = p) -> f.fg(G, θ, p) else fg! = nothing end @@ -373,10 +352,7 @@ function OptimizationBase.instantiate_function( end end elseif h == true - hess = (H, θ) -> f.hess(H, θ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - hess = (H, θ, p) -> f.hess(H, θ, p) - end + hess = (H, θ, p = p) -> f.hess(H, θ, p) else hess = nothing end @@ -395,10 +371,7 @@ function OptimizationBase.instantiate_function( end end elseif fgh == true - fgh! = (G, H, θ) -> f.fgh(G, H, θ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - fgh! = (G, H, θ, p) -> f.fgh(G, H, θ, p) - end + fgh!(G, H, θ, p = p) = f.fgh(G, H, θ, p) else fgh! = nothing end @@ -415,10 +388,7 @@ function OptimizationBase.instantiate_function( end end elseif hv == true - hv! = (H, θ, v) -> f.hv(H, θ, v, p) - if p !== SciMLBase.NullParameters() && p !== nothing - hv! = (H, θ, v, p) -> f.hv(H, θ, v, p) - end + hv! = (H, θ, v, p = p) -> f.hv(H, θ, v, p) else hv! = nothing end @@ -564,7 +534,7 @@ function OptimizationBase.instantiate_function( end end elseif cons !== nothing && cons_h == true - lag_h! = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) + lag_h! = (res, θ, σ, μ, p = p) -> f.lag_h(res, θ, σ, μ, p) else lag_h! = nothing end diff --git a/src/OptimizationDIExt.jl b/src/OptimizationDIExt.jl index 82826aa..494f6b6 100644 --- a/src/OptimizationDIExt.jl +++ b/src/OptimizationDIExt.jl @@ -50,10 +50,7 @@ function instantiate_function( end end elseif g == true - grad = (G, θ) -> f.grad(G, θ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - grad = (G, θ, p) -> f.grad(G, θ, p) - end + grad = (G, θ, p = p) -> f.grad(G, θ, p) else grad = nothing end @@ -74,10 +71,7 @@ function instantiate_function( end end elseif fg == true - fg! = (G, θ) -> f.fg(G, θ, p) - if p !== SciMLBase.NullParameters() - fg! = (G, θ, p) -> f.fg(G, θ, p) - end + fg! = (G, θ, p = p) -> f.fg(G, θ, p) else fg! = nothing end @@ -96,10 +90,7 @@ function instantiate_function( end end elseif h == true - hess = (H, θ) -> f.hess(H, θ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - hess = (H, θ, p) -> f.hess(H, θ, p) - end + hess = (H, θ, p = p) -> f.hess(H, θ, p) else hess = nothing end @@ -119,10 +110,7 @@ function instantiate_function( end end elseif fgh == true - fgh! = (G, H, θ) -> f.fgh(G, H, θ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - fgh! = (G, H, θ, p) -> f.fgh(G, H, θ, p) - end + fgh! = (G, H, θ, p = p) -> f.fgh(G, H, θ, p) else fgh! = nothing end @@ -139,10 +127,7 @@ function instantiate_function( end end elseif hv == true - hv! = (H, θ, v) -> f.hv(H, θ, v, p) - if p !== SciMLBase.NullParameters() && p !== nothing - hv! = (H, θ, v, p) -> f.hv(H, θ, v, p) - end + hv! = (H, θ, v, p = p) -> f.hv(H, θ, v, p) else hv! = nothing end @@ -277,7 +262,7 @@ function instantiate_function( end end elseif lag_h == true && cons !== nothing - lag_h! = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p) + lag_h! = (res, θ, σ, μ, p = p) -> f.lag_h(res, θ, σ, μ, p) else lag_h! = nothing end @@ -334,10 +319,7 @@ function instantiate_function( end end elseif g == true - grad = (θ) -> f.grad(θ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - grad = (θ, p) -> f.grad(θ, p) - end + grad = (θ, p = p) -> f.grad(θ, p) else grad = nothing end @@ -358,10 +340,7 @@ function instantiate_function( end end elseif fg == true - fg! = (θ) -> f.fg(θ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - fg! = (θ, p) -> f.fg(θ, p) - end + fg! = (θ, p = p) -> f.fg(θ, p) else fg! = nothing end @@ -380,10 +359,7 @@ function instantiate_function( end end elseif h == true - hess = (θ) -> f.hess(θ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - hess = (θ, p) -> f.hess(θ, p) - end + hess = (θ, p = p) -> f.hess(θ, p) else hess = nothing end @@ -401,10 +377,7 @@ function instantiate_function( end end elseif fgh == true - fgh! = (θ) -> f.fgh(θ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - fgh! = (θ, p) -> f.fgh(θ, p) - end + fgh! = (θ, p = p) -> f.fgh(θ, p) else fgh! = nothing end @@ -421,10 +394,7 @@ function instantiate_function( end end elseif hv == true - hv! = (θ, v) -> f.hv(θ, v, p) - if p !== SciMLBase.NullParameters() && p !== nothing - hv! = (θ, v, p) -> f.hv(θ, v, p) - end + hv! = (θ, v, p = p) -> f.hv(θ, v, p) else hv! = nothing end @@ -530,7 +500,7 @@ function instantiate_function( end end elseif lag_h == true && cons !== nothing - lag_h! = (θ, σ, λ) -> f.lag_h(θ, σ, λ, p) + lag_h! = (θ, σ, λ, p = p) -> f.lag_h(θ, σ, λ, p) else lag_h! = nothing end diff --git a/src/OptimizationDISparseExt.jl b/src/OptimizationDISparseExt.jl index b0ec48b..d9b17e7 100644 --- a/src/OptimizationDISparseExt.jl +++ b/src/OptimizationDISparseExt.jl @@ -127,10 +127,7 @@ function instantiate_function( end end elseif g == true - grad = (G, θ) -> f.grad(G, θ, p) - if p !== SciMLBase.NullParameters() - grad = (G, θ, p) -> f.grad(G, θ, p) - end + grad = (G, θ, p = p) -> f.grad(G, θ, p) else grad = nothing end @@ -151,10 +148,7 @@ function instantiate_function( end end elseif fg == true - fg! = (G, θ) -> f.fg(G, θ, p) - if p !== SciMLBase.NullParameters() - fg! = (G, θ, p) -> f.fg(G, θ, p) - end + fg! = (G, θ, p = p) -> f.fg(G, θ, p) else fg! = nothing end @@ -176,10 +170,7 @@ function instantiate_function( end end elseif h == true - hess = (H, θ) -> f.hess(H, θ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - hess = (H, θ, p) -> f.hess(H, θ, p) - end + hess = (H, θ, p = p) -> f.hess(H, θ, p) else hess = nothing end @@ -198,10 +189,7 @@ function instantiate_function( end end elseif fgh == true - fgh! = (G, H, θ) -> f.fgh(G, H, θ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - fgh! = (G, H, θ, p) -> f.fgh(G, H, θ, p) - end + fgh! = (G, H, θ, p = p) -> f.fgh(G, H, θ, p) else fgh! = nothing end @@ -218,10 +206,7 @@ function instantiate_function( end end elseif hv == true - hv! = (H, θ, v) -> f.hv(H, θ, v, p) - if p !== SciMLBase.NullParameters() && p !== nothing - hv! = (H, θ, v, p) -> f.hv(H, θ, v, p) - end + hv! = (H, θ, v, p = p) -> f.hv(H, θ, v, p) else hv! = nothing end @@ -369,10 +354,7 @@ function instantiate_function( end end elseif lag_h == true - lag_h! = (H, θ, σ, λ) -> f.lag_h(H, θ, σ, λ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - lag_h! = (H, θ, σ, λ, p) -> f.lag_h(H, θ, σ, λ, p) - end + lag_h! = (H, θ, σ, λ, p = p) -> f.lag_h(H, θ, σ, λ, p) else lag_h! = nothing end @@ -428,7 +410,7 @@ function instantiate_function( end end elseif g == true - grad = (θ) -> f.grad(θ, p) + grad = (θ, p = p) -> f.grad(θ, p) else grad = nothing end @@ -449,7 +431,7 @@ function instantiate_function( end end elseif fg == true - fg! = (θ) -> f.fg(θ, p) + fg! = (θ, p = p) -> f.fg(θ, p) else fg! = nothing end @@ -469,10 +451,7 @@ function instantiate_function( end end elseif fgh == true - fgh! = (θ) -> f.fgh(θ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - fgh! = (θ, p) -> f.fgh(θ, p) - end + fgh! = (θ, p = p) -> f.fgh(θ, p) else fgh! = nothing end @@ -494,10 +473,7 @@ function instantiate_function( end end elseif h == true - hess = (θ) -> f.hess(θ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - hess = (θ, p) -> f.hess(θ, p) - end + hess = (θ, p = p) -> f.hess(θ, p) else hess = nothing end @@ -515,10 +491,7 @@ function instantiate_function( end end elseif hv == true - hv! = (θ, v) -> f.hv(θ, v, p) - if p !== SciMLBase.NullParameters() && p !== nothing - hv! = (θ, v, p) -> f.hv(θ, v, p) - end + hv! = (θ, v, p = p) -> f.hv(θ, v, p) else hv! = nothing end @@ -632,10 +605,7 @@ function instantiate_function( end end elseif lag_h == true && cons !== nothing - lag_h! = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p) - if p !== SciMLBase.NullParameters() && p !== nothing - lag_h! = (θ, σ, μ, p) -> f.lag_h(θ, σ, μ, p) - end + lag_h! = (θ, σ, μ, p = p) -> f.lag_h(θ, σ, μ, p) else lag_h! = nothing end