Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,24 @@ function set_objective(model::Model, obj)
end

function add_constraint(
model::Model,
model::Model{T},
func,
set::Union{
MOI.GreaterThan{Float64},
MOI.LessThan{Float64},
MOI.Interval{Float64},
MOI.EqualTo{Float64},
MOI.GreaterThan{T},
MOI.LessThan{T},
MOI.Interval{T},
MOI.EqualTo{T},
},
)
) where {T}
f = parse_expression(model, func)
model.last_constraint_index += 1
index = ConstraintIndex(model.last_constraint_index)
model.constraints[index] = Constraint(f, set)
return index
end

function add_parameter(model::Model, value::Float64)
push!(model.parameters, value)
function add_parameter(model::Model{T}, value::Real) where {T}
push!(model.parameters, convert(T, value))
return ParameterIndex(length(model.parameters))
end

Expand Down
13 changes: 9 additions & 4 deletions src/parse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,19 @@ function parse_expression(
return
end

function parse_expression(data::Model, input)
expr = Expression()
function parse_expression(data::Model{T}, input) where {T}
expr = Expression{T}()
parse_expression(data, expr, input, -1)
return expr
end

function parse_expression(::Model, expr::Expression, x::Real, parent_index::Int)
push!(expr.values, convert(Float64, x)::Float64)
function parse_expression(
::Model,
expr::Expression{T},
x::Real,
parent_index::Int,
) where {T}
push!(expr.values, convert(T, x)::T)
push!(expr.nodes, Node(NODE_VALUE, length(expr.values), parent_index))
return
end
Expand Down
6 changes: 3 additions & 3 deletions src/parse_moi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function _parse_moi_stack!(
::Vector{Tuple{Int,Any}},
data::Model,
expr::Expression,
x::Union{Float64,MOI.VariableIndex},
x::Union{Real,MOI.VariableIndex},
parent_index::Int,
)
return parse_expression(data, expr, x, parent_index)
Expand Down Expand Up @@ -188,7 +188,7 @@ function _parse_moi_stack!(
stack::Vector{Tuple{Int,Any}},
data::Model,
expr::Expression,
x::Matrix{Float64},
x::AbstractMatrix{<:Real},
parent_index::Int,
)
m, n = size(x)
Expand All @@ -210,7 +210,7 @@ function _parse_moi_stack!(
stack::Vector{Tuple{Int,Any}},
data::Model,
expr::Expression,
x::Vector{Float64},
x::AbstractVector{<:Real},
parent_index::Int,
)
vect_id = data.operators.multivariate_operator_to_id[:vect]
Expand Down
57 changes: 29 additions & 28 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ The core type that represents a nonlinear expression. See the MathOptInterface
documentation for information on how the nodes and values form an expression
tree.
"""
struct Expression
struct Expression{T}
nodes::Vector{Node}
values::Vector{Float64}
Expression() = new(Node[], Float64[])
values::Vector{T}
Expression{T}() where {T} = new{T}(Node[], T[])
end

function Base.:(==)(x::Expression, y::Expression)
Expand All @@ -38,13 +38,13 @@ end
A type to hold information relating to the nonlinear constraint `f(x) in S`,
where `f(x)` is defined by `.expression`, and `S` is `.set`.
"""
struct Constraint
expression::Expression
struct Constraint{T}
expression::Expression{T}
set::Union{
MOI.LessThan{Float64},
MOI.GreaterThan{Float64},
MOI.EqualTo{Float64},
MOI.Interval{Float64},
MOI.LessThan{T},
MOI.GreaterThan{T},
MOI.EqualTo{T},
MOI.Interval{T},
}
end

Expand Down Expand Up @@ -101,7 +101,7 @@ function _subexpression_and_linearity(
return _SubexpressionStorage(
nodes,
adj,
expr.values,
convert(Vector{Float64}, expr.values),
partials_storage_ϵ,
linearity[1],
),
Expand Down Expand Up @@ -192,38 +192,39 @@ It has the following fields:
* `parameters::Vector{Float64}` : holds the current values of the parameters.
* `operators::OperatorRegistry` : stores the operators used in the model.
"""
mutable struct Model
objective::Union{Nothing,Expression}
expressions::Vector{Expression}
constraints::OrderedCollections.OrderedDict{ConstraintIndex,Constraint}
parameters::Vector{Float64}
mutable struct Model{T}
objective::Union{Nothing,Expression{T}}
expressions::Vector{Expression{T}}
constraints::OrderedCollections.OrderedDict{ConstraintIndex,Constraint{T}}
parameters::Vector{T}
operators::OperatorRegistry
# This is a private field, used only to increment the ConstraintIndex.
last_constraint_index::Int64
function Model()
model = new(
function Model{T}() where {T}
return new{T}(
nothing,
Expression[],
OrderedCollections.OrderedDict{ConstraintIndex,Constraint}(),
Float64[],
Expression{T}[],
OrderedCollections.OrderedDict{ConstraintIndex,Constraint{T}}(),
T[],
OperatorRegistry(),
0,
)
return model
end
end

mutable struct Evaluator{B} <: MOI.AbstractNLPEvaluator
Model() = Model{Float64}()

mutable struct Evaluator{T,B} <: MOI.AbstractNLPEvaluator
# The internal datastructure.
model::Model
model::Model{T}
# The abstract-differentiation backend
backend::B
# ordered_constraints is needed because `OrderedDict` doesn't support
# looking up a key by the linear index.
ordered_constraints::Vector{ConstraintIndex}
# Storage for the NLPBlockDual, so that we can query the dual of individual
# constraints without needing to query the full vector each time.
constraint_dual::Vector{Float64}
constraint_dual::Vector{T}
# Timers
initialize_timer::Float64
eval_objective_timer::Float64
Expand All @@ -236,14 +237,14 @@ mutable struct Evaluator{B} <: MOI.AbstractNLPEvaluator
eval_hessian_lagrangian_timer::Float64

function Evaluator(
model::Model,
model::Model{T},
backend::B = nothing,
) where {B<:Union{Nothing,MOI.AbstractNLPEvaluator}}
return new{B}(
) where {T,B<:Union{Nothing,MOI.AbstractNLPEvaluator}}
return new{T,B}(
model,
backend,
MOI.ConstraintIndex[],
Float64[],
T[],
0.0,
0.0,
0.0,
Expand Down
85 changes: 85 additions & 0 deletions test/ArrayDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,91 @@ function test_objective_broadcasted_pow_cubed()
return
end

function test_model_typed_default_is_float64()
model = ArrayDiff.Model()
@test model isa ArrayDiff.Model{Float64}
@test model.parameters isa Vector{Float64}
@test model.expressions isa Vector{ArrayDiff.Expression{Float64}}
@test model.constraints isa ArrayDiff.OrderedCollections.OrderedDict{
ArrayDiff.ConstraintIndex,
ArrayDiff.Constraint{Float64},
}
return
end

function test_model_typed_float32_parse_value()
model = ArrayDiff.Model{Float32}()
x = MOI.VariableIndex(1)
ArrayDiff.set_objective(model, :($x + 1.5))
obj = something(model.objective)
@test obj isa ArrayDiff.Expression{Float32}
@test obj.values isa Vector{Float32}
@test obj.values == Float32[1.5]
return
end

function test_model_typed_float32_add_parameter()
model = ArrayDiff.Model{Float32}()
p = ArrayDiff.add_parameter(model, 2.5)
@test p isa ArrayDiff.ParameterIndex
@test model.parameters isa Vector{Float32}
@test model.parameters == Float32[2.5]
return
end

function test_model_typed_float32_add_constraint()
model = ArrayDiff.Model{Float32}()
x = MOI.VariableIndex(1)
set = MOI.LessThan{Float32}(3.0f0)
idx = ArrayDiff.add_constraint(model, :($x + 1.0), set)
@test idx isa ArrayDiff.ConstraintIndex
c = model.constraints[idx]
@test c isa ArrayDiff.Constraint{Float32}
@test c.expression isa ArrayDiff.Expression{Float32}
@test c.expression.values == Float32[1.0]
@test c.set === set
return
end

function test_model_typed_float32_add_expression()
model = ArrayDiff.Model{Float32}()
x = MOI.VariableIndex(1)
idx = ArrayDiff.add_expression(model, :($x * 2.0))
@test idx isa ArrayDiff.ExpressionIndex
e = model[idx]
@test e isa ArrayDiff.Expression{Float32}
@test e.values == Float32[2.0]
return
end

function test_model_typed_bigfloat_constraint_set()
model = ArrayDiff.Model{BigFloat}()
x = MOI.VariableIndex(1)
set = MOI.GreaterThan{BigFloat}(big"1.0")
idx = ArrayDiff.add_constraint(model, :($x), set)
c = model.constraints[idx]
@test c isa ArrayDiff.Constraint{BigFloat}
@test c.set === set
return
end

function test_model_typed_float32_evaluator_runs()
# End-to-end smoke test: parsing happens in T = Float32, AD evaluation
# converts to Float64 internally.
model = ArrayDiff.Model{Float32}()
x = MOI.VariableIndex(1)
ArrayDiff.set_objective(model, :(2 * dot([$x], [$x]) + 1.0))
evaluator = ArrayDiff.Evaluator(model, ArrayDiff.Mode(), [x])
@test evaluator isa ArrayDiff.Evaluator{Float32}
MOI.initialize(evaluator, [:Grad])
xv = [1.5]
@test MOI.eval_objective(evaluator, xv) ≈ 2 * xv[1]^2 + 1.0
g = ones(1)
MOI.eval_objective_gradient(evaluator, g, xv)
@test g[1] ≈ 4 * xv[1]
return
end

end # module

TestArrayDiff.runtests()
Loading