Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions src/algorithms/ctmrg/gaugefix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,9 @@ function gauge_fix(envfinal::CTMRGEnv{C, T}, envprev::CTMRGEnv{C, T}, ::Scrambli
end

# Find right fixed points of mixed transfer matrices
ρinit = randn(
scalartype(T), space(Tsfinal[end], numind(Tsfinal[end]))' ← space(M[end], numind(M[end]))'
)
ρprev = transfermatrix_fixedpoint(Tsprev, M, ρinit)
ρfinal = transfermatrix_fixedpoint(Tsfinal, M, ρinit)
eigsolve_alg = Arnoldi()
ρprev = right_transfermatrix_fixedpoint(Tsprev, M, eigsolve_alg)
ρfinal = right_transfermatrix_fixedpoint(Tsfinal, M, eigsolve_alg)

# Decompose and multiply
Qprev, = left_orth!(ρprev; positive = true)
Expand Down Expand Up @@ -108,11 +106,9 @@ function gauge_fix(envfinal::CTMRGEnv{C, T}, envprev::CTMRGEnv{C, T}, ::Scrambli
M = _project_hermitian(randn(scalartype(Tfinal), space(Tfinal)))

# Find right fixed points of mixed transfer matrices
ρinit = randn(
scalartype(T), MPSKit._lastspace(Tfinal)' ← MPSKit._lastspace(M)'
)
ρprev = c4v_transfermatrix_fixedpoint(Tprev, M, ρinit)
ρfinal = c4v_transfermatrix_fixedpoint(Tfinal, M, ρinit)
eigsolve_alg = Lanczos() # real eigenvalues
ρprev = right_transfermatrix_fixedpoint([Tprev], [M], eigsolve_alg)
ρfinal = right_transfermatrix_fixedpoint([Tfinal], [M], eigsolve_alg)

# Decompose and multiply
Qprev, = left_orth!(ρprev; positive = true)
Expand Down Expand Up @@ -143,8 +139,18 @@ end
@__MODULE__, :(return @tensor $t_out := $t_top * conj($t_bot) * $t_in)
)
end
function transfermatrix_fixedpoint(tops, bottoms, ρinit)
_, vecs, info = eigsolve(ρinit, 1, :LM, Arnoldi()) do ρ

function initialize_right_fixedpoint(tops, bottoms)
ρ0 = randn(
scalartype(tops), space(tops[end], numind(tops[end]))' ← space(bottoms[end], numind(bottoms[end]))'
)
return ρ0
end
function right_transfermatrix_fixedpoint(
tops, bottoms, alg = Arnoldi(),
ρ0 = initialize_right_fixedpoint(tops, bottoms),
)
_, vecs, info = eigsolve(ρ0, 1, :LM, alg) do ρ
return foldr(zip(tops, bottoms); init = ρ) do (top, bottom), ρ
return mps_transfer_right(ρ, top, bottom)
end
Expand All @@ -154,13 +160,6 @@ function transfermatrix_fixedpoint(tops, bottoms, ρinit)
end
return first(vecs)
end
function c4v_transfermatrix_fixedpoint(top, bottom, ρinit)
_, vecs, info = eigsolve(ρinit, 1, :LM, Lanczos()) do ρ
Comment thread
lkdvos marked this conversation as resolved.
PEPSKit.mps_transfer_right(ρ, top, bottom)
end
info.converged > 0 || @warn "eigsolve did not converge"
return first(vecs)
end

# Explicit fixing of relative phases (doing this compactly in a loop is annoying)
function fix_relative_phases(envfinal::CTMRGEnv, signs)
Expand Down
67 changes: 39 additions & 28 deletions src/algorithms/optimization/fixed_point_differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ Evaluating the gradient of the cost function for CTMRG:
- With explicit evaluation of the geometric sum, the gradient is computed by differentiating the cost function with the environment kept fixed, and then manually adding the gradient contributions from the environments.
=#

_scrambling_env_gauge(::CTMRGAlgorithm) = ScramblingEnvGauge()
_scrambling_env_gauge(::C4vCTMRG) = ScramblingEnvGaugeC4v()

function _rrule(
gradmode::GradMode{:diffgauge},
config::RuleConfig,
Expand All @@ -217,20 +220,22 @@ function _rrule(
)
env, info = leading_boundary(envinit, state, alg)
alg_fixed = @set alg.projector_alg.trunc = FixedSpaceTruncation() # fix spaces during differentiation
alg_gauge = ScramblingEnvGauge() # TODO: make this a field in GradMode?
alg_gauge = _scrambling_env_gauge(alg) # TODO: make this a field in GradMode?

# prepare iterating function corresponding to a single gauge-fixed CTMRG iteration
function f(A, x)
return gauge_fix(ctmrg_iteration(InfiniteSquareNetwork(A), x, alg_fixed)[1], x, alg_gauge)[1]
end
# compute its pullback
_, env_vjp = rrule_via_ad(config, f, state, env)
# split off state and environment parts
∂f∂A(x)::typeof(state) = env_vjp(x)[2]
∂f∂x(x)::typeof(env) = env_vjp(x)[3]

function leading_boundary_diffgauge_pullback((Δenv′, Δinfo))
Δenv = unthunk(Δenv′)

# find partial gradients of gauge-fixed single CTMRG iteration
function f(A, x)
return gauge_fix(ctmrg_iteration(InfiniteSquareNetwork(A), x, alg_fixed)[1], x, alg_gauge)[1]
end
_, env_vjp = rrule_via_ad(config, f, state, env)

# evaluate the geometric sum
∂f∂A(x)::typeof(state) = env_vjp(x)[2]
∂f∂x(x)::typeof(env) = env_vjp(x)[3]
∂F∂env = fpgrad(Δenv, ∂f∂x, ∂f∂A, Δenv, gradmode)

return NoTangent(), ZeroTangent(), ∂F∂env, NoTangent()
Expand All @@ -254,22 +259,25 @@ function _rrule(
env_conv, info = ctmrg_iteration(InfiniteSquareNetwork(state), env, alg_fixed)
env_fixed, signs = gauge_fix(env_conv, env, alg_gauge)

# Fix SVD
# fix decomposition
alg_fixed = gauge_fix(alg, signs, info)

# prepare iterating function corresponding to a single CTMRG iteration with a gauge-fixed projector
function f(A, x)
return fix_global_phases(
ctmrg_iteration(InfiniteSquareNetwork(A), x, alg_fixed)[1], x,
)
end
# prepare its pullback
_, env_vjp = rrule_via_ad(config, f, state, env_fixed)
# split off state and environment parts
∂f∂A(x)::typeof(state) = env_vjp(x)[2]
∂f∂x(x)::typeof(env) = env_vjp(x)[3]

function leading_boundary_fixed_pullback((Δenv′, Δinfo))
Δenv = unthunk(Δenv′)

function f(A, x)
return fix_global_phases(
ctmrg_iteration(InfiniteSquareNetwork(A), x, alg_fixed)[1], x
)
end
_, env_vjp = rrule_via_ad(config, f, state, env_fixed)

# evaluate the geometric sum
∂f∂A(x)::typeof(state) = env_vjp(x)[2]
∂f∂x(x)::typeof(env) = env_vjp(x)[3]
∂F∂env = fpgrad(Δenv, ∂f∂x, ∂f∂A, Δenv, gradmode)

return NoTangent(), ZeroTangent(), ∂F∂env, NoTangent()
Expand Down Expand Up @@ -362,19 +370,22 @@ function _rrule(
env_conv, info = ctmrg_iteration(InfiniteSquareNetwork(state), env, alg_fixed)
_, signs = gauge_fix(env_conv, env, alg_gauge)

# Fix eigendecomposition
# fix eigendecomposition
alg_fixed = gauge_fix(alg, signs, info)

# prepare iterating function corresponding to a single CTMRG iteration with a gauge-fixed projector
f(A, x) = ctmrg_iteration(InfiniteSquareNetwork(A), x, alg_fixed)[1]
# compute its pullback
_, env_vjp = rrule_via_ad(config, f, state, env)
# split off state and environment parts
∂f∂A(x)::typeof(state) = env_vjp(x)[2]
# ∂f∂x(x)::typeof(env) = env_vjp(x)[3] # TODO: why is this derivative type-instable? The corner gradient is a complex DiagonalTensorMap
∂f∂x(x) = env_vjp(x)[3]

function leading_boundary_fixed_pullback((Δenv′, Δinfo))
Δenv = unthunk(Δenv′)

f(A, x) = ctmrg_iteration(InfiniteSquareNetwork(A), x, alg_fixed)[1]
_, env_vjp = rrule_via_ad(config, f, state, env)

# evaluate the geometric sum
∂f∂A(x)::typeof(state) = env_vjp(x)[2]
# ∂f∂x(x)::typeof(env) = env_vjp(x)[3] # TODO: why is this derivative type-instable? The corner gradient is a complex DiagonalTensorMap
∂f∂x(x) = env_vjp(x)[3]
∂F∂env = fpgrad(Δenv, ∂f∂x, ∂f∂A, Δenv, gradmode)

return NoTangent(), ZeroTangent(), ∂F∂env, NoTangent()
Expand Down Expand Up @@ -460,7 +471,7 @@ end
function fpgrad(∂F∂x, ∂f∂x, ∂f∂A, x₀, alg::EigSolver)
function f(X)
y = ∂f∂x(X[1])
return (y + X[2] * ∂F∂x, X[2])
return (VI.add!!(y, ∂F∂x, X[2]), X[2])
end
X₀ = (x₀, one(scalartype(x₀)))
_, vecs, info = realeigsolve(f, X₀, 1, :LM, alg.solver_alg)
Expand All @@ -482,7 +493,7 @@ function fpgrad(∂F∂x, ∂f∂x, ∂f∂A, x₀, alg::EigSolver)
)
return fpgrad(∂F∂x, ∂f∂x, ∂f∂A, x₀, backup_ls_alg)
else
y = scale(vecs[1][1], 1 / vecs[1][2])
y = VI.scale!!(vecs[1][1], inv(vecs[1][2]))
end

return ∂f∂A(y)
Expand Down
133 changes: 133 additions & 0 deletions test/gradients/c4v_ctmrg_gradients.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
using Test
using Random
using PEPSKit
using TensorKit
using Zygote
using OptimKit
using KrylovKit

sd = 42039482052

## Test C4v CTMRG gradients
# -------------------------------------------
χbond = 2
χenv = 6
symmetry = RotateReflect()
Pspaces = [ComplexSpace(2)]
Vspaces = [ComplexSpace(χbond)]
Espaces = [ComplexSpace(χenv)]
models = [heisenberg_XYZ(InfiniteSquare())]
names = ["Heisenberg"]

gradtol = 1.0e-4
ctmrg_verbosity = 1
ctmrg_algs = [[:c4v]]
projector_algs = [[:c4v_eigh]]
eigh_rrule_algs = [[:full, :trunc]] # TODO: handle projector-algorithm-dependence
gradient_algs = [[nothing, :geomsum, :manualiter, :linsolver, :eigsolver]]
gradient_iterschemes = [[:fixed, :diffgauge]]
steps = -0.01:0.005:0.01

# be selective on which configurations to test the naive gradient for
naive_gradient_combinations = [
(:c4v, :c4v_eigh, :full, :fixed),
(:c4v, :c4v_eigh, :trunc, :diffgauge),
]
naive_gradient_done = Set()

# mark the broken configurations as broken explicitly
broken_gradients = Dict(
"eigh_rrule_alg" => Set([:full]),
"gradient_iterscheme" => Set([:diffgauge]),
)
function _check_broken(
ctmrg_alg, projector_alg, eigh_rrule_alg, gradient_alg, gradient_iterscheme
)
# naive gradients should always work
isnothing(gradient_alg) && return false
# evaluate brokenness
eigh_rrule_alg in broken_gradients["eigh_rrule_alg"] && return true
gradient_iterscheme in broken_gradients["gradient_iterscheme"] && return true
return false
end


## Tests
# ------
@testset "AD C4v CTMRG energy gradients for $(names[i]) model" verbose = true for i in
eachindex(
models
)
Pspace = Pspaces[i]
Vspace = Vspaces[i]
Espace = Espaces[i]
calgs = ctmrg_algs[i]
palgs = projector_algs[i]
ealgs = eigh_rrule_algs[i]
galgs = gradient_algs[i]
gischemes = gradient_iterschemes[i]
@testset "ctmrg_alg=:$ctmrg_alg, projector_alg=:$projector_alg, eigh_rrule_alg=:$eigh_rrule_alg, gradient_alg=:$gradient_alg, gradient_iterscheme=:$gradient_iterscheme" for (
ctmrg_alg, projector_alg, eigh_rrule_alg, gradient_alg, gradient_iterscheme,
) in Iterators.product(
calgs, palgs, ealgs, galgs, gischemes
)

# check for allowed algorithm combinations when testing naive gradient
if isnothing(gradient_alg)
combo = (ctmrg_alg, projector_alg, eigh_rrule_alg, gradient_iterscheme)
combo in naive_gradient_combinations || continue
combo in naive_gradient_done && continue
push!(naive_gradient_done, combo)
end

@info "optimtest of ctmrg_alg=:$ctmrg_alg, projector_alg=:$projector_alg, eigh_rrule_alg=:$eigh_rrule_alg, gradient_alg=:$gradient_alg and gradient_iterscheme=:$gradient_iterscheme on $(names[i])"
Random.seed!(sd)
dir = InfinitePEPS(Pspace, Vspace)
psi = InfinitePEPS(Pspace, Vspace)
symmetrize!(psi, symmetry)
symmetrize!(dir, symmetry)
# instantiate to avoid having to type this twice...
contrete_ctmrg_alg = PEPSKit.CTMRGAlgorithm(;
alg = ctmrg_alg,
verbosity = ctmrg_verbosity,
projector_alg = projector_alg,
decomposition_alg = (; rrule_alg = (; alg = eigh_rrule_alg)),
)
# instantiate because hook_pullback doesn't go through the keyword selector...
concrete_gradient_alg = if isnothing(gradient_alg)
nothing # TODO: add this to the PEPSKit.GradMode selector?
else
PEPSKit.GradMode(; alg = gradient_alg, tol = gradtol, iterscheme = gradient_iterscheme)
end
env0 = PEPSKit.initialize_random_c4v_env(psi, Espace)
env, = leading_boundary(env0, psi, contrete_ctmrg_alg)
alphas, fs, dfs1, dfs2 = OptimKit.optimtest(
(psi, env),
dir;
alpha = steps,
retract = PEPSKit.peps_retract,
inner = PEPSKit.real_inner,
) do (peps, env)
E, g = Zygote.withgradient(peps) do psi
env2, = PEPSKit.hook_pullback(
leading_boundary,
env,
psi,
contrete_ctmrg_alg;
alg_rrule = concrete_gradient_alg,
)
return cost_function(psi, env2, models[i])
end
g = only(g)
symmetrize!(g, symmetry)
return E, g
end
if _check_broken(
ctmrg_alg, projector_alg, eigh_rrule_alg, gradient_alg, gradient_iterscheme
)
@test_broken dfs1 ≈ dfs2 atol = 1.0e-2
else
@test dfs1 ≈ dfs2 atol = 1.0e-2
end
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ end
@time @safetestset "CTMRG gradients" begin
include("gradients/ctmrg_gradients.jl")
end
@time @safetestset "C4v CTMRG gradients" begin
include("gradients/c4v_ctmrg_gradients.jl")
end
end
if GROUP == "ALL" || GROUP == "BOUNDARYMPS"
@time @safetestset "VUMPS" begin
Expand Down
Loading