Skip to content

Commit c2cef34

Browse files
authored
[Utilities] add distance_to_set for more sets (#2925)
1 parent 7f1c089 commit c2cef34

2 files changed

Lines changed: 193 additions & 23 deletions

File tree

src/Utilities/distance_to_set.jl

Lines changed: 116 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,39 @@ function _check_dimension(v::AbstractVector, s)
175175
return
176176
end
177177

178+
function _reshape(
179+
x::AbstractVector,
180+
set::Union{
181+
MOI.PositiveSemidefiniteConeSquare,
182+
MOI.LogDetConeSquare,
183+
MOI.RootDetConeSquare,
184+
},
185+
)
186+
n = isqrt(length(x))
187+
return reshape(x, (n, n))
188+
end
189+
190+
function _reshape(
191+
x::AbstractVector{T},
192+
set::Union{
193+
MOI.PositiveSemidefiniteConeTriangle,
194+
MOI.LogDetConeTriangle,
195+
MOI.RootDetConeTriangle,
196+
},
197+
) where {T}
198+
n = isqrt(2 * length(x))
199+
# The type annotation is needed for JET.
200+
X = zeros(T, n, n)::Matrix{T}
201+
k = 1
202+
for i in 1:n
203+
for j in 1:i
204+
X[j, i] = X[i, j] = x[k]
205+
k += 1
206+
end
207+
end
208+
return LinearAlgebra.Symmetric(X)
209+
end
210+
178211
# This is the minimal L2-norm.
179212
function distance_to_set(
180213
::ProjectionUpperBoundDistance,
@@ -499,28 +532,6 @@ function distance_to_set(
499532
return LinearAlgebra.norm(elements, 2)
500533
end
501534

502-
function _reshape(x::AbstractVector, set::MOI.PositiveSemidefiniteConeSquare)
503-
n = MOI.side_dimension(set)
504-
return reshape(x, (n, n))
505-
end
506-
507-
function _reshape(
508-
x::AbstractVector{T},
509-
set::MOI.PositiveSemidefiniteConeTriangle,
510-
) where {T}
511-
n = MOI.side_dimension(set)
512-
# The type annotation is needed for JET.
513-
X = zeros(T, n, n)::Matrix{T}
514-
k = 1
515-
for i in 1:n
516-
for j in 1:i
517-
X[j, i] = X[i, j] = x[k]
518-
k += 1
519-
end
520-
end
521-
return LinearAlgebra.Symmetric(X)
522-
end
523-
524535
"""
525536
distance_to_set(
526537
::ProjectionUpperBoundDistance,
@@ -608,3 +619,86 @@ function distance_to_set(
608619
sqrt(x[1]^2 + distance_to_set(distance, x[2], set.set)^2),
609620
)
610621
end
622+
623+
"""
624+
distance_to_set(::ProjectionUpperBoundDistance, x, set::MOI.NormNuclearCone)
625+
626+
Let `(t, y...) = x`. Return the epigraph distance `d` such that `(t + d, y...)`
627+
belongs to the set.
628+
"""
629+
function distance_to_set(
630+
::ProjectionUpperBoundDistance,
631+
x::AbstractVector{T},
632+
set::MOI.NormNuclearCone,
633+
) where {T}
634+
_check_dimension(x, set)
635+
X = reshape(x[2:end], set.row_dim, set.column_dim)
636+
return max(sum(LinearAlgebra.svdvals(X)) - x[1], zero(T))
637+
end
638+
639+
"""
640+
distance_to_set(::ProjectionUpperBoundDistance, x, set::MOI.NormSpectralCone)
641+
642+
Let `(t, y...) = x`. Return the epigraph distance `d` such that `(t + d, y...)`
643+
belongs to the set.
644+
"""
645+
function distance_to_set(
646+
::ProjectionUpperBoundDistance,
647+
x::AbstractVector{T},
648+
set::MOI.NormSpectralCone,
649+
) where {T}
650+
_check_dimension(x, set)
651+
X = reshape(x[2:end], set.row_dim, set.column_dim)
652+
return max(maximum(LinearAlgebra.svdvals(X)) - x[1], zero(T))
653+
end
654+
655+
"""
656+
distance_to_set(
657+
::ProjectionUpperBoundDistance,
658+
x::AbstractVector,
659+
set::Union{MOI.RootDetConeSquare,MOI.RootDetConeTriangle},
660+
)
661+
662+
Let ``Y`` be `y` in `x = (t, y)`, reshaped into the appropriate matrix. The
663+
returned distance is ``||Y - Z||_2^2`` where ``Z`` is the eigen decomposition of
664+
``Y`` with negative eigen values removed, plus the epigraph distance in `t`
665+
needed to satisfy the root-determinant constraint.
666+
"""
667+
function distance_to_set(
668+
::ProjectionUpperBoundDistance,
669+
x::AbstractVector{T},
670+
set::Union{MOI.RootDetConeSquare,MOI.RootDetConeTriangle},
671+
) where {T<:Real}
672+
_check_dimension(x, set)
673+
eigvals = LinearAlgebra.eigvals(_reshape(x[2:end], set))
674+
eigvals_neg = min.(zero(T), eigvals)
675+
eigvals_pos = max.(zero(T), eigvals)
676+
rootdet = prod(eigvals_pos)^(1 / set.side_dimension)
677+
push!(eigvals_neg, max(x[1] - rootdet, zero(T)))
678+
return LinearAlgebra.norm(eigvals_neg, 2)
679+
end
680+
681+
"""
682+
distance_to_set(
683+
::ProjectionUpperBoundDistance,
684+
x::AbstractVector,
685+
set::Union{MOI.LogDetConeSquare,MOI.LogDetConeTriangle},
686+
)
687+
688+
Let ``Y`` be `y` in `x = (t, y)`, reshaped into the appropriate matrix. The
689+
returned distance is ``||Y/u - Z||_2^2`` where ``Z`` is the eigen decomposition
690+
of ``Y`` with negative eigen values removed, plus the epigraph distance in `t`
691+
needed to satisfy the log-determinant constraint.
692+
"""
693+
function distance_to_set(
694+
::ProjectionUpperBoundDistance,
695+
x::AbstractVector{T},
696+
set::Union{MOI.LogDetConeSquare,MOI.LogDetConeTriangle},
697+
) where {T<:Real}
698+
_check_dimension(x, set)
699+
eigvals = LinearAlgebra.eigvals(_reshape(x[3:end] ./ x[2], set))
700+
eigvals_neg = min.(eps(T), eigvals)
701+
eigvals_pos = max.(eps(T), eigvals)
702+
push!(eigvals_neg, max(x[1] - x[2] * sum(log.(eigvals_pos)), zero(T)))
703+
return LinearAlgebra.norm(eigvals_neg, 2)
704+
end

test/Utilities/distance_to_set.jl

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,83 @@ function test_indicator()
394394
[0.99, 1.0] => 0.01,
395395
[0.99, 1.1] => sqrt(0.01^2 + 0.1^2),
396396
[0.5, 1.1] => 0.5,
397-
[0.8, 1.1] => sqrt(0.2^2 + 0.1^2),
397+
[0.8, 1.1] => sqrt(0.2^2 + 0.1^2);
398+
mismatch = [1.0],
399+
)
400+
return
401+
end
402+
403+
function test_NormNuclearCone()
404+
_test_set(
405+
MOI.NormNuclearCone(2, 3),
406+
[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] => 10.039818672223756,
407+
[11.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] => 0.0;
408+
mismatch = [1.0],
409+
)
410+
return
411+
end
412+
413+
function test_NormSpectralCone()
414+
_test_set(
415+
MOI.NormSpectralCone(2, 3),
416+
[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] => 9.525518091565111,
417+
[9.6, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] => 0.0;
418+
mismatch = [1.0],
419+
)
420+
return
421+
end
422+
423+
function test_RootDetConeTriangle()
424+
_test_set(
425+
MOI.RootDetConeTriangle(2),
426+
[2.0, 1.0, 0.0, 1.0] => 1.0,
427+
[0.9, 1.0, 0.0, 1.0] => 0.0,
428+
[2.0, 1.0, 0.0, 2.0] => 2 - sqrt(2),
429+
# Projection onto PSD
430+
[0.0, 1.0, 2.0, 3.0] => 0.2360679774997897,
431+
# Projection onto PSD+t
432+
[1.0, 1.0, 2.0, 3.0] => sqrt(1 + 0.2360679774997897^2);
433+
mismatch = [1.0],
434+
)
435+
return
436+
end
437+
438+
function test_RootDetConeSquare()
439+
_test_set(
440+
MOI.RootDetConeSquare(2),
441+
[2.0, 1.0, 0.0, 0.0, 1.0] => 1.0,
442+
[0.9, 1.0, 0.0, 0.0, 1.0] => 0.0,
443+
[2.0, 1.0, 0.0, 0.0, 2.0] => 2 - sqrt(2),
444+
# Projection onto PSD
445+
[0.0, 1.0, 2.0, 2.0, 3.0] => 0.2360679774997897,
446+
# Projection onto PSD+t
447+
[1.0, 1.0, 2.0, 2.0, 3.0] => sqrt(1 + 0.2360679774997897^2);
448+
mismatch = [1.0],
449+
)
450+
return
451+
end
452+
453+
function test_LogDetConeTriangle()
454+
_test_set(
455+
MOI.LogDetConeTriangle(2),
456+
[2.0, 1.0, 2.0, 0.0, 1.0] => 2 - 0.6931471805599453,
457+
[0.69, 1.0, 2.0, 0.0, 1.0] => 0.0,
458+
[0.0, 2.0, 2.0, 0.0, 1.0] => 1.3862943611198906,
459+
mismatch = [1.0],
460+
)
461+
return
462+
end
463+
464+
function test_LogDetConeSquare()
465+
_test_set(
466+
MOI.LogDetConeSquare(2),
467+
[2.0, 1.0, 2.0, 0.0, 0.0, 1.0] => 2 - 0.6931471805599453,
468+
[0.69, 1.0, 2.0, 0.0, 0.0, 1.0] => 0.0,
469+
[0.0, 2.0, 2.0, 0.0, 0.0, 1.0] => 1.3862943611198906,
470+
# Projection onto PSD
471+
[0.0, 1.0, 1.0, 2.0, 2.0, 3.0] => 34.60082322336934,
472+
# Projection onto PSD+t
473+
[1.0, 1.0, 1.0, 2.0, 2.0, 3.0] => 35.60080060283381;
398474
mismatch = [1.0],
399475
)
400476
return

0 commit comments

Comments
 (0)