diff --git a/.gitignore b/.gitignore index 3ca71c5..b0e0865 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,6 @@ Manifest.toml # Local reports (analysis, status reports, previews) should not be tracked reports/ + +# claude +CLAUDE.local.md diff --git a/Project.toml b/Project.toml index 3ca7759..44cd493 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "CTParser" uuid = "32681960-a1b1-40db-9bff-a1ca817385d1" -version = "0.8.0" +version = "0.7.9" authors = ["Jean-Baptiste Caillau "] [deps] diff --git a/src/onepass.jl b/src/onepass.jl index 5e3aa3c..57e3b17 100644 --- a/src/onepass.jl +++ b/src/onepass.jl @@ -96,19 +96,6 @@ function e_prefix!(p) return nothing end -# Utils - -""" -$(TYPEDSIGNATURES) - -Generate a fresh symbol by concatenating the given components and a -`gensym()` suffix. - -This is used throughout the parser to create unique internal names that -do not collide with user-defined identifiers. -""" -__symgen(s...) = Symbol(s..., gensym()) - """ $(TYPEDEF) @@ -200,28 +187,6 @@ __wrap(e, n, line) = quote end end -""" -$(TYPEDSIGNATURES) - -Return `true` if `x` represents a range. - -This predicate is specialised for `AbstractRange` values and for -expressions of the form `i:j` or `i:p:j`. -""" -is_range(x) = false -is_range(x::T) where {T<:AbstractRange} = true -is_range(x::Expr) = (x.head == :call) && (x.args[1] == :(:)) - -""" -$(TYPEDSIGNATURES) - -Return `x` itself if it is a range, or a one-element array `[x]`. - -This is a normalisation helper used when interpreting constraint -indices. -""" -as_range(x) = is_range(x) ? x : [x] - # Main code """ @@ -580,7 +545,10 @@ function p_state_exa!(p, p_ocp, x, n, xx; components_names=nothing) )) code = __wrap(code, p.lnum, p.line) dyn_con = Symbol(:dyn_con, x) # name for the constraints associated with the dynamics - code = :($x = $code; $dyn_con = Vector{$pref.Constraint}(undef, $n)) # affectation must be done outside try ... catch (otherwise declaration known only to try local scope) + code = quote + $x = $code + $dyn_con = Vector{$pref.Constraint}(undef, $n) # affectation must be done outside try ... catch (otherwise declaration known only to try local scope) + end return code end @@ -727,18 +695,21 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label) code = :(length($e1) == length($e3) == 1 || throw("this constraint must be scalar")) # (vs. __throw) since raised at runtime x0 = __symgen(:x0) xf = __symgen(:xf) + k = __symgen(:k) e2 = replace_call(e2, p.x, p.t0, x0) e2 = replace_call(e2, p.x, p.tf, xf) e2 = subs2(e2, x0, p.x, 0) + e2 = subs(e2, x0, :([$(p.x)[$k, 0] for $k ∈ 1:$(p.dim_x)])) e2 = subs2(e2, xf, p.x, :grid_size) - concat(code, :($pref.constraint($p_ocp, $e2; lcon=($e1), ucon=($e3)))) + e2 = subs(e2, xf, :([$(p.x)[$k, grid_size] for $k ∈ 1:$(p.dim_x)])) + concat(code, :($pref.constraint($p_ocp, $e2; lcon=($e1), ucon=($e3)))) # debug: to vectorise end (:initial, rg) => begin if isnothing(rg) rg = :(1:($(p.dim_x))) # x(t0) implies rg == nothing but means x[1:p.dim_x](t0) e2 = subs(e2, p.x, :($(p.x)[$rg])) elseif !is_range(rg) - rg = as_range(rg) + rg = as_range(rg) # case rg = i (vs i:j or i:p:j) end code = :( length($e1) == length($e3) == length($rg) || throw("wrong bound dimension") @@ -757,7 +728,7 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label) rg = :(1:($(p.dim_x))) e2 = subs(e2, p.x, :($(p.x)[$rg])) elseif !is_range(rg) - rg = as_range(rg) + rg = as_range(rg) # case rg = i (vs i:j or i:p:j) end code = :( length($e1) == length($e3) == length($rg) || throw("wrong bound dimension") @@ -776,7 +747,7 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label) rg = :(1:($(p.dim_v))) e2 = subs(e2, p.v, :($(p.v)[$rg])) elseif !is_range(rg) - rg = as_range(rg) + rg = as_range(rg) # case rg = i (vs i:j or i:p:j) end code_box = :( length($e1) == length($e3) == length($rg) || throw("wrong bound dimension") @@ -791,10 +762,9 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label) end (:state_range, rg) => begin if isnothing(rg) - rg = :(1:($(p.dim_x))) - e2 = subs(e2, p.x, :($(p.x)[$rg])) + rg = :(1:($(p.dim_x))) # NB. no need to update e2 (unused) here elseif !is_range(rg) - rg = as_range(rg) + rg = as_range(rg) # case rg = i (vs i:j or i:p:j) end code_box = :( length($e1) == length($e3) == length($rg) || throw("wrong bound dimension") @@ -809,10 +779,9 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label) end (:control_range, rg) => begin if isnothing(rg) - rg = :(1:($(p.dim_u))) - e2 = subs(e2, p.u, :($(p.u)[$rg])) + rg = :(1:($(p.dim_u))) # NB. no need to update e2 (unused here) elseif !is_range(rg) - rg = as_range(rg) + rg = as_range(rg) # case rg = i (vs i:j or i:p:j) end code_box = :( length($e1) == length($e3) == length($rg) || throw("wrong bound dimension") @@ -831,8 +800,11 @@ function p_constraint_exa!(p, p_ocp, e1, e2, e3, c_type, label) ut = __symgen(:ut) e2 = replace_call(e2, [p.x, p.u], p.t, [xt, ut]) j = __symgen(:j) + k = __symgen(:k) e2 = subs2(e2, xt, p.x, j) + e2 = subs(e2, xt, :([$(p.x)[$k, $j] for $k ∈ 1:$(p.dim_x)])) e2 = subs2(e2, ut, p.u, j) + e2 = subs(e2, ut, :([$(p.u)[$k, $j] for $k ∈ 1:$(p.dim_u)])) e2 = subs(e2, p.t, :($(p.t0) + $j * $(p.dt))) concat( code, @@ -931,26 +903,33 @@ function p_dynamics_coord_exa!(p, p_ocp, x, i, t, e) j1 = __symgen(:j) j2 = :($j1 + 1) j12 = :($j1 + 0.5) + k = __symgen(:k) ej1 = subs2(e, xt, p.x, j1) + ej1 = subs(ej1, xt, :([$(p.x)[$k, $j1] for $k ∈ 1:$(p.dim_x)])) ej1 = subs2(ej1, ut, p.u, j1) + ej1 = subs(ej1, ut, :([$(p.u)[$k, $j1] for $k ∈ 1:$(p.dim_u)])) ej1 = subs(ej1, p.t, :($(p.t0) + $j1 * $(p.dt))) ej2 = subs2(e, xt, p.x, j2) + ej2 = subs(ej2, xt, :([$(p.x)[$k, $j2] for $k ∈ 1:$(p.dim_x)])) ej2 = subs2(ej2, ut, p.u, j2) + ej2 = subs(ej2, ut, :([$(p.u)[$k, $j2] for $k ∈ 1:$(p.dim_u)])) ej2 = subs(ej2, p.t, :($(p.t0) + $j2 * $(p.dt))) - ej12 = subs5(e, xt, p.x, j1) + ej12 = subs2m(e, xt, p.x, j1) + ej12 = subs(ej12, xt, :([(($(p.x)[$k, $j1] + $(p.x)[$k, $j1 + 1]) / 2) for $k ∈ 1:$(p.dim_x)])) ej12 = subs2(ej12, ut, p.u, j1) + ej12 = subs(ej12, ut, :([$(p.u)[$k, $j1] for $k ∈ 1:$(p.dim_u)])) ej12 = subs(ej12, p.t, :($(p.t0) + $j12 * $(p.dt))) dxij = :($(p.x)[$i, $j2] - $(p.x)[$i, $j1]) code = quote if scheme == :euler - $pref.constraint($p_ocp, $dxij - $(p.dt) * $ej1 for $j1 in 0:(grid_size - 1)) + $pref.constraint($p_ocp, $dxij - $(p.dt) * $ej1 for $j1 in 0:grid_size-1) elseif scheme ∈ (:euler_implicit, :euler_b) # euler_b is deprecated - $pref.constraint($p_ocp, $dxij - $(p.dt) * $ej2 for $j1 in 0:(grid_size - 1)) + $pref.constraint($p_ocp, $dxij - $(p.dt) * $ej2 for $j1 in 0:grid_size-1) elseif scheme == :midpoint - $pref.constraint($p_ocp, $dxij - $(p.dt) * $ej12 for $j1 in 0:(grid_size - 1)) + $pref.constraint($p_ocp, $dxij - $(p.dt) * $ej12 for $j1 in 0:grid_size-1) elseif scheme ∈ (:trapeze, :trapezoidal) # trapezoidal is deprecated $pref.constraint( - $p_ocp, $dxij - $(p.dt) * ($ej1 + $ej2) / 2 for $j1 in 0:(grid_size - 1) + $p_ocp, $dxij - $(p.dt) * ($ej1 + $ej2) / 2 for $j1 in 0:grid_size-1 ) else throw( @@ -1001,22 +980,27 @@ function p_lagrange_exa!(p, p_ocp, e, type) j1 = __symgen(:j) j2 = :($j1 + 1) j12 = :($j1 + 0.5) + k = __symgen(:k) ej1 = subs2(e, xt, p.x, j1) + ej1 = subs(ej1, xt, :([$(p.x)[$k, $j1] for $k ∈ 1:$(p.dim_x)])) ej1 = subs2(ej1, ut, p.u, j1) + ej1 = subs(ej1, ut, :([$(p.u)[$k, $j1] for $k ∈ 1:$(p.dim_u)])) ej1 = subs(ej1, p.t, :($(p.t0) + $j1 * $(p.dt))) - ej12 = subs5(e, xt, p.x, j1) + ej12 = subs2m(e, xt, p.x, j1) + ej12 = subs(ej12, xt, :([(($(p.x)[$k, $j1] + $(p.x)[$k, $j1 + 1]) / 2) for $k ∈ 1:$(p.dim_x)])) ej12 = subs2(ej12, ut, p.u, j1) + ej12 = subs(ej12, ut, :([$(p.u)[$k, $j1] for $k ∈ 1:$(p.dim_u)])) ej12 = subs(ej12, p.t, :($(p.t0) + $j12 * $(p.dt))) code = quote if scheme == :euler - $pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 0:(grid_size - 1)) + $pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 0:grid_size-1) elseif scheme ∈ (:euler_implicit, :euler_b) # euler_b is deprecated $pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 1:grid_size) elseif scheme == :midpoint - $pref.objective($p_ocp, $(p.dt) * $ej12 for $j1 in 0:(grid_size - 1)) + $pref.objective($p_ocp, $(p.dt) * $ej12 for $j1 in 0:grid_size-1) elseif scheme ∈ (:trapeze, :trapezoidal) # trapezoidal is deprecated $pref.objective($p_ocp, $(p.dt) * $ej1 / 2 for $j1 in (0, grid_size)) - $pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 1:(grid_size - 1)) + $pref.objective($p_ocp, $(p.dt) * $ej1 for $j1 in 1:grid_size-1) else throw( "unknown numerical scheme: $scheme (possible choices are :euler, :euler_implicit, :midpoint, :trapeze)", @@ -1062,10 +1046,13 @@ function p_mayer_exa!(p, p_ocp, e, type) pref = prefix_exa() x0 = __symgen(:x0) xf = __symgen(:xf) + k = __symgen(:k) e = replace_call(e, p.x, p.t0, x0) e = replace_call(e, p.x, p.tf, xf) e = subs2(e, x0, p.x, 0) + e = subs(e, x0, :([$(p.x)[$k, 0] for $k ∈ 1:$(p.dim_x)])) e = subs2(e, xf, p.x, :grid_size) + e = subs(e, xf, :([$(p.x)[$k, grid_size] for $k ∈ 1:$(p.dim_x)])) # now, x[i](t0) has been replaced by x[i, 0] and x[i](tf) by x[i, grid_size] code = :($pref.objective($p_ocp, $e)) return __wrap(code, p.lnum, p.line) @@ -1295,7 +1282,7 @@ function def_fun(e; log=false) $p_ocp = $pref.PreModel() $code $pref.definition!($p_ocp, $ee) - $pref.time_dependence!($p_ocp; autonomous=$p.is_autonomous) + $pref.time_dependence!($p_ocp; autonomous=$p.is_autonomous) # not $(p.xxxx) as this info is known statically end if is_active_backend(:exa) @@ -1383,7 +1370,7 @@ function def_exa(e; log=false) $(p.box_u) # lvar and uvar for control $(p.box_v) # lvar and uvar for variable (after x and u for compatibility with CTDirect) $p_ocp = $pref.ExaCore( - base_type; backend=backend, minimize=($p.criterion == :min) + base_type; backend=backend, minimize=($p.criterion == :min) # not $(p.xxxx) as this info is known statically ) $code $dyn_check diff --git a/src/utils.jl b/src/utils.jl index 4e315d6..04a7d2c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -4,6 +4,39 @@ """ $(TYPEDSIGNATURES) +Generate a fresh symbol by concatenating the given components and a +`gensym()` suffix. + +This is used throughout the parser to create unique internal names that +do not collide with user-defined identifiers. +""" +__symgen(s...) = Symbol(s..., gensym()) + +""" +$(TYPEDSIGNATURES) + +Return `true` if `x` represents a range. + +This predicate is specialised for `AbstractRange` values and for +expressions of the form `i:j` or `i:p:j`. +""" +is_range(x) = false +is_range(x::T) where {T<:AbstractRange} = true +is_range(x::Expr) = (x.head == :call) && (x.args[1] == :(:)) + +""" +$(TYPEDSIGNATURES) + +Return `x` itself if it is a range, or a one-element array `[x]`. + +This is a normalisation helper used when interpreting constraint +indices. +""" +as_range(x) = is_range(x) ? x : [x] + +""" +$(TYPEDSIGNATURES) + Expr iterator: apply `_Expr` to nodes and `f` to leaves of the AST. # Example @@ -64,39 +97,47 @@ end """ $(TYPEDSIGNATURES) -Substitute x[i] by y[i, j], whatever i, in e. See also: subs5. +Substitute occurrences of symbol `x` in expression `e` with indexed access to `y` at time index `j`. +Handles two patterns: +- `x[i]` (scalar index) → `y[i, j]` +- `x[1:3]` (range index) → `[y[k, j] for k ∈ 1:3]` + +See also: subs2m. # Examples ```@example +julia> # Scalar indexing julia> e = :(x0[1] * 2xf[3] - cos(xf[2]) * 2x0[2]) -:(x0[1] * (2 * xf[3]) - cos(xf[2]) * (2 * x0[2])) - julia> subs2(subs2(e, :x0, :x, 0), :xf, :x, :N) :(x[1, 0] * (2 * x[3, N]) - cos(x[2, N]) * (2 * x[2, 0])) -julia> e = :(x0 * 2xf[3] - cos(xf) * 2x0[2]) -:(x0 * (2 * xf[3]) - cos(xf) * (2 * x0[2])) +julia> # Range indexing +julia> e = :(x0[1:3]) +julia> subs2(e, :x0, :x, 0; k = :k) +:([x[k, 0] for k ∈ 1:3]) +julia> # Bare symbols are not substituted +julia> e = :(x0 * 2xf[3]) julia> subs2(subs2(e, :x0, :x, 0), :xf, :x, :N) -:(x0 * (2 * x[3, N]) - cos(xf) * (2 * x[2, 0])) +:(x0 * (2 * x[3, N])) ``` """ -function subs2(e, x, y, j) +function subs2(e, x, y, j; k = __symgen(:k)) foo(x, y, j) = (h, args...) -> begin f = Expr(h, args...) @match f begin - :($xx[$i]) && if (xx == x) - end => :($y[$i, $j]) + :($xx[$rg]) && if ((xx == x) && is_range(rg)) end => :([$y[$k, $j] for $k ∈ $rg]) + :($xx[$i]) && if (xx == x) end => :($y[$i, $j]) _ => f end end - expr_it(e, foo(x, y, j), x -> x) + expr_it(e, foo(x, y, j), x -> x) end """ $(TYPEDSIGNATURES) -Substitute x[rg] by y[i, j], whatever rg, in e. +Substitute x[rg] by y[i, j], whatever rg, in e. (Note: rg is then expected to be used to loop on i.) # Examples ```@example @@ -125,59 +166,42 @@ end """ $(TYPEDSIGNATURES) -Substitute x[rg] by y[i], whatever rg, in e. +Substitute x[i] or x[rg] in e for the midpoint scheme: +- x[i] → (y[i, j] + y[i, j + 1]) / 2 (scalar indexing) +- x[rg] → [(y[k, j] + y[k, j + 1]) / 2 for k ∈ rg] (range indexing) -# Examples -```@example -julia> e = :(v[1:2:d] * 2xf[1:3]) -:(v[1:2:d] * (2 * xf[1:3])) +Bare symbols like x (without indexing) are NOT substituted. -julia> subs4(e, :v, :v, :i) -:(v[i] * (2 * xf[1:3])) - -julia> subs4(e, :xf, :xf, 1) -:(v[1:2:d] * (2 * xf[1])) -``` -""" -function subs4(e, x, y, i) # currently unused - foo(x, y, i) = (h, args...) -> begin - f = Expr(h, args...) - @match f begin - :($xx[$rg]) && if (xx == x) - end => :($y[$i]) - _ => f - end - end - expr_it(e, foo(x, y, i), x -> x) -end - -""" -$(TYPEDSIGNATURES) - -Substitute x[i] by (y[i, j] + y[i, j + 1]) / 2, whatever i, in e. See also: subs2. +See also: subs2. # Examples ```@example julia> e = :(x0[1] * 2xf[3] - cos(xf[2]) * 2x0[2]) :(x0[1] * (2 * xf[3]) - cos(xf[2]) * (2 * x0[2])) -julia> subs5(subs5(e, :x0, :x, 0), :xf, :x, :N) +julia> subs2m(subs2m(e, :x0, :x, 0), :xf, :x, :N) :(((x[1, 0] + x[1, 0 + 1]) / 2) * (2 * ((x[3, N] + x[3, N + 1]) / 2)) - cos((x[2, N] + x[2, N + 1]) / 2) * (2 * ((x[2, 0] + x[2, 0 + 1]) / 2))) julia> e = :(x0 * 2xf[3] - cos(xf) * 2x0[2]) :(x0 * (2 * xf[3]) - cos(xf) * (2 * x0[2])) -julia> subs5(subs5(e, :x0, :x, 0), :xf, :x, :N) +julia> subs2m(subs2m(e, :x0, :x, 0), :xf, :x, :N) :(x0 * (2 * ((x[3, N] + x[3, N + 1]) / 2)) - cos(xf) * (2 * ((x[2, 0] + x[2, 0 + 1]) / 2))) + +julia> e = :(x0[1:3]) +:(x0[1:3]) + +julia> subs2m(e, :x0, :x, 0) +:([((x[k, 0] + x[k, 0 + 1]) / 2) for k ∈ 1:3]) ``` """ -function subs5(e, x, y, j) +function subs2m(e, x, y, j; k = __symgen(:k)) foo(x, y, j) = (h, args...) -> begin f = Expr(h, args...) @match f begin - :($xx[$i]) && if (xx == x) - end => :(($y[$i, $j] + $y[$i, $j + 1]) / 2) + :($xx[$rg]) && if ((xx == x) && is_range(rg)) end => :([($y[$k, $j] + $y[$k, $j + 1]) / 2 for $k ∈ $rg]) + :($xx[$i]) && if (xx == x) end => :(($y[$i, $j] + $y[$i, $j + 1]) / 2) _ => f end end diff --git a/test/runtests.jl b/test/runtests.jl index a38ceba..7a9a955 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,9 +5,8 @@ import CTParser: CTParser, subs, subs2, + subs2m, subs3, - subs4, - subs5, replace_call, has, concat, diff --git a/test/test_onepass_exa.jl b/test/test_onepass_exa.jl index 19c9e4f..ded1900 100644 --- a/test/test_onepass_exa.jl +++ b/test/test_onepass_exa.jl @@ -104,6 +104,538 @@ function __test_onepass_exa( @test CTParser.as_range(:(x + 1)) == [:(x + 1)] end + test_name = "bare symbols and ranges - costs ($backend_name, $scheme)" + @testset "$test_name" begin + println(test_name) + + # Test: Lagrange with sum over all state components + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + x(0) == [1, 2, 3] + x(1) == [4, 5, 6] + ∫(sum(x(t))^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Lagrange with sum over range of states + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + ∫(sum(x[1:2](t))^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Lagrange with sum over all controls + o = @def begin + t ∈ [0, 1], time + x ∈ R², state + u ∈ R³, control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∫(sum(u(t))^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Lagrange with sum over range of controls + o = @def begin + t ∈ [0, 1], time + x ∈ R², state + u ∈ R³, control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∫(sum(u[1:2](t))^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Mayer with sum over all states at t0 + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x(0))^2 → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Mayer with sum over all states at tf + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x(1))^2 → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Mayer with sum over range at t0 + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x[1:2](0))^2 → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Mayer with sum over range at tf + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x[2:3](1))^2 → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Bolza cost with bare symbols + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + (sum(x(0))^2 + sum(x(1))^2) + ∫(sum(u(t))^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Bolza cost with ranges + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + (sum(x[1:2](0)) + sum(x[2:3](1))) + ∫(sum(u[1:2](t))) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + end + + test_name = "bare symbols and ranges - constraints ($backend_name, $scheme)" + @testset "$test_name" begin + println(test_name) + + # Test: Initial constraint with bare symbol + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x(0)) == 6 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Initial constraint with range + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x[1:2](0)) == 3 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Final constraint with bare symbol + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x(1)) == 15 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Final constraint with range + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x[2:3](1)) == 11 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Boundary constraint combining t0 and tf with bare symbols + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x(0)) + sum(x(1)) == 21 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Boundary constraint with ranges + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x[1:2](0)) - sum(x[2:3](1)) == -8 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Path constraint with bare state symbol + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x(t))^2 ≤ 100 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Path constraint with state range + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x[1:2](t)) ≤ 10 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Path constraint with bare control symbol + o = @def begin + t ∈ [0, 1], time + x ∈ R², state + u ∈ R³, control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + sum(u(t))^2 ≤ 5 + ∫(x₁(t)^2 + x₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Path constraint with control range + o = @def begin + t ∈ [0, 1], time + x ∈ R², state + u ∈ R³, control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + sum(u[1:2](t)) ≤ 3 + ∫(x₁(t)^2 + x₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Mixed constraint with bare symbols + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + sum(x(t)) + sum(u(t)) ≤ 15 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Mixed constraint with ranges + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R³, control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₃(t) + sum(x[1:2](t)) + sum(u[2:3](t)) ≤ 8 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + end + + test_name = "bare symbols and ranges - dynamics ($backend_name, $scheme)" + @testset "$test_name" begin + println(test_name) + + # Test: Dynamics with sum over all states + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == sum(x(t)) + ∂(x₂)(t) == u₁(t) + ∂(x₃)(t) == u₂(t) + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Dynamics with sum over state range + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == sum(x[2:3](t)) + ∂(x₂)(t) == u₁(t) + ∂(x₃)(t) == u₂(t) + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Dynamics with sum over all controls + o = @def begin + t ∈ [0, 1], time + x ∈ R², state + u ∈ R³, control + ∂(x₁)(t) == sum(u(t)) + ∂(x₂)(t) == u₁(t) + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Dynamics with sum over control range + o = @def begin + t ∈ [0, 1], time + x ∈ R², state + u ∈ R³, control + ∂(x₁)(t) == sum(u[1:2](t)) + ∂(x₂)(t) == u₃(t) + ∫(u₁(t)^2 + u₂(t)^2 + u₃(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Dynamics with mixed bare symbols + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == sum(x(t)) + sum(u(t)) + ∂(x₂)(t) == u₁(t) + ∂(x₃)(t) == u₂(t) + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Dynamics with mixed ranges + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R³, control + ∂(x₁)(t) == sum(x[1:2](t)) + sum(u[2:3](t)) + ∂(x₂)(t) == u₁(t) + ∂(x₃)(t) == u₂(t) + ∫(u₁(t)^2 + u₂(t)^2 + u₃(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + end + + test_name = "user-defined functions with ranges ($backend_name, $scheme)" + @testset "$test_name" begin + println(test_name) + + # Define user functions outside @def + f(x, u) = x[1] * x[3] + u[1]^2 * cos(u[2]) + g(x) = x[1] + 2 * x[2] + h(u) = u[1]^2 + sin(u[2]) + + # Test: User-defined function in Lagrange cost + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + ∫(f(x(t), u(t))^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: User-defined function in Mayer cost at t0 + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + f(x(0), [0, 0])^2 → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: User-defined function in Mayer cost at tf + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + f(x(1), [0, 0])^2 → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: User-defined function in Bolza cost + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + (f(x(0), [0, 0]) + f(x(1), [0, 0])) + ∫(h(u(t))) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: User-defined function in initial constraint + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + f(x(0), [0, 0]) == 5 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: User-defined function in final constraint + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + f(x(1), [0, 0]) == 10 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: User-defined function in boundary constraint + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + f(x(0), [0, 0]) + f(x(1), [0, 0]) == 15 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: User-defined function in path constraint + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == u₁(t) + u₂(t) + f(x(t), u(t)) ≤ 10 + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: User-defined function in dynamics + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == u₁(t) + ∂(x₂)(t) == u₂(t) + ∂(x₃)(t) == g(x[1:2](t)) + ∫(u₁(t)^2 + u₂(t)^2) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + + # Test: Multiple user-defined functions + o = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + ∂(x₁)(t) == g(x[1:2](t)) + ∂(x₂)(t) == u₁(t) + ∂(x₃)(t) == u₂(t) + h(u(t)) ≤ 5 + (f(x(0), [0, 0])) + ∫(f(x(t), u(t))) → min + end + m = discretise_exa(o; backend=backend, scheme=scheme) + @test m isa ExaModels.ExaModel + end + test_name = "pragma ($backend_name, $scheme)" @testset "$test_name" begin println(test_name) @@ -1005,4 +1537,58 @@ function __test_onepass_exa( sol = madnlp(m; tol=tolerance, kwargs...) @test sol.status == MadNLP.SOLVE_SUCCEEDED end -end + + test_name = "use case no. 4: vectorised ($backend_name, $scheme)" + @testset "$test_name" begin + println(test_name) + + f₁(x, u) = 2x[1] * u[1] + x[2] * u[2] + f₂(x) = x[1] + 2x[2] - x[3] + f₃(x0, xf) = x0[2]^2 + sum(xf)^2 + f₄(u) = sum(u.^2) + + o1 = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + + x[1:2:3](0) == [1, 3] + + ∂(x₁)(t) == sum(x(t)) + ∂(x₂)(t) == f₁(x(t), u(t)) + ∂(x₃)(t) == f₂(x(t)) + + f₃(x(0), x(1)) + 0.5∫( f₄(u(t)) ) → min + end + + N = 250 + m1, _ = discretise_exa_full(o1; grid_size=N, backend=backend, scheme=scheme) + @test m1 isa ExaModels.ExaModel + sol1 = madnlp(m1; tol=tolerance, kwargs...) + @test sol1.status == MadNLP.SOLVE_SUCCEEDED + obj1 = sol1.objective + + o2 = @def begin + t ∈ [0, 1], time + x ∈ R³, state + u ∈ R², control + + x[1:2:3](0) == [1, 3] + + ∂(x₁)(t) == x₁(t) + x₂(t) + x₃(t) + ∂(x₂)(t) == 2x₁(t) * u₁(t) + x₂(t) * u₂(t) + ∂(x₃)(t) == x₁(t) + 2x₂(t) - x₃(t) + + (x₂(0)^2 + (x₁(1) + x₂(1) + x₃(1))^2) + 0.5∫( u₁(t)^2 + u₂(t)^2 ) → min + end + + m2, _ = discretise_exa_full(o2; grid_size=N, backend=backend, scheme=scheme) + @test m2 isa ExaModels.ExaModel + sol2 = madnlp(m2; tol=tolerance, kwargs...) + @test sol2.status == MadNLP.SOLVE_SUCCEEDED + obj2 = sol2.objective + + __atol = 1e-6 + @test obj1 ≈ obj2 atol = __atol + end +end \ No newline at end of file diff --git a/test/test_utils.jl b/test/test_utils.jl index effa2d8..580a11e 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -27,13 +27,118 @@ function test_utils() @testset "subs2" begin println("subs2") - e = :(x0[1] * 2xf[3] - cos(xf[2]) * 2x0[2]) - @test subs2(subs2(e, :x0, :x, 0), :xf, :x, :N) == - :(x[1, 0] * (2 * x[3, N]) - cos(x[2, N]) * (2 * x[2, 0])) + # ===== EXISTING FUNCTIONALITY (scalar indexing) ===== - e = :(x0 * 2xf[3] - cos(xf) * 2x0[2]) - @test subs2(subs2(e, :x0, :x, 0), :xf, :x, :N) == - :(x0 * (2 * x[3, N]) - cos(xf) * (2 * x[2, 0])) + @testset "scalar indexing (existing)" begin + # Test 1: Basic scalar substitution + e = :(x0[1] * 2xf[3] - cos(xf[2]) * 2x0[2]) + @test subs2(subs2(e, :x0, :x, 0), :xf, :x, :N) == + :(x[1, 0] * (2 * x[3, N]) - cos(x[2, N]) * (2 * x[2, 0])) + + # Test 2: Bare symbols are NOT substituted + e = :(x0 * 2xf[3] - cos(xf) * 2x0[2]) + @test subs2(subs2(e, :x0, :x, 0), :xf, :x, :N) == + :(x0 * (2 * x[3, N]) - cos(xf) * (2 * x[2, 0])) + + # Test 3: Numeric index + e = :(x0[5] + x0[10]) + @test subs2(e, :x0, :x, 0) == :(x[5, 0] + x[10, 0]) + + # Test 4: Symbolic index + e = :(x0[i] + x0[j]) + @test subs2(e, :x0, :x, 0) == :(x[i, 0] + x[j, 0]) + end + + # ===== NEW FUNCTIONALITY (range indexing) ===== + + @testset "range indexing (new)" begin + # Test 5: Simple range 1:3 + e = :(x0[1:3]) + result = subs2(e, :x0, :x, 0; k = :k) + @test result == :([x[k, 0] for k ∈ 1:3]) + + # Test 6: Range with step 1:2:5 + e = :(x0[1:2:5]) + result = subs2(e, :x0, :x, 0; k = :k) + @test result == :([x[k, 0] for k ∈ 1:2:5]) + + # Test 7: Range with symbolic bounds + e = :(x0[1:n]) + result = subs2(e, :x0, :x, 0; k = :k) + @test result == :([x[k, 0] for k ∈ 1:n]) + + # Test 8: Multiple ranges in same expression + e = :(x0[1:3] + xf[2:4]) + result = subs2(subs2(e, :x0, :x, 0; k = :k1), :xf, :x, :N; k = :k2) + @test result == :([x[k1, 0] for k1 ∈ 1:3] + [x[k2, N] for k2 ∈ 2:4]) + + # Test 9: Range inside function call + e = :(sum(x0[1:n])) + result = subs2(e, :x0, :x, 0; k = :k) + @test result == :(sum([x[k, 0] for k ∈ 1:n])) + end + + @testset "mixed scalar and range" begin + # Test 10: Expression with both scalars and ranges + e = :(x0[1] + x0[2:4] + x0[5]) + result = subs2(e, :x0, :x, 0; k = :k) + # x0[1] → x[1, 0] + # x0[2:4] → [x[k, 0] for k ∈ 2:4] + # x0[5] → x[5, 0] + @test result == :(x[1, 0] + [x[k, 0] for k ∈ 2:4] + x[5, 0]) + end + + @testset "nested and complex expressions" begin + # Test 11: Nested function calls with ranges + e = :(norm(x0[1:3]) + cos(x0[4])) + result = subs2(e, :x0, :x, 0; k = :k) + @test result == :(norm([x[k, 0] for k ∈ 1:3]) + cos(x[4, 0])) + + # Test 12: Range in matrix operations + e = :(A * x0[1:n]) + result = subs2(e, :x0, :x, 0; k = :k) + @test result == :(A * [x[k, 0] for k ∈ 1:n]) + + # Test 13: Multiple substitutions with symbolic j + e = :(x0[1:3] + xf[2:4]) + result = subs2(subs2(e, :x0, :x, :j; k = :k1), :xf, :x, :(j+1); k = :k2) + @test result == :([x[k1, j] for k1 ∈ 1:3] + [x[k2, j+1] for k2 ∈ 2:4]) + end + + @testset "edge cases" begin + # Test 14: Single-element range (should still create comprehension) + e = :(x0[1:1]) + result = subs2(e, :x0, :x, 0; k = :k) + @test result == :([x[k, 0] for k ∈ 1:1]) + + # Test 15: Wrong variable name (should not substitute) + e = :(y0[1:3]) + result = subs2(e, :x0, :x, 0; k = :k) + @test result == e # Unchanged + + # Test 16: Complex symbolic j expression + e = :(x0[1:3]) + result = subs2(e, :x0, :x, :grid_size; k = :k) + @test result == :([x[k, grid_size] for k ∈ 1:3]) + + # Test 17: Scalar index that is a range expression (should not match) + # This tests that we properly distinguish i (scalar) from 1:3 (range) + e = :(x0[i]) + result = subs2(e, :x0, :x, 0; k = :k) + @test result == :(x[i, 0]) # Scalar behavior + end + + @testset "backward compatibility" begin + # Test 18: Scalar indexing still works + e = :(x0[1] * 2xf[3] - cos(xf[2]) * 2x0[2]) + @test subs2(subs2(e, :x0, :x, 0), :xf, :x, :N) == + :(x[1, 0] * (2 * x[3, N]) - cos(x[2, N]) * (2 * x[2, 0])) + + # Test 19: Bare symbols are NOT substituted + e = :(x0 * 2xf[3] - cos(xf) * 2x0[2]) + @test subs2(subs2(e, :x0, :x, 0), :xf, :x, :N) == + :(x0 * (2 * x[3, N]) - cos(xf) * (2 * x[2, 0])) + end end @testset "subs3" begin @@ -44,28 +149,59 @@ function test_utils() @test subs3(e, :xf, :x, 1, :N) == :(x0[1:2:d] * (2 * x[1, N])) end - @testset "subs4" begin - println("subs4") + @testset "subs2m" begin + println("subs2m") - e = :(v[1:2:d] * 2xf[1:3]) - @test subs4(e, :v, :v, :i) == :(v[i] * (2 * xf[1:3])) - @test subs4(e, :xf, :xf, 1) == :(v[1:2:d] * (2 * xf[1])) - end + @testset "range indexing" begin + # Test 1: Basic range substitution + e = :(x0[1:3]) + result = subs2m(e, :x0, :x, 0; k = :k) + @test result == :([((x[k, 0] + x[k, 0 + 1]) / 2) for k ∈ 1:3]) + + # Test 2: Range with step + e = :(x0[1:2:5]) + result = subs2m(e, :x0, :x, 0; k = :k) + @test result == :([((x[k, 0] + x[k, 0 + 1]) / 2) for k ∈ 1:2:5]) + + # Test 3: Range in arithmetic expression + e = :(2 * x0[1:3]) + result = subs2m(e, :x0, :x, 0; k = :k) + @test result == :(2 * [((x[k, 0] + x[k, 0 + 1]) / 2) for k ∈ 1:3]) - @testset "subs5" begin - println("subs5") + # Test 4: Multiple ranges in same expression + e = :(x0[1:2] + xf[2:4]) + result = subs2m(subs2m(e, :x0, :x, 0; k = :k), :xf, :x, :N; k = :k) + @test result == :( + [((x[k, 0] + x[k, 0 + 1]) / 2) for k ∈ 1:2] + + [((x[k, N] + x[k, N + 1]) / 2) for k ∈ 2:4] + ) - e = :(x0[1] * 2xf[3] - cos(xf[2]) * 2x0[2]) - @test subs5(subs5(e, :x0, :x, 0), :xf, :x, :N) == :( - ((x[1, 0] + x[1, 0 + 1]) / 2) * (2 * ((x[3, N] + x[3, N + 1]) / 2)) - - cos((x[2, N] + x[2, N + 1]) / 2) * (2 * ((x[2, 0] + x[2, 0 + 1]) / 2)) - ) + # Test 5: Range with symbolic j + e = :(x0[1:3]) + result = subs2m(e, :x0, :x, :j; k = :k) + @test result == :([((x[k, j] + x[k, j + 1]) / 2) for k ∈ 1:3]) - e = :(x0 * 2xf[3] - cos(xf) * 2x0[2]) - @test subs5(subs5(e, :x0, :x, 0), :xf, :x, :N) == :( - x0 * (2 * ((x[3, N] + x[3, N + 1]) / 2)) - - cos(xf) * (2 * ((x[2, 0] + x[2, 0 + 1]) / 2)) - ) + # Test 6: Single-element range + e = :(x0[2:2]) + result = subs2m(e, :x0, :x, 0; k = :k) + @test result == :([((x[k, 0] + x[k, 0 + 1]) / 2) for k ∈ 2:2]) + end + + @testset "backward compatibility" begin + # Test 7: Scalar indexing still works + e = :(x0[1] * 2xf[3] - cos(xf[2]) * 2x0[2]) + @test subs2m(subs2m(e, :x0, :x, 0), :xf, :x, :N) == :( + ((x[1, 0] + x[1, 0 + 1]) / 2) * (2 * ((x[3, N] + x[3, N + 1]) / 2)) - + cos((x[2, N] + x[2, N + 1]) / 2) * (2 * ((x[2, 0] + x[2, 0 + 1]) / 2)) + ) + + # Test 8: Bare symbols are NOT substituted + e = :(x0 * 2xf[3] - cos(xf) * 2x0[2]) + @test subs2m(subs2m(e, :x0, :x, 0), :xf, :x, :N) == :( + x0 * (2 * ((x[3, N] + x[3, N + 1]) / 2)) - + cos(xf) * (2 * ((x[2, 0] + x[2, 0 + 1]) / 2)) + ) + end end @testset "replace_call" begin diff --git a/test/test_utils_bis.jl b/test/test_utils_bis.jl index 4b39ef3..6f48764 100644 --- a/test/test_utils_bis.jl +++ b/test/test_utils_bis.jl @@ -78,16 +78,15 @@ function test_utils_bis() @test constraint_type(e, t, t0, tf, x, u, v) == :variable_fun end - @testset "subs2/3/4/5 (pathological cases)" begin - println("subs2/3/4/5 (bis)") + @testset "subs2/2m/3 (pathological cases)" begin + println("subs2/2m/3 (bis)") e = :(x0[1] * 2xf[3]) # symbol does not appear at all → expression unchanged @test subs2(e, :z, :x, 0) == e + @test subs2m(e, :z, :x, 0) == e @test subs3(e, :z, :x, :i, 0) == e - @test subs4(e, :z, :z, :i) == e - @test subs5(e, :z, :x, 0) == e end @testset "replace_call (errors)" begin