diff --git a/src/FSharp.Stats/Testing/ComparisonMetrics.fs b/src/FSharp.Stats/Testing/ComparisonMetrics.fs index 8b02e132..0c4dc97b 100644 --- a/src/FSharp.Stats/Testing/ComparisonMetrics.fs +++ b/src/FSharp.Stats/Testing/ComparisonMetrics.fs @@ -368,29 +368,28 @@ type ComparisonMetrics = { static member multiLabelThresholdMap( actual: #IConvertible [], - predictions: (#IConvertible * float []) [] - ) = - - // we have to use a global threshold collection for all binary threshold maps, otherwise we do not necessarily have values for macro/micro averaging for each label. - let allDistinctThresholds = + predictions: (#IConvertible * float []) [], + thresholds: float [] + ) = + // Use global max as prefix so micro/macro averages use a consistent threshold label. + let globalMax = predictions |> Array.map snd |> Array.concat - |> Array.distinct - |> Array.sortDescending + |> Array.max - let prefixedThresholds = [|allDistinctThresholds[0] + 1.; yield! allDistinctThresholds|] + let prefixedThresholds = [|globalMax + 1.; yield! thresholds|] let labelMetrics = - predictions + predictions |> Array.map (fun (label, preds) -> let labelTruth = actual |> Array.map (fun x -> x = label) - label, BinaryConfusionMatrix.thresholdMap(labelTruth,preds,allDistinctThresholds) + label, BinaryConfusionMatrix.thresholdMap(labelTruth, preds, thresholds) ) let transposedBCMs = labelMetrics - |> Array.map (fun (x,y) -> y) + |> Array.map snd |> JaggedArray.transpose |> JaggedArray.map snd @@ -405,12 +404,27 @@ type ComparisonMetrics = { |> Array.zip prefixedThresholds [| - yield! labelMetrics |> Array.map (fun (label, thrs) -> string label, thrs |> Array.map (fun (thr,bcm) -> thr, ComparisonMetrics.create bcm)) + yield! labelMetrics |> Array.map (fun (label, thrs) -> string label, thrs |> Array.map (fun (thr, bcm) -> thr, ComparisonMetrics.create bcm)) "micro-average", microAverages "macro-average", macroAverages |] |> Map.ofArray + static member multiLabelThresholdMap( + actual: #IConvertible [], + predictions: (#IConvertible * float []) [] + ) = + + // we have to use a global threshold collection for all binary threshold maps, otherwise we do not necessarily have values for macro/micro averaging for each label. + let allDistinctThresholds = + predictions + |> Array.map snd + |> Array.concat + |> Array.distinct + |> Array.sortDescending + + ComparisonMetrics.multiLabelThresholdMap(actual, predictions, allDistinctThresholds) + static member calculateROC( actual: seq, predictions: seq, @@ -439,6 +453,18 @@ type ComparisonMetrics = { metrics.FallOut, metrics.Sensitivity ) + static member calculateMultiLabelROC( + actual: #IConvertible [], + predictions: (#IConvertible * float []) [], + thresholds: float [] + ) = + ComparisonMetrics.multiLabelThresholdMap( + actual, + predictions, + thresholds + ) + |> Map.map (fun _k v -> v |> Array.map (fun (_,cm) -> cm.FallOut, cm.Sensitivity)) + static member calculateMultiLabelROC( actual: #IConvertible [], predictions: (#IConvertible * float []) [] @@ -447,6 +473,4 @@ type ComparisonMetrics = { actual, predictions ) - |> Map.map (fun k v -> v |> Array.map (fun (_,cm) -> cm.FallOut, cm.Sensitivity) - - ) \ No newline at end of file + |> Map.map (fun _k v -> v |> Array.map (fun (_,cm) -> cm.FallOut, cm.Sensitivity)) \ No newline at end of file diff --git a/tests/FSharp.Stats.Tests/Testing.fs b/tests/FSharp.Stats.Tests/Testing.fs index 1adfd493..488db789 100644 --- a/tests/FSharp.Stats.Tests/Testing.fs +++ b/tests/FSharp.Stats.Tests/Testing.fs @@ -1170,6 +1170,41 @@ let comparisonMetricsTests = testCase "C: threshold 0-1" (fun _ -> TestExtensions.comparisonMetricsEqualRounded 3 (snd (actual["C"][9])) (snd (expectedMetricsMap["C"][9])) "Incorrect metrics for threshold 0.1") testCase "C: threshold 0-0" (fun _ -> TestExtensions.comparisonMetricsEqualRounded 3 (snd (actual["C"][10])) (snd (expectedMetricsMap["C"][10])) "Incorrect metrics for threshold 0.0") ] + testList "multi-label threshold map with explicit thresholds" [ + // Use a coarse threshold list [0.9; 0.5; 0.1] — a subset of all distinct thresholds. + // Expected values are taken from the full-threshold test above (same data). + let explicitThresholds = [|0.9; 0.5; 0.1|] + let actualExplicit = + ComparisonMetrics.multiLabelThresholdMap( + actual = [|"A"; "A"; "A"; "A"; "A"; "B"; "B"; "B"; "C"; "C"; "C"; "C"; "C"; "C"|], + predictions = [| + "A", [|0.8; 0.7; 0.9; 0.4; 0.3; 0.1; 0.2; 0.5; 0.1; 0.1; 0.1; 0.3; 0.5; 0.4|] + "B", [|0.0; 0.1; 0.0; 0.5; 0.1; 0.8; 0.7; 0.4; 0.0; 0.1; 0.1; 0.0; 0.1; 0.3|] + "C", [|0.2; 0.2; 0.1; 0.1; 0.6; 0.1; 0.1; 0.1; 0.9; 0.8; 0.8; 0.7; 0.4; 0.3|] + |], + thresholds = explicitThresholds + ) + // With 3 explicit thresholds the result should have 4 entries per label (prefix + 3) + testCase "explicit thresholds: result length" (fun _ -> + Expect.equal actualExplicit["A"].Length 4 "Expected 4 threshold entries for label A with 3 explicit thresholds" + ) + // Values at threshold 0.9 should match the full-threshold result at that threshold + testCase "A: explicit threshold 0-9" (fun _ -> + TestExtensions.comparisonMetricsEqualRounded 3 (snd (actualExplicit["A"][1])) (BinaryConfusionMatrix.create(1,9,0,4) |> ComparisonMetrics.create) "Incorrect A metrics at threshold 0.9" + ) + testCase "B: explicit threshold 0-5" (fun _ -> + TestExtensions.comparisonMetricsEqualRounded 3 (snd (actualExplicit["B"][2])) (BinaryConfusionMatrix.create(2,10,1,1) |> ComparisonMetrics.create) "Incorrect B metrics at threshold 0.5" + ) + testCase "C: explicit threshold 0-1" (fun _ -> + TestExtensions.comparisonMetricsEqualRounded 3 (snd (actualExplicit["C"][3])) (BinaryConfusionMatrix.create(6,0,8,0) |> ComparisonMetrics.create) "Incorrect C metrics at threshold 0.1" + ) + testCase "micro-average present" (fun _ -> + Expect.isTrue (actualExplicit.ContainsKey("micro-average")) "micro-average key should be present" + ) + testCase "macro-average present" (fun _ -> + Expect.isTrue (actualExplicit.ContainsKey("macro-average")) "macro-average key should be present" + ) + ] ]