Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: initial implementation of SCCNonlinearProblem codegen #3213

Merged
merged 22 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
17359ce
feat: initial implementation of `SCCNonlinearProblem` codegen
AayushSabharwal Nov 15, 2024
7d3b3f4
refactor: no need to re-sort SCCs
AayushSabharwal Nov 18, 2024
3e4a648
fix: minor bug fix
AayushSabharwal Nov 18, 2024
b99ab7d
fix: fix observed equations not being generated
AayushSabharwal Nov 18, 2024
04c0cf8
test: add tests for `SCCNonlinearProblem` codegen
AayushSabharwal Nov 18, 2024
5639bd1
feat: pre-compute observed equations of previous SCCs
AayushSabharwal Nov 19, 2024
26269d9
feat: subset system and pass to SCC problems
AayushSabharwal Nov 19, 2024
beb7070
refactor: improve `observed_dependency_graph` and add docstrings
AayushSabharwal Nov 20, 2024
27d7f70
feat: cache subexpressions dependent only on previous SCCs
AayushSabharwal Nov 20, 2024
691747a
feat: use SCCNonlinearProblem for initialization
AayushSabharwal Nov 26, 2024
b5fc3ed
refactor: better handle inputs in `wrap_array_vars`
AayushSabharwal Nov 26, 2024
e79dc90
fix: properly sort SCCs
AayushSabharwal Nov 26, 2024
c1e1523
feat: better handle observed variables, constants in SCCNonlinearProblem
AayushSabharwal Nov 26, 2024
205d76a
test: fix tests
AayushSabharwal Nov 26, 2024
ed9cdf3
fix: reorder system in `SCCNonlinearProblem`
AayushSabharwal Nov 26, 2024
a381930
build: bump OrdinaryDiffEqCore, OrdinaryDiffEqNonlinearSolve compats
AayushSabharwal Nov 30, 2024
52eba50
refactor: update to new `SCCNonlinearProblem` constructor
AayushSabharwal Dec 4, 2024
24d39af
test: fix test to new SciMLBase error message
AayushSabharwal Dec 4, 2024
1320fc0
refactor: separate out operating point and initializeprob construction
AayushSabharwal Dec 5, 2024
35b407c
fix: fix `remake_initialization_data` on problems with no initprob
AayushSabharwal Dec 4, 2024
9943532
refactor: add better warnings when SCC initialization cannot be used
AayushSabharwal Dec 5, 2024
0537715
refactor: propagate `use_scc` to `remake_initialization_data`
AayushSabharwal Dec 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
SCCNonlinearSolve = "9dfe8606-65a1-4bb3-9748-cb89d1561431"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Expand Down Expand Up @@ -120,13 +121,15 @@ NonlinearSolve = "3.14, 4"
OffsetArrays = "1"
OrderedCollections = "1"
OrdinaryDiffEq = "6.82.0"
OrdinaryDiffEqCore = "1.7.0"
OrdinaryDiffEqCore = "1.13.0"
OrdinaryDiffEqNonlinearSolve = "1.3.0"
PrecompileTools = "1"
REPL = "1"
RecursiveArrayTools = "3.26"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SciMLBase = "2.64"
SCCNonlinearSolve = "1.0.0"
SciMLBase = "2.66"
SciMLStructures = "1.0"
Serialization = "1"
Setfield = "0.7, 0.8, 1"
Expand Down Expand Up @@ -160,6 +163,7 @@ OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -174,4 +178,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve"]
1 change: 1 addition & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ using Distributed
import JuliaFormatter
using MLStyle
using NonlinearSolve
import SCCNonlinearSolve
using Reexport
using RecursiveArrayTools
import Graphs: SimpleDiGraph, add_edge!, incidence_matrix
Expand Down
14 changes: 10 additions & 4 deletions src/structural_transformation/bipartite_tearing/modia_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ function tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, eqs, vars
return nothing
end

function build_var_eq_matching(structure::SystemStructure, ::Type{U} = Unassigned;
varfilter::F2 = v -> true, eqfilter::F3 = eq -> true) where {U, F2, F3}
@unpack graph, solvable_graph = structure
var_eq_matching = maximal_matching(graph, eqfilter, varfilter, U)
matching_len = max(length(var_eq_matching),
maximum(x -> x isa Int ? x : 0, var_eq_matching, init = 0))
return complete(var_eq_matching, matching_len), matching_len
end

function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
::Type{U} = Unassigned;
varfilter::F2 = v -> true,
Expand All @@ -78,10 +87,7 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
# find them here [TODO: It would be good to have an explicit example of this.]

@unpack graph, solvable_graph = structure
var_eq_matching = maximal_matching(graph, eqfilter, varfilter, U)
matching_len = max(length(var_eq_matching),
maximum(x -> x isa Int ? x : 0, var_eq_matching, init = 0))
var_eq_matching = complete(var_eq_matching, matching_len)
var_eq_matching, matching_len = build_var_eq_matching(structure, U; varfilter, eqfilter)
full_var_eq_matching = copy(var_eq_matching)
var_sccs = find_var_sccs(graph, var_eq_matching)
vargraph = DiCMOBiGraph{true}(graph, 0, Matching(matching_len))
Expand Down
191 changes: 54 additions & 137 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,12 @@ object.
"""
function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys),
ps = parameters(sys); wrap_code = nothing, postprocess_fbody = nothing, states = nothing,
expression = Val{true}, eval_expression = false, eval_module = @__MODULE__, kwargs...)
expression = Val{true}, eval_expression = false, eval_module = @__MODULE__,
cachesyms::Tuple = (), kwargs...)
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system.")
end
p = reorder_parameters(sys, unwrap.(ps))
p = (reorder_parameters(sys, unwrap.(ps))..., cachesyms...)
isscalar = !(exprs isa AbstractArray)
if wrap_code === nothing
wrap_code = isscalar ? identity : (identity, identity)
Expand All @@ -187,7 +188,7 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
postprocess_fbody,
states,
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
wrap_array_vars(sys, exprs; dvs) .∘
wrap_array_vars(sys, exprs; dvs, cachesyms) .∘
wrap_parameter_dependencies(sys, isscalar),
expression = Val{true}
)
Expand All @@ -199,7 +200,7 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
postprocess_fbody,
states,
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
wrap_array_vars(sys, exprs; dvs) .∘
wrap_array_vars(sys, exprs; dvs, cachesyms) .∘
wrap_parameter_dependencies(sys, isscalar),
expression = Val{true}
)
Expand Down Expand Up @@ -231,133 +232,59 @@ end

function wrap_array_vars(
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys),
inputs = nothing, history = false)
inputs = nothing, history = false, cachesyms::Tuple = ())
isscalar = !(exprs isa AbstractArray)
array_vars = Dict{Any, AbstractArray{Int}}()
if dvs !== nothing
for (j, x) in enumerate(dvs)
if iscall(x) && operation(x) == getindex
arg = arguments(x)[1]
inds = get!(() -> Int[], array_vars, arg)
push!(inds, j)
end
end
for (k, inds) in array_vars
if inds == (inds′ = inds[1]:inds[end])
array_vars[k] = inds′
end
end
var_to_arridxs = Dict()

uind = 1
else
if dvs === nothing
uind = 0
end
# values are (indexes, index of buffer, size of parameter)
array_parameters = Dict{Any, Tuple{AbstractArray{Int}, Int, Tuple{Vararg{Int}}}}()
# If for some reason different elements of an array parameter are in different buffers
other_array_parameters = Dict{Any, Any}()

hasinputs = inputs !== nothing
input_vars = Dict{Any, AbstractArray{Int}}()
if hasinputs
for (j, x) in enumerate(inputs)
if iscall(x) && operation(x) == getindex
arg = arguments(x)[1]
inds = get!(() -> Int[], input_vars, arg)
push!(inds, j)
end
end
for (k, inds) in input_vars
if inds == (inds′ = inds[1]:inds[end])
input_vars[k] = inds′
end
end
end
if has_index_cache(sys)
ic = get_index_cache(sys)
else
ic = nothing
end
if ps isa Tuple && eltype(ps) <: AbstractArray
ps = Iterators.flatten(ps)
end
for p in ps
p = unwrap(p)
if iscall(p) && operation(p) == getindex
p = arguments(p)[1]
end
symtype(p) <: AbstractArray && Symbolics.shape(p) != Symbolics.Unknown() || continue
scal = collect(p)
# all scalarized variables are in `ps`
any(isequal(p), ps) || all(x -> any(isequal(x), ps), scal) || continue
(haskey(array_parameters, p) || haskey(other_array_parameters, p)) && continue

idx = parameter_index(sys, p)
idx isa Int && continue
if idx isa ParameterIndex
if idx.portion != SciMLStructures.Tunable()
continue
end
array_parameters[p] = (vec(idx.idx), 1, size(idx.idx))
uind = 1
for (i, x) in enumerate(dvs)
iscall(x) && operation(x) == getindex || continue
arg = arguments(x)[1]
inds = get!(() -> [], var_to_arridxs, arg)
push!(inds, (uind, i))
end
end
p_start = uind + 1 + history
rps = (reorder_parameters(sys, ps)..., cachesyms...)
if inputs !== nothing
rps = (inputs, rps...)
end
for sym in reduce(vcat, rps; init = [])
iscall(sym) && operation(sym) == getindex || continue
arg = arguments(sym)[1]

bufferidx = findfirst(buf -> any(isequal(sym), buf), rps)
idxinbuffer = findfirst(isequal(sym), rps[bufferidx])
inds = get!(() -> [], var_to_arridxs, arg)
push!(inds, (p_start + bufferidx - 1, idxinbuffer))
end

viewsyms = Dict()
splitsyms = Dict()
for (arrsym, idxs) in var_to_arridxs
length(idxs) == length(arrsym) || continue
# allequal(first, idxs) is a 1.11 feature
if allequal(Iterators.map(first, idxs))
viewsyms[arrsym] = (first(first(idxs)), reshape(last.(idxs), size(arrsym)))
else
# idx === nothing
idxs = map(Base.Fix1(parameter_index, sys), scal)
if first(idxs) isa ParameterIndex
buffer_idxs = map(Base.Fix1(iterated_buffer_index, ic), idxs)
if allequal(buffer_idxs)
buffer_idx = first(buffer_idxs)
if first(idxs).portion == SciMLStructures.Tunable()
idxs = map(x -> x.idx, idxs)
else
idxs = map(x -> x.idx[end], idxs)
end
else
other_array_parameters[p] = scal
continue
end
else
buffer_idx = 1
end

sz = size(idxs)
if vec(idxs) == idxs[begin]:idxs[end]
idxs = idxs[begin]:idxs[end]
elseif vec(idxs) == idxs[begin]:-1:idxs[end]
idxs = idxs[begin]:-1:idxs[end]
end
idxs = vec(idxs)
array_parameters[p] = (idxs, buffer_idx, sz)
splitsyms[arrsym] = reshape(idxs, size(arrsym))
end
end

inputind = if history
uind + 2
else
uind + 1
end
params_offset = if history && hasinputs
uind + 2
elseif history || hasinputs
uind + 1
else
uind
end
if isscalar
function (expr)
Func(
expr.args,
[],
Let(
vcat(
[k ← :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
[k ← :(view($(expr.args[inputind].name), $v))
for (k, v) in input_vars],
[k ← :(reshape(
view($(expr.args[params_offset + buffer_idx].name), $idxs),
$sz))
for (k, (idxs, buffer_idx, sz)) in array_parameters],
[k ← Code.MakeArray(v, symtype(k))
for (k, v) in other_array_parameters]
[sym ← :(view($(expr.args[i].name), $idxs))
for (sym, (i, idxs)) in viewsyms],
[sym ←
MakeArray([expr.args[bufi].elems[vali] for (bufi, vali) in idxs],
expr.args[idxs[1][1]]) for (sym, idxs) in splitsyms]
),
expr.body,
false
Expand All @@ -371,15 +298,11 @@ function wrap_array_vars(
[],
Let(
vcat(
[k ← :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
[k ← :(view($(expr.args[inputind].name), $v))
for (k, v) in input_vars],
[k ← :(reshape(
view($(expr.args[params_offset + buffer_idx].name), $idxs),
$sz))
for (k, (idxs, buffer_idx, sz)) in array_parameters],
[k ← Code.MakeArray(v, symtype(k))
for (k, v) in other_array_parameters]
[sym ← :(view($(expr.args[i].name), $idxs))
for (sym, (i, idxs)) in viewsyms],
[sym ←
MakeArray([expr.args[bufi].elems[vali] for (bufi, vali) in idxs],
expr.args[idxs[1][1]]) for (sym, idxs) in splitsyms]
),
expr.body,
false
Expand All @@ -392,17 +315,11 @@ function wrap_array_vars(
[],
Let(
vcat(
[k ← :(view($(expr.args[uind + 1].name), $v))
for (k, v) in array_vars],
[k ← :(view($(expr.args[inputind + 1].name), $v))
for (k, v) in input_vars],
[k ← :(reshape(
view($(expr.args[params_offset + buffer_idx + 1].name),
$idxs),
$sz))
for (k, (idxs, buffer_idx, sz)) in array_parameters],
[k ← Code.MakeArray(v, symtype(k))
for (k, v) in other_array_parameters]
[sym ← :(view($(expr.args[i + 1].name), $idxs))
for (sym, (i, idxs)) in viewsyms],
[sym ← MakeArray(
[expr.args[bufi + 1].elems[vali] for (bufi, vali) in idxs],
expr.args[idxs[1][1] + 1]) for (sym, idxs) in splitsyms]
),
expr.body,
false
Expand Down
Loading
Loading