Skip to content

Commit 1cbe497

Browse files
authored
[Utilities] add distance_to_set for more sets (#2926)
1 parent c2cef34 commit 1cbe497

2 files changed

Lines changed: 102 additions & 5 deletions

File tree

src/Utilities/distance_to_set.jl

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ function _reshape(
194194
MOI.LogDetConeTriangle,
195195
MOI.RootDetConeTriangle,
196196
},
197-
) where {T}
197+
) where {T<:Real}
198198
n = isqrt(2 * length(x))
199199
# The type annotation is needed for JET.
200200
X = zeros(T, n, n)::Matrix{T}
@@ -208,6 +208,24 @@ function _reshape(
208208
return LinearAlgebra.Symmetric(X)
209209
end
210210

211+
function _reshape(
212+
x::AbstractVector{T},
213+
set::MOI.PositiveSemidefiniteConeTriangle,
214+
) where {T<:Complex}
215+
n = isqrt(2 * length(x))
216+
# The type annotation is needed for JET.
217+
X = zeros(T, n, n)::Matrix{T}
218+
k = 1
219+
for i in 1:n
220+
for j in 1:i
221+
X[i, j] = conj(x[k])
222+
X[j, i] = x[k]
223+
k += 1
224+
end
225+
end
226+
return LinearAlgebra.Hermitian(X)
227+
end
228+
211229
# This is the minimal L2-norm.
212230
function distance_to_set(
213231
::ProjectionUpperBoundDistance,
@@ -553,7 +571,7 @@ function distance_to_set(
553571
MOI.PositiveSemidefiniteConeSquare,
554572
MOI.PositiveSemidefiniteConeTriangle,
555573
},
556-
) where {T<:Real}
574+
) where {T<:Union{Real,Complex}}
557575
_check_dimension(x, set)
558576
# We should return the norm of `A` defined by:
559577
# ```julia
@@ -565,8 +583,7 @@ function distance_to_set(
565583
# The norm should correspond to `MOI.Utilities.set_dot` so it's the
566584
# Frobenius norm, which is the Euclidean norm of the vector of eigenvalues.
567585
eigvals = LinearAlgebra.eigvals(_reshape(x, set))
568-
eigvals .= min.(zero(T), eigvals)
569-
return LinearAlgebra.norm(eigvals, 2)
586+
return LinearAlgebra.norm(min.(0, eigvals), 2)
570587
end
571588

572589
"""
@@ -702,3 +719,44 @@ function distance_to_set(
702719
push!(eigvals_neg, max(x[1] - x[2] * sum(log.(eigvals_pos)), zero(T)))
703720
return LinearAlgebra.norm(eigvals_neg, 2)
704721
end
722+
723+
"""
724+
distance_to_set(
725+
::ProjectionUpperBoundDistance,
726+
x::AbstractVector,
727+
set::MOI.Scaled{S},
728+
)
729+
730+
This is the distance in the un-scaled space.
731+
"""
732+
function distance_to_set(
733+
dist::ProjectionUpperBoundDistance,
734+
x::AbstractVector{T},
735+
set::MOI.Scaled{S},
736+
) where {T,S<:MOI.AbstractVectorSet}
737+
_check_dimension(x, set)
738+
scale = MOI.Utilities.SetDotScalingVector{T}(set.set)
739+
return distance_to_set(dist, x ./ scale, set.set)
740+
end
741+
742+
function distance_to_set(
743+
dist::ProjectionUpperBoundDistance,
744+
x::AbstractVector{T},
745+
set::MOI.HermitianPositiveSemidefiniteConeTriangle,
746+
) where {T<:Real}
747+
_check_dimension(x, set)
748+
output_set = MOI.PositiveSemidefiniteConeTriangle(set.side_dimension)
749+
y = zeros(Complex{T}, MOI.dimension(output_set))
750+
real_offset, imag_offset = 0, length(y)
751+
for col in 1:set.side_dimension
752+
for row in 1:col
753+
real_offset += 1
754+
y[real_offset] = x[real_offset]
755+
if row != col
756+
imag_offset += 1
757+
y[real_offset] += x[imag_offset] * im
758+
end
759+
end
760+
end
761+
return distance_to_set(dist, y, output_set)
762+
end

test/Utilities/distance_to_set.jl

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,8 @@ function test_positivesemidefiniteconesquare()
320320
[1.0, 0.0, 0.0, 1.0] => 0.0,
321321
[1.0, -1.0, -1.0, 1.0] => 0.0,
322322
[1.0, -2.0, -2.0, 1.0] => 1.0,
323-
[1.0, 1.1, 1.1, -2.3] => 2.633053201505194;
323+
[1.0, 1.1, 1.1, -2.3] => 2.633053201505194,
324+
[1.0, -2.0, -2.0, 1.0] => 1.0;
324325
mismatch = [1.0],
325326
)
326327
return
@@ -476,6 +477,44 @@ function test_LogDetConeSquare()
476477
return
477478
end
478479

480+
function test_Scaled()
481+
_test_set(
482+
MOI.Scaled(MOI.PositiveSemidefiniteConeTriangle(2)),
483+
[1.0, 0.0, 1.0] => 0.0,
484+
[1.0, -1.0, 1.0] => 0.0,
485+
[1.0, -2.0 * sqrt(2), 1.0] => 1.0,
486+
[1.0, 1.1 * sqrt(2), -2.3] => 2.633053201505194;
487+
mismatch = [1.0],
488+
)
489+
return
490+
end
491+
492+
function test_PositiveSemidefiniteConeTriangle_Complex()
493+
_test_set(
494+
MOI.PositiveSemidefiniteConeTriangle(2),
495+
ComplexF64[1.0, 0.0, 1.0] => 0.0,
496+
ComplexF64[1.0, -1.0, 1.0] => 0.0,
497+
ComplexF64[1.0, -2.0, 1.0] => 1.0,
498+
ComplexF64[1.0, 1.1, -2.3] => 2.633053201505194,
499+
ComplexF64[1.0, 1-im, 1.0] => 0.41421356237309537;
500+
mismatch = [1.0],
501+
)
502+
return
503+
end
504+
505+
function test_HermitianPositiveSemidefiniteConeTriangle()
506+
_test_set(
507+
MOI.HermitianPositiveSemidefiniteConeTriangle(2),
508+
[1.0, 0.0, 1.0, 0.0] => 0.0,
509+
[1.0, -1.0, 1.0, 0.0] => 0.0,
510+
[1.0, -2.0, 1.0, 0.0] => 1.0,
511+
[1.0, 1.1, -2.3, 0.0] => 2.633053201505194,
512+
[1.0, 1.0, 1.0, -1.0] => 0.41421356237309537;
513+
mismatch = [1.0],
514+
)
515+
return
516+
end
517+
479518
end
480519

481520
TestFeasibilityChecker.runtests()

0 commit comments

Comments
 (0)