Skip to content

Commit 7f1c089

Browse files
authored
[Utilities] add distance_to_set for MOI.Indicator (#2923)
1 parent 9263978 commit 7f1c089

2 files changed

Lines changed: 53 additions & 0 deletions

File tree

src/Utilities/distance_to_set.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,3 +579,32 @@ function distance_to_set(
579579
set.eval_f(y, x)
580580
return LinearAlgebra.norm(y .- clamp.(y, set.l, set.u))
581581
end
582+
583+
# This is the minimal L2-norm.
584+
function distance_to_set(
585+
distance::ProjectionUpperBoundDistance,
586+
x::AbstractVector{T},
587+
set::MOI.Indicator{MOI.ACTIVATE_ON_ONE},
588+
) where {T<:Real}
589+
_check_dimension(x, set)
590+
return min(
591+
# Distance of x[1] from 0
592+
abs(x[1]),
593+
# Distance of x[1] from 1 + distance to set
594+
sqrt((1 - x[1])^2 + distance_to_set(distance, x[2], set.set)^2),
595+
)
596+
end
597+
598+
function distance_to_set(
599+
distance::ProjectionUpperBoundDistance,
600+
x::AbstractVector{T},
601+
set::MOI.Indicator{MOI.ACTIVATE_ON_ZERO},
602+
) where {T}
603+
_check_dimension(x, set)
604+
return min(
605+
# Distance of x[1] from 1
606+
abs(one(T) - x[1]),
607+
# Distance of x[1] from 0 + distance to set
608+
sqrt(x[1]^2 + distance_to_set(distance, x[2], set.set)^2),
609+
)
610+
end

test/Utilities/distance_to_set.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,30 @@ function test_vectornonlinearoracle()
376376
return
377377
end
378378

379+
function test_indicator()
380+
_test_set(
381+
MOI.Indicator{MOI.ACTIVATE_ON_ZERO}(MOI.EqualTo(1.0)),
382+
[0.0, 1.0] => 0.0,
383+
[0.01, 1.0] => 0.01,
384+
[-0.01, 1.0] => 0.01,
385+
[-0.01, 1.1] => sqrt(0.01^2 + 0.1^2),
386+
[0.5, 1.1] => 0.5,
387+
[0.2, 1.1] => sqrt(0.2^2 + 0.1^2);
388+
mismatch = [1.0],
389+
)
390+
_test_set(
391+
MOI.Indicator{MOI.ACTIVATE_ON_ONE}(MOI.EqualTo(1.0)),
392+
[1.0, 1.0] => 0.0,
393+
[1.01, 1.0] => 0.01,
394+
[0.99, 1.0] => 0.01,
395+
[0.99, 1.1] => sqrt(0.01^2 + 0.1^2),
396+
[0.5, 1.1] => 0.5,
397+
[0.8, 1.1] => sqrt(0.2^2 + 0.1^2),
398+
mismatch = [1.0],
399+
)
400+
return
401+
end
402+
379403
end
380404

381405
TestFeasibilityChecker.runtests()

0 commit comments

Comments
 (0)