Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ITensorNetworks"
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
authors = ["Matthew Fishman <[email protected]>, Joseph Tindall <[email protected]> and contributors"]
version = "0.15.4"
version = "0.15.5"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
2 changes: 1 addition & 1 deletion src/solvers/applyexp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ end
function default_sweep_callback(
sweep_iterator::SweepIterator{<:ApplyExpProblem};
exponent_description = "exponent",
outputlevel = 0,
process_time = identity,
)
outputlevel = get(region_kwargs(region_iterator(sweep_iterator)), :outputlevel, 0)
return if outputlevel >= 1
the_problem = problem(sweep_iterator)
@printf(
Expand Down
14 changes: 8 additions & 6 deletions src/solvers/eigsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ end
function update!(
region_iter::RegionIterator{<:EigsolveProblem},
local_state;
outputlevel = 0,
solver = eigsolve_solver,
solver = eigsolve_solver
)
prob = problem(region_iter)

Expand All @@ -34,15 +33,17 @@ function update!(

prob.eigenvalue = eigval

outputlevel = get(region_kwargs(region_iter), :outputlevel, 0)
if outputlevel >= 2
@printf(" Region %s: energy = %.12f\n", current_region(region_iter), eigenvalue(prob))
end
return region_iter, local_state
end

function default_sweep_callback(
sweep_iterator::SweepIterator{<:EigsolveProblem}; outputlevel = 0
sweep_iterator::SweepIterator{<:EigsolveProblem}
)
outputlevel = get(region_kwargs(region_iterator(sweep_iterator)), :outputlevel, 0)
return if outputlevel >= 1
nsweeps = length(sweep_iterator)
current_sweep = sweep_iterator.which_sweep
Expand All @@ -51,9 +52,10 @@ function default_sweep_callback(
else
@printf("After sweep %d/%d ", current_sweep, nsweeps)
end
@printf("eigenvalue=%.12f", eigenvalue(problem))
@printf(" maxlinkdim=%d", maxlinkdim(state(problem)))
@printf(" max truncerror=%d", max_truncerror(problem))
current_problem = problem(sweep_iterator)
@printf("eigenvalue=%.12f", eigenvalue(current_problem))
@printf(" maxlinkdim=%d", maxlinkdim(current_problem))
@printf(" max truncerror=%d", max_truncerror(current_problem))
println()
flush(stdout)
end
Expand Down
29 changes: 21 additions & 8 deletions src/solvers/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ abstract type AbstractNetworkIterator end
islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator)

function Base.iterate(iterator::AbstractNetworkIterator, init = true)
islaststep(iterator) && return nothing
# The assumption is that first "increment!" is implicit, therefore we must skip the
# the termination check for the first iteration, i.e. `AbstractNetworkIterator` is not
# defined when length < 1,
init || islaststep(iterator) && return nothing
# We seperate increment! from step! and demand that any AbstractNetworkIterator *must*
# define a method for increment! This way we avoid cases where one may wish to nest
# calls to different step! methods accidentaly incrementing multiple times.
Expand Down Expand Up @@ -44,6 +47,9 @@ mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator
which_region::Int
const which_sweep::Int
function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P, R}
if length(region_plan) == 0
throw(BoundsError("Cannot construct a region iterator with 0 elements."))
end
return new{P, R}(problem, region_plan, 1, sweep)
end
end
Expand Down Expand Up @@ -115,26 +121,33 @@ region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs

mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator
region_iter::RegionIterator{Problem}
sweep_kwargs::Iterators.Stateful{Iter}
sweep_kwargs::Iter
which_sweep::Int
nsweeps::Int
function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob, Iter}
stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs)
first_kwargs, _ = Iterators.peel(stateful_sweep_kwargs)
first_state = Iterators.peel(sweep_kwargs)
if isnothing(first_state)
throw(BoundsError("Cannot construct a sweep iterator with 0 elements."))
end
first_kwargs, sweep_kwargs_rest = first_state
region_iter = RegionIterator(problem; sweep = 1, first_kwargs...)
return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1)
return new{Prob, typeof(sweep_kwargs_rest)}(region_iter, sweep_kwargs_rest, 1, length(sweep_kwargs))
end
end

islaststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs))
islaststep(sweep_iter::SweepIterator) = isempty(sweep_iter.sweep_kwargs)

region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter

problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter))

state(sweep_iter::SweepIterator) = sweep_iter.which_sweep
Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kwargs)

Base.length(sweep_iter::SweepIterator) = sweep_iter.nsweeps

function increment!(sweep_iter::SweepIterator)
sweep_iter.which_sweep += 1
sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kwargs)
sweep_kwargs, sweep_iter.sweep_kwargs = Iterators.peel(sweep_iter.sweep_kwargs)
update_region_iterator!(sweep_iter; sweep_kwargs...)
return sweep_iter
end
Expand Down
2 changes: 1 addition & 1 deletion src/solvers/region_plans/euler_plans.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Graphs: dst, src
using NamedGraphs.GraphsExtensions: default_root_vertex

function euler_sweep(graph; nsites, root_vertex = default_root_vertex(graph), sweep_kwargs...)
function euler_sweep(graph; nsites = 1, root_vertex = default_root_vertex(graph), sweep_kwargs...)
sweep_kwargs = (; nsites, root_vertex, sweep_kwargs...)

if nsites == 1
Expand Down
2 changes: 1 addition & 1 deletion test/solvers/test_applyexp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ end

nsites = 2
factorize_kwargs = (; cutoff, maxdim)
E, gs_psi = dmrg(H, psi0; factorize_kwargs, nsites, nsweeps, outputlevel)
E, gs_psi = dmrg(H, psi0; factorize_kwargs, nsites, nsweeps, outputlevel = 0)
(outputlevel >= 1) && println("2-site DMRG energy = ", E)

nsites = 1
Expand Down
60 changes: 60 additions & 0 deletions test/solvers/test_sweepiterator.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
using Test: @test, @testset
using ITensorNetworks: ITensorNetworks, AbstractProblem, RegionIterator, SweepIterator, compute!, region_iterator, region_kwargs

include("utilities/tree_graphs.jl")

# TestProblem type for testing
struct TestProblem <: AbstractProblem
graph
end

ITensorNetworks.state(T::TestProblem) = T.graph

ITensorNetworks.compute!(R::RegionIterator{<:TestProblem}) = "TestProblem Compute"


@testset "SweepIterator Basics" begin
g = build_tree(; nbranch = 3, nbranch_sites = 3)
prob = TestProblem(g)

nsweeps = 5

# Basic construction, taking length
sweep_iter = SweepIterator(prob, nsweeps)
@test length(sweep_iter) == nsweeps

# Pass keyword parameters
test_kwarg_a = 1
test_kwarg_b = "b"
sweep_iter = SweepIterator(prob, nsweeps; test_kwarg_a, test_kwarg_b)
@test region_kwargs(region_iterator(sweep_iter)).test_kwarg_a == test_kwarg_a
@test region_kwargs(region_iterator(sweep_iter)).test_kwarg_b == test_kwarg_b

# Pass array of parameters
kws_array = [(; outputlevel = 0), (; outputlevel = 1)]
sweep_iter = SweepIterator(prob, kws_array)
@test length(sweep_iter) == length(kws_array)
@test region_kwargs(region_iterator(sweep_iter)).outputlevel == 0
end

@testset "SweepIterator Iteration" begin
g = build_tree(; nbranch = 3, nbranch_sites = 3)
prob = TestProblem(g)

nsweeps = 5
sweep_iter = SweepIterator(prob, nsweeps)
count = 0
for _ in sweep_iter
count += 1
end
@test count == nsweeps

# Test case of one iteration
nsweeps = 1
sweep_iter = SweepIterator(prob, nsweeps)
count = 0
for _ in sweep_iter
count += 1
end
@test count == nsweeps
end
Loading