Skip to content

Commit 87cbf74

Browse files
orebasclaude
andcommitted
Add S2/S3 composite interpolators and Matérn-5/2 kernel variant
New interpolator pipeline stages: - S2 (AAA→MLE): barycentric AAA on raw data, refined via MLE optimization - S3 (GP→AAA→MLE): GP smoothing, then AAA extraction, then MLE refinement - S3 variants: SE, RQ, SE+RQ, SE×RQ, Matérn-5/2 kernels Also adds InterpolatorAGPRobustMatern52, helper exports (is_gp_interpolator, s3_symbol, s3_refine_gp), and gp_s3_refinement option flag. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 600fe9e commit 87cbf74

4 files changed

Lines changed: 372 additions & 20 deletions

File tree

src/ODEParameterEstimation.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,11 @@ export substr_test, global_unident_test, sum_test, trivial_unident
142142
export EstimationOptions, SystemSolverMethod, InterpolatorMethod, PolishMethod, EstimationFlow
143143
export FlowDeprecated, FlowStandard, FlowDirectOpt
144144
export SolverRS, SolverHC, SolverNLOpt, SolverFastNLOpt, SolverRobust
145-
export InterpolatorAAAD, InterpolatorAAADGPR, InterpolatorAAADOld, InterpolatorFHD, InterpolatorAGP, InterpolatorAGPRobust, InterpolatorAGPRobustRQ, InterpolatorAGPRobustSEpRQ, InterpolatorAGPRobustSExRQ, InterpolatorCustom
145+
export InterpolatorAAAD, InterpolatorAAADGPR, InterpolatorAAADOld, InterpolatorFHD, InterpolatorAGP, InterpolatorAGPRobust, InterpolatorAGPRobustRQ, InterpolatorAGPRobustSEpRQ, InterpolatorAGPRobustSExRQ, InterpolatorAGPRobustMatern52, InterpolatorS2AAAMLE, InterpolatorS3SE, InterpolatorS3RQ, InterpolatorS3SEpRQ, InterpolatorS3SExRQ, InterpolatorS3Matern52, InterpolatorCustom
146146
export PolishNewtonTrust, PolishLevenberg, PolishGaussNewton, PolishBFGS, PolishLBFGS
147147
export get_solver_function, get_interpolator_function, get_polish_optimizer, get_ad_backend
148148
export interpolator_method_to_symbol, resolve_interpolator_list, setup_identifiability, compute_shooting_indices
149+
export is_gp_interpolator, is_matern_interpolator, s3_symbol, s3_refine_gp
149150
export merge_options, validate_options, print_options, get_solver_options_dict
150151
export optimized_multishot_parameter_estimation
151152

src/core/derivatives.jl

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,3 +1198,166 @@ function agp_gpr_robust(xs::AbstractArray{T}, ys::AbstractArray{T};
11981198

11991199
return AGPInterpolator(mean_pred, std_pred, f_posterior, 0.0, 1.0, y_mean, y_std)
12001200
end
1201+
1202+
# ============================================================================
1203+
# S3 Refinement: GP → AAA → MLE composite interpolator
1204+
# ============================================================================
1205+
1206+
"""
1207+
_BaryResult
1208+
1209+
Lightweight struct matching the BaryRational.AAAapprox field interface (x, w, f).
1210+
Used to wrap MLE-refined barycentric parameters into an AAADapprox.
1211+
"""
1212+
struct _BaryResult
1213+
x::Vector{Float64} # support points
1214+
w::Vector{Float64} # barycentric weights
1215+
f::Vector{Float64} # function values at support points
1216+
end
1217+
1218+
"""
1219+
_mle_refine_bary(z, w_init, f_init, t_data, y_data; maxiter, g_tol) -> (z, w_opt, f_opt)
1220+
1221+
MLE refinement of barycentric weights and function values against raw noisy data.
1222+
Uses LBFGS with a hard gauge constraint (w₁ = 1) to avoid scale ambiguity.
1223+
"""
1224+
function _mle_refine_bary(z::Vector{Float64}, w_init::Vector{Float64}, f_init::Vector{Float64},
1225+
t_data::Vector{Float64}, y_data::Vector{Float64};
1226+
maxiter::Int = 20000, g_tol::Float64 = 1e-15)
1227+
m = length(z)
1228+
m < 2 && return z, copy(w_init), copy(f_init)
1229+
1230+
# Hard gauge: fix w₁ = 1
1231+
w_scaled = w_init ./ w_init[1]
1232+
θ0 = vcat(w_scaled[2:end], copy(f_init))
1233+
1234+
fg! = function (F, G, θ)
1235+
w = vcat(one(eltype(θ)), θ[1:m-1])
1236+
fv = θ[m:2m-1]
1237+
loss = zero(eltype(θ))
1238+
G !== nothing && fill!(G, 0.0)
1239+
for j in eachindex(t_data)
1240+
tj = t_data[j]
1241+
exact = 0
1242+
for i in 1:m
1243+
tj == z[i] && (exact = i; break)
1244+
end
1245+
if exact > 0
1246+
res = y_data[j] - fv[exact]
1247+
loss += res^2
1248+
G !== nothing && (G[m-1+exact] -= 2.0 * res)
1249+
else
1250+
num, den = zero(eltype(θ)), zero(eltype(θ))
1251+
for i in 1:m
1252+
a = w[i] / (tj - z[i])
1253+
num += a * fv[i]; den += a
1254+
end
1255+
rv = num / den
1256+
res = y_data[j] - rv
1257+
loss += res^2
1258+
if G !== nothing
1259+
id = 1.0 / den
1260+
for i in 1:m
1261+
ai = 1.0 / (tj - z[i])
1262+
G[m-1+i] -= 2.0 * res * w[i] * ai * id
1263+
i >= 2 && (G[i-1] -= 2.0 * res * ai * (fv[i] - rv) * id)
1264+
end
1265+
end
1266+
end
1267+
end
1268+
1269+
return loss
1270+
end
1271+
1272+
od = Optim.OnceDifferentiable(Optim.only_fg!(fg!), θ0)
1273+
opts = Optim.Options(iterations = maxiter, g_tol = g_tol,
1274+
f_reltol = 1e-16, x_reltol = 1e-16, show_trace = false)
1275+
result = Optim.optimize(od, θ0, Optim.LBFGS(), opts)
1276+
θ_best = Optim.minimizer(result)
1277+
# Warm restart
1278+
result2 = Optim.optimize(od, θ_best, Optim.LBFGS(), opts)
1279+
θ_final = Optim.minimum(result2) < Optim.minimum(result) ?
1280+
Optim.minimizer(result2) : θ_best
1281+
1282+
w_opt = vcat(1.0, θ_final[1:m-1])
1283+
f_opt = θ_final[m:2m-1]
1284+
return z, w_opt, f_opt
1285+
end
1286+
1287+
"""
1288+
s3_refine_gp(gp_interpolant::AbstractInterpolator, t_vector, y_raw;
1289+
aaa_tol, mmax, maxiter, g_tol) -> AAADapprox
1290+
1291+
S3 strategy: take a GP interpolant, evaluate it at data points to get denoised values,
1292+
run AAA on the denoised values (tight tolerance for more support points), then
1293+
MLE-refine the barycentric weights against the RAW noisy data.
1294+
1295+
Returns an AAADapprox whose call method uses baryEval (AD-friendly at support points).
1296+
"""
1297+
function s3_refine_gp(gp_interpolant::AbstractInterpolator,
1298+
t_vector::Vector{Float64}, y_raw::Vector{Float64};
1299+
aaa_tol::Float64 = 1e-14, mmax::Int = 200,
1300+
maxiter::Int = 20000, g_tol::Float64 = 1e-15)
1301+
# Step 1: Evaluate GP at data points → denoised values
1302+
y_gp = [gp_interpolant(t) for t in t_vector]
1303+
1304+
# Step 2: AAA on denoised data with tight tolerance → many support points
1305+
aaa_result = BaryRational.aaa(t_vector, y_gp; tol = aaa_tol, mmax = mmax)
1306+
z = copy(aaa_result.x)
1307+
w = copy(aaa_result.w)
1308+
f = copy(aaa_result.f)
1309+
1310+
# Step 3: MLE refine weights+values against raw noisy data
1311+
z, w, f = _mle_refine_bary(z, w, f, t_vector, y_raw; maxiter = maxiter, g_tol = g_tol)
1312+
1313+
# Step 4: Wrap in AAADapprox for baryEval-based evaluation
1314+
refined = _BaryResult(z, w, f)
1315+
return AAADapprox(refined)
1316+
end
1317+
1318+
"""
1319+
s2_aaa_mle_interpolator(xs, ys; aaa_tol, mmax, maxiter, g_tol) -> AAADapprox
1320+
1321+
S2 strategy: run AAA directly on raw (noisy) data to harvest support points,
1322+
then MLE-refine barycentric weights against the same data. No GP denoising step.
1323+
1324+
This is the GP-free counterpart to S3. At high noise, skipping GP avoids
1325+
over-smoothing; at low noise, S3 typically wins.
1326+
"""
1327+
function s2_aaa_mle_interpolator(xs::AbstractArray, ys::AbstractArray;
1328+
aaa_tol::Float64 = 1e-10, mmax::Int = 200,
1329+
maxiter::Int = 20000, g_tol::Float64 = 1e-15)
1330+
t_vec = Vector{Float64}(xs)
1331+
y_vec = Vector{Float64}(ys)
1332+
1333+
# AAA directly on raw data
1334+
aaa_result = BaryRational.aaa(t_vec, y_vec; tol = aaa_tol, mmax = mmax)
1335+
z = copy(aaa_result.x)
1336+
w = copy(aaa_result.w)
1337+
f = copy(aaa_result.f)
1338+
1339+
# MLE refine weights+values against the same raw data
1340+
z, w, f = _mle_refine_bary(z, w, f, t_vec, y_vec; maxiter = maxiter, g_tol = g_tol)
1341+
1342+
refined = _BaryResult(z, w, f)
1343+
return AAADapprox(refined)
1344+
end
1345+
1346+
"""
1347+
_s3_interpolator(xs, ys; kernel_type=:se) -> AAADapprox
1348+
1349+
Standalone S3 composite interpolator factory.
1350+
Fits a GP with the given kernel, then refines via AAA + MLE barycentric weights.
1351+
"""
1352+
function _s3_interpolator(xs::AbstractArray, ys::AbstractArray; kernel_type::Symbol = :se)
1353+
t_vec = Vector{Float64}(xs)
1354+
y_vec = Vector{Float64}(ys)
1355+
gp = agp_gpr_robust(t_vec, y_vec; kernel_type = kernel_type)
1356+
return s3_refine_gp(gp, t_vec, y_vec)
1357+
end
1358+
1359+
s3_se_interpolator(xs, ys) = _s3_interpolator(xs, ys; kernel_type = :se)
1360+
s3_rq_interpolator(xs, ys) = _s3_interpolator(xs, ys; kernel_type = :rq)
1361+
s3_se_plus_rq_interpolator(xs, ys) = _s3_interpolator(xs, ys; kernel_type = :se_plus_rq)
1362+
s3_se_times_rq_interpolator(xs, ys) = _s3_interpolator(xs, ys; kernel_type = :se_times_rq)
1363+
s3_matern52_interpolator(xs, ys) = _s3_interpolator(xs, ys; kernel_type = :matern52)

src/core/optimized_multishot_estimation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1541,14 +1541,14 @@ function optimized_multishot_parameter_estimation(PEP::ParameterEstimationProble
15411541
end
15421542
end # end if use_param_homotopy
15431543

1544-
# Accumulate this interpolator's solutions into the global pool
1544+
# Accumulate this pass's solutions into the global pool
15451545
append!(all_solutions, interp_solutions)
15461546
append!(solution_time_indices, interp_time_indices)
15471547
for _ in 1:length(interp_solutions)
15481548
push!(solution_interpolator_sources, interp_sym)
15491549
end
15501550

1551-
if !opts.nooutput && n_interpolators > 1
1551+
if !opts.nooutput
15521552
println(" -> $interp_sym produced $(length(interp_solutions)) solutions")
15531553
end
15541554

0 commit comments

Comments
 (0)