Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into MTK
Browse files Browse the repository at this point in the history
  • Loading branch information
vyudu committed Dec 1, 2024
2 parents b3da813 + e9fe9a1 commit 86c82ce
Show file tree
Hide file tree
Showing 23 changed files with 409 additions and 44 deletions.
1 change: 1 addition & 0 deletions .github/workflows/Tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
group:
- InterfaceI
- InterfaceII
- Initialization
- SymbolicIndexingInterface
- Extended
- Extensions
Expand Down
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ModelingToolkit"
uuid = "961ee093-0014-501f-94e3-6117800e7a78"
authors = ["Yingbo Ma <[email protected]>", "Chris Rackauckas <[email protected]> and contributors"]
version = "9.53.0"
version = "9.54.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down Expand Up @@ -64,6 +64,7 @@ BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
HomotopyContinuation = "f213a82b-91d6-5c5d-acf7-10f1c761b327"
InfiniteOpt = "20393b10-9daf-11e9-18c9-8db751c92c57"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"

[extensions]
Expand All @@ -72,6 +73,7 @@ MTKChainRulesCoreExt = "ChainRulesCore"
MTKDeepDiffsExt = "DeepDiffs"
MTKHomotopyContinuationExt = "HomotopyContinuation"
MTKLabelledArraysExt = "LabelledArrays"
MTKInfiniteOptExt = "InfiniteOpt"

[compat]
AbstractTrees = "0.3, 0.4"
Expand Down Expand Up @@ -104,6 +106,7 @@ FunctionWrappers = "1.1"
FunctionWrappersWrappers = "0.1"
Graphs = "1.5.2"
HomotopyContinuation = "2.11"
InfiniteOpt = "0.5"
InteractiveUtils = "1"
JuliaFormatter = "1.0.47"
JumpProcesses = "9.13.1"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorials/initialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ the `initialization_eqs` keyword argument, for example:

```@example init
prob = ODEProblem(pend, [x => 1], (0.0, 1.5), [g => 1], guesses = [λ => 0, y => 1],
initialization_eqs = [y ~ 1])
initialization_eqs = [y ~ 0])
sol = solve(prob, Rodas5P())
plot(sol, idxs = (x, y))
```
Expand Down
26 changes: 26 additions & 0 deletions ext/MTKInfiniteOptExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module MTKInfiniteOptExt
import ModelingToolkit
import SymbolicUtils
import NaNMath
import InfiniteOpt
import InfiniteOpt: JuMP, GeneralVariableRef

# This file contains method definitions to make it possible to trace through functions generated by MTK using JuMP variables

for ff in [acos, log1p, acosh, log2, asin, tan, atanh, cos, log, sin, log10, sqrt]
f = nameof(ff)
# These need to be defined so that JuMP can trace through functions built by Symbolics
@eval NaNMath.$f(x::GeneralVariableRef) = Base.$f(x)
end

# JuMP variables and Symbolics variables never compare equal. When tracing through dynamics, a function argument can be either a JuMP variable or A Symbolics variable, it can never be both.
function Base.isequal(::SymbolicUtils.Symbolic,
::Union{JuMP.GenericAffExpr, JuMP.GenericQuadExpr, InfiniteOpt.AbstractInfOptExpr})
false
end
function Base.isequal(
::Union{JuMP.GenericAffExpr, JuMP.GenericQuadExpr, InfiniteOpt.AbstractInfOptExpr},
::SymbolicUtils.Symbolic)
false
end
end
8 changes: 8 additions & 0 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ function tearing_sub(expr, dict, s)
s ? simplify(expr) : expr
end

function tearing_substitute_expr(sys::AbstractSystem, expr; simplify = false)
empty_substitutions(sys) && return expr
substitutions = get_substitutions(sys)
@unpack subs = substitutions
solved = Dict(eq.lhs => eq.rhs for eq in subs)
return tearing_sub(expr, solved, simplify)
end

"""
$(TYPEDSIGNATURES)
Expand Down
57 changes: 41 additions & 16 deletions src/structural_transformation/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,41 @@ end
###
### Structural check
###
function check_consistency(state::TransformationState, orig_inputs)

"""
$(TYPEDSIGNATURES)
Check if the `state` represents a singular system, and return the unmatched variables.
"""
function singular_check(state::TransformationState)
@unpack graph, var_to_diff = state.structure
fullvars = get_fullvars(state)
# This is defined to check if Pantelides algorithm terminates. For more
# details, check the equation (15) of the original paper.
extended_graph = (@set graph.fadjlist = Vector{Int}[graph.fadjlist;
map(collect, edges(var_to_diff))])
extended_var_eq_matching = maximal_matching(extended_graph)

nvars = ndsts(graph)
unassigned_var = []
for (vj, eq) in enumerate(extended_var_eq_matching)
vj > nvars && break
if eq === unassigned && !isempty(𝑑neighbors(graph, vj))
push!(unassigned_var, fullvars[vj])
end
end
return unassigned_var
end

"""
$(TYPEDSIGNATURES)
Check the consistency of `state`, given the inputs `orig_inputs`. If `nothrow == false`,
throws an error if the system is under-/over-determined or singular. In this case, if the
function returns it will return `true`. If `nothrow == true`, it will return `false`
instead of throwing an error. The singular case will print a warning.
"""
function check_consistency(state::TransformationState, orig_inputs; nothrow = false)
fullvars = get_fullvars(state)
neqs = n_concrete_eqs(state)
@unpack graph, var_to_diff = state.structure
Expand All @@ -72,6 +106,7 @@ function check_consistency(state::TransformationState, orig_inputs)
is_balanced = n_highest_vars == neqs

if neqs > 0 && !is_balanced
nothrow && return false
varwhitelist = var_to_diff .== nothing
var_eq_matching = maximal_matching(graph, eq -> true, v -> varwhitelist[v]) # not assigned
# Just use `error_reporting` to do conditional
Expand All @@ -85,22 +120,12 @@ function check_consistency(state::TransformationState, orig_inputs)
error_reporting(state, bad_idxs, n_highest_vars, iseqs, orig_inputs)
end

# This is defined to check if Pantelides algorithm terminates. For more
# details, check the equation (15) of the original paper.
extended_graph = (@set graph.fadjlist = Vector{Int}[graph.fadjlist;
map(collect, edges(var_to_diff))])
extended_var_eq_matching = maximal_matching(extended_graph)

nvars = ndsts(graph)
unassigned_var = []
for (vj, eq) in enumerate(extended_var_eq_matching)
vj > nvars && break
if eq === unassigned && !isempty(𝑑neighbors(graph, vj))
push!(unassigned_var, fullvars[vj])
end
end
unassigned_var = singular_check(state)

if !isempty(unassigned_var) || !is_balanced
if nothrow
return false
end
io = IOBuffer()
Base.print_array(io, unassigned_var)
unassigned_var_str = String(take!(io))
Expand All @@ -110,7 +135,7 @@ function check_consistency(state::TransformationState, orig_inputs)
throw(InvalidSystemException(errmsg))
end

return nothing
return true
end

###
Expand Down
3 changes: 2 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3335,7 +3335,8 @@ function parse_variable(sys::AbstractSystem, str::AbstractString)
# I'd write a regex to validate `str`, but https://xkcd.com/1171/
str = strip(str)
derivative_level = 0
while ((cond1 = startswith(str, "D(")) || startswith(str, "Differential(")) && endswith(str, ")")
while ((cond1 = startswith(str, "D(")) || startswith(str, "Differential(")) &&
endswith(str, ")")
if cond1
derivative_level += 1
str = _string_view_inner(str, 2, 1)
Expand Down
22 changes: 20 additions & 2 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
checkbounds = false,
initializeprob = nothing,
initializeprobmap = nothing,
initializeprobpmap = nothing,
update_initializeprob! = nothing,
kwargs...) where {iip}
if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`")
Expand Down Expand Up @@ -643,7 +645,9 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
jac_prototype = jac_prototype,
observed = observedfun,
initializeprob = initializeprob,
initializeprobmap = initializeprobmap)
initializeprobmap = initializeprobmap,
initializeprobpmap = initializeprobpmap,
update_initializeprob! = update_initializeprob!)
end

function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...)
Expand Down Expand Up @@ -1387,7 +1391,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
check_length = true,
warn_initialize_determined = true,
initialization_eqs = [],
fully_determined = false,
fully_determined = nothing,
check_units = true,
kwargs...) where {iip, specialize}
if !iscomplete(sys)
Expand All @@ -1405,6 +1409,19 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
sys; u0map, initialization_eqs, check_units, pmap = parammap); fully_determined)
end

ts = get_tearing_state(isys)
if warn_initialize_determined &&
(unassigned_vars = StructuralTransformations.singular_check(ts); !isempty(unassigned_vars))
errmsg = """
The initialization system is structurally singular. Guess values may \
significantly affect the initial values of the ODE. The problematic variables \
are $unassigned_vars.
Note that the identification of problematic variables is a best-effort heuristic.
"""
@warn errmsg
end

uninit = setdiff(unknowns(sys), [unknowns(isys); getfield.(observed(isys), :lhs)])

# TODO: throw on uninitialized arrays
Expand Down Expand Up @@ -1448,6 +1465,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
u0T = promote_type(u0T, typeof(fullmap[eq.lhs]))
end
if u0T != Union{}
u0T = eltype(u0T)
u0map = Dict(k => if symbolic_type(v) == NotSymbolic() && !is_array_of_symbolics(v)
v isa AbstractArray ? u0T.(v) : u0T(v)
else
Expand Down
29 changes: 19 additions & 10 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ const TunableIndexMap = Dict{BasicSymbolic,
Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}}}
const TimeseriesSetType = Set{Union{ContinuousTimeseries, Int}}

const SymbolicParam = Union{BasicSymbolic, CallWithMetadata}

struct IndexCache
unknown_idx::UnknownIndexMap
# sym => (bufferidx, idx_in_buffer)
discrete_idx::Dict{BasicSymbolic, DiscreteIndex}
discrete_idx::Dict{SymbolicParam, DiscreteIndex}
# sym => (clockidx, idx_in_clockbuffer)
callback_to_clocks::Dict{Any, Vector{Int}}
tunable_idx::TunableIndexMap
Expand All @@ -56,13 +58,13 @@ struct IndexCache
tunable_buffer_size::BufferTemplate
constant_buffer_sizes::Vector{BufferTemplate}
nonnumeric_buffer_sizes::Vector{BufferTemplate}
symbol_to_variable::Dict{Symbol, Union{BasicSymbolic, CallWithMetadata}}
symbol_to_variable::Dict{Symbol, SymbolicParam}
end

function IndexCache(sys::AbstractSystem)
unks = solved_unknowns(sys)
unk_idxs = UnknownIndexMap()
symbol_to_variable = Dict{Symbol, Union{BasicSymbolic, CallWithMetadata}}()
symbol_to_variable = Dict{Symbol, SymbolicParam}()

let idx = 1
for sym in unks
Expand Down Expand Up @@ -95,18 +97,18 @@ function IndexCache(sys::AbstractSystem)

tunable_buffers = Dict{Any, Set{BasicSymbolic}}()
constant_buffers = Dict{Any, Set{BasicSymbolic}}()
nonnumeric_buffers = Dict{Any, Set{Union{BasicSymbolic, CallWithMetadata}}}()
nonnumeric_buffers = Dict{Any, Set{SymbolicParam}}()

function insert_by_type!(buffers::Dict{Any, S}, sym, ctype) where {S}
sym = unwrap(sym)
buf = get!(buffers, ctype, S())
push!(buf, sym)
end

disc_param_callbacks = Dict{BasicSymbolic, Set{Int}}()
disc_param_callbacks = Dict{SymbolicParam, Set{Int}}()
events = vcat(continuous_events(sys), discrete_events(sys))
for (i, event) in enumerate(events)
discs = Set{BasicSymbolic}()
discs = Set{SymbolicParam}()
affs = affects(event)
if !(affs isa AbstractArray)
affs = [affs]
Expand All @@ -130,26 +132,32 @@ function IndexCache(sys::AbstractSystem)
isequal(only(arguments(sym)), get_iv(sys))
clocks = get!(() -> Set{Int}(), disc_param_callbacks, sym)
push!(clocks, i)
else
elseif is_variable_floatingpoint(sym)
insert_by_type!(constant_buffers, sym, symtype(sym))
else
stype = symtype(sym)
if stype <: FnType
stype = fntype_to_function_type(stype)
end
insert_by_type!(nonnumeric_buffers, sym, stype)
end
end
end
clock_partitions = unique(collect(values(disc_param_callbacks)))
disc_symtypes = unique(symtype.(keys(disc_param_callbacks)))
disc_symtype_idx = Dict(disc_symtypes .=> eachindex(disc_symtypes))
disc_syms_by_symtype = [BasicSymbolic[] for _ in disc_symtypes]
disc_syms_by_symtype = [SymbolicParam[] for _ in disc_symtypes]
for sym in keys(disc_param_callbacks)
push!(disc_syms_by_symtype[disc_symtype_idx[symtype(sym)]], sym)
end
disc_syms_by_symtype_by_partition = [Vector{BasicSymbolic}[] for _ in disc_symtypes]
disc_syms_by_symtype_by_partition = [Vector{SymbolicParam}[] for _ in disc_symtypes]
for (i, buffer) in enumerate(disc_syms_by_symtype)
for partition in clock_partitions
push!(disc_syms_by_symtype_by_partition[i],
[sym for sym in buffer if disc_param_callbacks[sym] == partition])
end
end
disc_idxs = Dict{BasicSymbolic, DiscreteIndex}()
disc_idxs = Dict{SymbolicParam, DiscreteIndex}()
callback_to_clocks = Dict{
Union{SymbolicContinuousCallback, SymbolicDiscreteCallback}, Set{Int}}()
for (typei, disc_syms_by_partition) in enumerate(disc_syms_by_symtype_by_partition)
Expand Down Expand Up @@ -191,6 +199,7 @@ function IndexCache(sys::AbstractSystem)
end
haskey(disc_idxs, p) && continue
haskey(constant_buffers, ctype) && p in constant_buffers[ctype] && continue
haskey(nonnumeric_buffers, ctype) && p in nonnumeric_buffers[ctype] && continue
insert_by_type!(
if ctype <: Real || ctype <: AbstractArray{<:Real}
if istunable(p, true) && Symbolics.shape(p) != Symbolics.Unknown() &&
Expand Down
2 changes: 1 addition & 1 deletion src/systems/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,7 @@ function parse_variable_arg(dict, mod, arg, varclass, kwargs, where_types)
end

push!(varexpr.args, metadata_expr)
return vv isa Num ? name : :($name...), varexpr
return symbolic_type(vv) == ScalarSymbolic() ? name : :($name...), varexpr
else
return vv
end
Expand Down
Loading

0 comments on commit 86c82ce

Please sign in to comment.