diff --git a/Project.toml b/Project.toml index b193577..d9524c7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Static" uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" authors = ["chriselrod", "ChrisRackauckas", "Tokazama"] -version = "1.3.0" +version = "1.3.1" [deps] CommonWorldInvalidations = "f70d9fcc-98c5-4d4a-abd7-e4cdeebd8ca8" diff --git a/src/Static.jl b/src/Static.jl index 207cf94..0e25f60 100644 --- a/src/Static.jl +++ b/src/Static.jl @@ -4,8 +4,7 @@ import IfElse: ifelse using SciMLPublic: @public export StaticInt, StaticFloat64, StaticSymbol, True, False, StaticBool, NDIndex -export dynamic, is_static, known, static, static_promote, static_first, static_step, - static_last +export dynamic, is_static, known, static, static_promote @public OptionallyStaticRange, OptionallyStaticUnitRange, OptionallyStaticStepRange, SUnitRange, SOneTo @@ -353,18 +352,20 @@ julia> static_promote(1:2:9, static(1):static(2):static(9)) static(1):static(2):static(9) ``` """ -Base.@propagate_inbounds @inline function static_promote(x::AbstractUnitRange{<:Integer}, - y::AbstractUnitRange{<:Integer}) - fst = static_promote(static_first(x), static_first(y)) - lst = static_promote(static_last(x), static_last(y)) - return OptionallyStaticUnitRange(fst, lst) -end -Base.@propagate_inbounds @inline function static_promote(x::AbstractRange{<:Integer}, - y::AbstractRange{<:Integer}) - fst = static_promote(static_first(x), static_first(y)) - stp = static_promote(static_step(x), static_step(y)) - lst = static_promote(static_last(x), static_last(y)) - return _OptionallyStaticStepRange(fst, stp, lst) +@inline function static_promote( + x0::AbstractRange{<:Integer}, + y0::AbstractRange{<:Integer}, +) + x = OptionallyStaticStepRange(x0) + y = OptionallyStaticStepRange(y0) + fst = static_promote(getfield(x, :start), getfield(y, :start)) + stp = static_promote(getfield(x, :step), getfield(y, :step)) + lst = static_promote(getfield(x, :stop), getfield(y, :stop)) + if isa(stp, One) + return _OptionallyStaticUnitRange(fst, lst) + else + return _OptionallyStaticStepRange(fst, stp, lst) + end end function static_promote(x::Base.Slice, y::Base.Slice) Base.Slice(static_promote(x.indices, y.indices)) diff --git a/src/ranges.jl b/src/ranges.jl index 3de504e..0b28dd7 100644 --- a/src/ranges.jl +++ b/src/ranges.jl @@ -12,26 +12,9 @@ struct OptionallyStaticUnitRange{F <: IntType, L <: IntType} <: start::F stop::L - function OptionallyStaticUnitRange(start::IntType, - stop::IntType) + global function _OptionallyStaticUnitRange(start::IntType, stop::IntType) new{typeof(start), typeof(stop)}(start, stop) end - function OptionallyStaticUnitRange(start, stop) - OptionallyStaticUnitRange(IntType(start), IntType(stop)) - end - OptionallyStaticUnitRange(@nospecialize x::OptionallyStaticUnitRange) = x - function OptionallyStaticUnitRange(x::AbstractRange) - step(x) == 1 && return OptionallyStaticUnitRange(static_first(x), static_last(x)) - - errmsg(x) = throw(ArgumentError("step must be 1, got $(step(x))")) # avoid GC frame - errmsg(x) - end - function OptionallyStaticUnitRange{F, L}(x::AbstractRange) where {F, L} - OptionallyStaticUnitRange(x) - end - function OptionallyStaticUnitRange{StaticInt{F}, StaticInt{L}}() where {F, L} - new{StaticInt{F}, StaticInt{L}}() - end end """ @@ -97,9 +80,20 @@ function OptionallyStaticStepRange(@nospecialize(start::IntType), end end OptionallyStaticStepRange(@nospecialize x::OptionallyStaticStepRange) = x +function OptionallyStaticStepRange(x::Union{Base.Slice, Base.IdentityUnitRange}) + OptionallyStaticStepRange(x.indices) +end +function OptionallyStaticStepRange(x::Base.OneTo) + _OptionallyStaticStepRange(static(1), static(1), Int(last(x))) +end +function OptionallyStaticStepRange(x::OptionallyStaticUnitRange) + _OptionallyStaticStepRange(getfield(x, :start), static(1), getfield(x, :stop)) +end +function OptionallyStaticStepRange(x::AbstractUnitRange) + _OptionallyStaticStepRange(Int(first(x)), static(1), Int(last(x))) +end function OptionallyStaticStepRange(x::AbstractRange) - _OptionallyStaticStepRange(IntType(static_first(x)), IntType(static_step(x)), - IntType(static_last(x))) + _OptionallyStaticStepRange(Int(first(x)), Int(step(x)), Int(last(x))) end # to make StepRange constructor inlineable, so optimizer can see `step` value @@ -129,6 +123,37 @@ end end end +OptionallyStaticUnitRange(@nospecialize x::OptionallyStaticUnitRange) = x +OptionallyStaticUnitRange(x::Base.OneTo) = OptionallyStaticUnitRange(static(1), Int(last(x))) +function OptionallyStaticUnitRange(x::Union{Base.Slice, Base.IdentityUnitRange}) + OptionallyStaticUnitRange(x.indices) +end +function OptionallyStaticUnitRange(x::OptionallyStaticStepRange) + assert_unit_step(step(x)) + _OptionallyStaticUnitRange(getfield(x, :start), getfield(x, :stop)) +end +function OptionallyStaticUnitRange(x::AbstractRange) + assert_unit_step(step(x)) + _OptionallyStaticUnitRange(first(x), last(x)) +end +function OptionallyStaticUnitRange{F, L}(x::AbstractRange) where {F, L} + OptionallyStaticUnitRange(x) +end +function OptionallyStaticUnitRange(start::IntType, stop::IntType) + _OptionallyStaticUnitRange(start, stop) +end +function OptionallyStaticUnitRange(start, stop) + OptionallyStaticUnitRange(IntType(start), IntType(stop)) +end +function OptionallyStaticUnitRange{StaticInt{F}, StaticInt{L}}() where {F, L} + _OptionallyStaticUnitRange(StaticInt{F}(), StaticInt{L}()) +end +function assert_unit_step(s::Int) + s == 1 && return nothing + errmsg(x) = throw(ArgumentError(LazyString("step must be 1, got ", s))) # avoid GC frame + errmsg(s) +end + """ SUnitRange(start::Int, stop::Int) @@ -150,82 +175,6 @@ const OptionallyStaticRange{ F, L} = Union{OptionallyStaticUnitRange{F, L}, OptionallyStaticStepRange{F, <:Any, L}} -""" - static_first(x::AbstractRange) - -Attempt to return `static(first(x))`, if known at compile time. Otherwise, return -`first(x)`. - -See also: [`static_step`](@ref), [`static_last`](@ref) - -# Examples - -```julia -julia> static_first(static(2):10) -static(2) - -julia> static_first(1:10) -1 - -julia> static_first(Base.OneTo(10)) -static(1) - -``` -""" -static_first(x::Base.OneTo) = StaticInt(1) -static_first(x::Union{Base.Slice, Base.IdentityUnitRange}) = static_first(x.indices) -static_first(x::OptionallyStaticRange) = getfield(x, :start) -static_first(x) = first(x) - -""" - static_step(x::AbstractRange) - -Attempt to return `static(step(x))`, if known at compile time. Otherwise, return -`step(x)`. - -See also: [`static_first`](@ref), [`static_last`](@ref) - -# Examples - -```julia -julia> static_step(static(1):static(3):9) -static(3) - -julia> static_step(1:3:9) -3 - -julia> static_step(1:9) -static(1) - -``` -""" -static_step(@nospecialize x::AbstractUnitRange) = StaticInt(1) -static_step(x::OptionallyStaticStepRange) = getfield(x, :step) -static_step(x) = step(x) - -""" - static_last(x::AbstractRange) - -Attempt to return `static(last(x))`, if known at compile time. Otherwise, return -`last(x)`. - -See also: [`static_first`](@ref), [`static_step`](@ref) - -# Examples - -```julia -julia> static_last(static(1):static(10)) -static(10) - -julia> static_last(static(1):10) -10 - -``` -""" -static_last(x::OptionallyStaticRange) = getfield(x, :stop) -static_last(x) = last(x) -static_last(x::Union{Base.Slice, Base.IdentityUnitRange}) = static_last(x.indices) - Base.first(x::OptionallyStaticRange{Int}) = getfield(x, :start) Base.first(::OptionallyStaticRange{StaticInt{F}}) where {F} = F Base.step(x::OptionallyStaticStepRange{<:Any, Int}) = getfield(x, :step) @@ -285,12 +234,15 @@ function Base.checkindex(::Type{Bool}, (F1::Int <= F2::Int) && (L1::Int >= L2::Int) end -function Base.getindex(r::OptionallyStaticUnitRange, - s::AbstractUnitRange{<:Integer}) +function Base.getindex( + r::OptionallyStaticUnitRange, + i::AbstractUnitRange{<:Integer}, +) + s = OptionallyStaticUnitRange(i) @boundscheck checkbounds(r, s) - f = static_first(r) + f = getfield(r, :start) fnew = f - one(f) - return (fnew + static_first(s)):(fnew + static_last(s)) + return (fnew + getfield(s, :start)):(fnew + getfield(s, :stop)) end function Base.getindex(x::OptionallyStaticUnitRange{StaticInt{1}}, i::Int) @@ -320,11 +272,10 @@ end Base.AbstractUnitRange{Int}(r::OptionallyStaticUnitRange) = r function Base.AbstractUnitRange{T}(r::OptionallyStaticUnitRange) where {T} - start = static_first(r) - if isa(start, StaticInt{1}) && T <: Integer - return Base.OneTo{T}(T(static_last(r))) + if isa(getfield(r, :start), StaticInt{1}) && T <: Integer + return Base.OneTo{T}(T(last(r))) else - return UnitRange{T}(T(static_first(r)), T(static_last(r))) + return UnitRange{T}(T(first(r)), T(last(r))) end end @@ -373,7 +324,10 @@ end Base.axes1(x::Base.Slice{<:OptionallyStaticUnitRange{One}}) = x.indices Base.axes1(x::Base.Slice{<:OptionallyStaticRange}) = Base.IdentityUnitRange(x.indices) -Base.:(-)(r::OptionallyStaticRange) = (-static_first(r)):(-static_step(r)):(-static_last(r)) +function Base.:(-)(r::OptionallyStaticRange) + s = isa(r, OptionallyStaticStepRange) ? -getfield(r, :step) : -One() + (-getfield(r, :start)):s:(-getfield(r, :stop)) +end function Base.reverse(x::OptionallyStaticUnitRange) _OptionallyStaticStepRange(getfield(x, :stop), StaticInt(-1), getfield(x, :start)) @@ -435,25 +389,25 @@ end function Base.first(x::OptionallyStaticUnitRange, n::IntType) n < 0 && throw(ArgumentError("Number of elements must be nonnegative")) - start = static_first(x) - OptionallyStaticUnitRange(start, min(start - one(start) + n, static_last(x))) + start = getfield(x, :start) + OptionallyStaticUnitRange(start, min(start - one(start) + n, getfield(x, :stop))) end function Base.first(x::OptionallyStaticStepRange, n::IntType) n < 0 && throw(ArgumentError("Number of elements must be nonnegative")) - start = static_first(x) - s = static_step(x) - stop = min(((n - one(n)) * s) + static_first(x), static_last(x)) + start = getfield(x, :start) + s = getfield(x, :step) + stop = min(((n - one(n)) * s) + start, getfield(x, :stop)) OptionallyStaticStepRange(start, s, stop) end function Base.last(x::OptionallyStaticUnitRange, n::IntType) n < 0 && throw(ArgumentError("Number of elements must be nonnegative")) - stop = static_last(x) - OptionallyStaticUnitRange(max(stop + one(stop) - n, static_first(x)), stop) + stop = getfield(x, :stop) + OptionallyStaticUnitRange(max(stop + one(stop) - n, getfield(x, :start)), stop) end function Base.last(x::OptionallyStaticStepRange, n::IntType) n < 0 && throw(ArgumentError("Number of elements must be nonnegative")) - start = static_first(x) - s = static_step(x) - stop = static_last(x) + start = getfield(x, :start) + s = getfield(x, :step) + stop = getfield(x, :stop) OptionallyStaticStepRange(max(stop + one(stop) - (n * s), start), s, stop) end diff --git a/test/ranges.jl b/test/ranges.jl index 1e27ae9..dcbf194 100644 --- a/test/ranges.jl +++ b/test/ranges.jl @@ -116,9 +116,6 @@ end @test_throws BoundsError getindex(Static.OptionallyStaticUnitRange(static(1), 10), 11) @test_throws BoundsError getindex(Static.OptionallyStaticStepRange(static(1), 2, 10), 11) -@test Static.static_first(Base.OneTo(one(UInt))) === static(1) -@test Static.static_step(Base.OneTo(one(UInt))) === static(1) - @test @inferred(eachindex(static(-7):static(7))) === static(1):static(15) @test @inferred((static(-7):static(7))[first(eachindex(static(-7):static(7)))]) == -7