diff --git a/ext/MTKExt.jl b/ext/MTKExt.jl index a9ea56b..8f7dab1 100644 --- a/ext/MTKExt.jl +++ b/ext/MTKExt.jl @@ -33,4 +33,8 @@ function SymbolicIndexingInterface.observed(sys::PartitionedGraphSystem, var::Nu SymbolicIndexingInterface.observed(sys, tosymbol(var; escape=false)) end +function SymbolicIndexingInterface.observed(sys::PartitionedGraphSystem, vars::Union{Vector{Num}, Tuple{Vararg{Num}}}) + SymbolicIndexingInterface.observed(sys, tosymbol.(vars; escape=false)) +end + end diff --git a/src/symbolic_indexing.jl b/src/symbolic_indexing.jl index 7b44f90..46dcfb8 100644 --- a/src/symbolic_indexing.jl +++ b/src/symbolic_indexing.jl @@ -192,6 +192,14 @@ function SymbolicIndexingInterface.is_observed(sys::PartitionedGraphSystem, sym) haskey(sys.compu_namemap, sym) end +function SymbolicIndexingInterface.observed(sys::PartitionedGraphSystem, syms::Union{Vector{Symbol}, Tuple{Vararg{Symbol}}}) + function (u, p, t) + map(syms) do sym + observed(sys, sym)(u, p, t) + end + end +end + function SymbolicIndexingInterface.observed(sys::PartitionedGraphSystem, sym) (; tup_index, v_index, prop, requires_inputs) = sys.compu_namemap[sym] diff --git a/test/runtests.jl b/test/runtests.jl index 3d677f5..c568799 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,3 +5,7 @@ using SafeTestsets solution_solve_test() sensitivity_test() end + +@safetestset "SII" begin + include("symbolic_indexing.jl") +end diff --git a/test/symbolic_indexing.jl b/test/symbolic_indexing.jl new file mode 100644 index 0000000..adc9d34 --- /dev/null +++ b/test/symbolic_indexing.jl @@ -0,0 +1,10 @@ +include("particle_osc_example.jl") +using SymbolicIndexingInterface + +@testset "Symbolic Indexing of Vectors" begin + sol = solve_particle_osc(x1=1.0, x2=-1.0) + + a = getsym(sol, :particle1₊a)(sol)[end] + ω = getsym(sol, :osc₊ω₀)(sol)[end] + @test getsym(sol, [:osc₊ω₀, :particle1₊a])(sol)[end] == [ω, a] +end