Skip to content

Commit e4079c3

Browse files
committed
Add benchmark phase diagnostics and scoring gate
1 parent 59cf290 commit e4079c3

4 files changed

Lines changed: 95 additions & 4 deletions

File tree

src/main/java/com/bioinceptionlabs/reactionblast/mapping/CallableAtomMappingTool.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package com.bioinceptionlabs.reactionblast.mapping;
2020

2121
import java.io.Serializable;
22+
import static java.lang.System.currentTimeMillis;
2223
import static java.lang.System.getProperty;
2324
import static java.util.Collections.unmodifiableMap;
2425
import java.util.EnumMap;
@@ -93,6 +94,7 @@ private void generateAtomAtomMapping(
9394
StandardizeReaction standardizer,
9495
boolean removeHydrogen,
9596
boolean checkComplex) {
97+
long mappingStart = currentTimeMillis();
9698
/*
9799
* Standardize the reaction ONCE.
98100
*/
@@ -199,6 +201,11 @@ private void generateAtomAtomMapping(
199201
LOGGER.debug("ERROR: in AtomMappingTool: " + e.getMessage());
200202
LOGGER.error(e);
201203
} finally {
204+
if (standardizedReaction != null && standardizedReaction.getID() != null) {
205+
MappingDiagnostics.recordMappingPhase(
206+
standardizedReaction.getID(),
207+
currentTimeMillis() - mappingStart);
208+
}
202209
executor.shutdown();
203210
LOGGER.debug("!!!!Atom-Atom Mapping Done!!!!");
204211
ThreadSafeCache.getInstance().cleanup();

src/main/java/com/bioinceptionlabs/reactionblast/mapping/MappingDiagnostics.java

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,19 @@ public static void recordQuickMappingSearch(String reactionId, String algorithm)
9191
.quickMappingSearches.incrementAndGet();
9292
}
9393

94+
public static void recordMappingPhase(String reactionId, long elapsedMillis) {
95+
reactionStats(reactionId).mappingPhaseMillis.set(elapsedMillis);
96+
}
97+
98+
public static void recordEvaluationPhase(String reactionId, long elapsedMillis) {
99+
reactionStats(reactionId).evaluationPhaseMillis.set(elapsedMillis);
100+
}
101+
94102
public static ReactionSnapshot snapshot(String reactionId) {
95103
ReactionStats stats = REACTIONS.get(reactionId);
96-
return stats == null ? new ReactionSnapshot(reactionId, Collections.emptyList()) : stats.snapshot(reactionId);
104+
return stats == null
105+
? new ReactionSnapshot(reactionId, 0L, 0L, Collections.emptyList())
106+
: stats.snapshot(reactionId);
97107
}
98108

99109
private static ReactionStats reactionStats(String reactionId) {
@@ -105,6 +115,8 @@ private static ReactionStats reactionStats(String reactionId) {
105115
private static final class ReactionStats {
106116

107117
private final ConcurrentMap<String, AlgorithmStats> algorithms = new ConcurrentHashMap<>();
118+
private final AtomicLong mappingPhaseMillis = new AtomicLong();
119+
private final AtomicLong evaluationPhaseMillis = new AtomicLong();
108120

109121
private AlgorithmStats algorithmStats(String algorithm) {
110122
String key = algorithm == null ? "UNKNOWN" : algorithm;
@@ -117,7 +129,11 @@ private ReactionSnapshot snapshot(String reactionId) {
117129
algorithmSnapshots.add(stats.snapshot());
118130
}
119131
algorithmSnapshots.sort(Comparator.comparing(snapshot -> snapshot.algorithm));
120-
return new ReactionSnapshot(reactionId, algorithmSnapshots);
132+
return new ReactionSnapshot(
133+
reactionId,
134+
mappingPhaseMillis.get(),
135+
evaluationPhaseMillis.get(),
136+
algorithmSnapshots);
121137
}
122138
}
123139

@@ -211,10 +227,17 @@ private MatcherInvocationSnapshot snapshot() {
211227
public static final class ReactionSnapshot {
212228

213229
public final String reactionId;
230+
public final long mappingPhaseMillis;
231+
public final long evaluationPhaseMillis;
214232
public final List<AlgorithmSnapshot> algorithms;
215233

216-
public ReactionSnapshot(String reactionId, List<AlgorithmSnapshot> algorithms) {
234+
public ReactionSnapshot(String reactionId,
235+
long mappingPhaseMillis,
236+
long evaluationPhaseMillis,
237+
List<AlgorithmSnapshot> algorithms) {
217238
this.reactionId = reactionId;
239+
this.mappingPhaseMillis = mappingPhaseMillis;
240+
this.evaluationPhaseMillis = evaluationPhaseMillis;
218241
this.algorithms = Collections.unmodifiableList(new ArrayList<>(algorithms));
219242
}
220243
}

src/main/java/com/bioinceptionlabs/reactionblast/mechanism/ReactionMechanismTool.java

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import java.io.Serializable;
2222
import static java.lang.Integer.MIN_VALUE;
23+
import static java.lang.System.currentTimeMillis;
2324

2425
import java.util.ArrayList;
2526
import java.util.Collection;
@@ -58,6 +59,7 @@
5859
import com.bioinceptionlabs.reactionblast.fingerprints.IPatternFingerprinter;
5960
import com.bioinceptionlabs.reactionblast.tools.StandardizeReaction;
6061
import com.bioinceptionlabs.reactionblast.mapping.CallableAtomMappingTool;
62+
import com.bioinceptionlabs.reactionblast.mapping.MappingDiagnostics;
6163
import com.bioinceptionlabs.reactionblast.mapping.Reactor;
6264
import com.bioinceptionlabs.reactionblast.mapping.IMappingAlgorithm;
6365
import static com.bioinceptionlabs.reactionblast.mapping.IMappingAlgorithm.USER_DEFINED;
@@ -278,6 +280,7 @@ && getAtomCount(reaction.getReactants())
278280
CallableAtomMappingTool amt = new CallableAtomMappingTool(reaction, standardizer,
279281
onlyCoreMappingByMCS, checkComplex);
280282
Map<IMappingAlgorithm, Reactor> solutions = amt.getSolutions();
283+
long evaluationStart = currentTimeMillis();
281284
List<EvaluationCandidate> orderedSolutions = orderSolutionsForEvaluation(solutions);
282285
List<EvaluationCandidate> candidates = collectCandidatesForEvaluation(orderedSolutions);
283286

@@ -289,6 +292,9 @@ && getAtomCount(reaction.getReactants())
289292
LOGGER.debug("is solution: " + mappingSolution.getAlgorithmID()
290293
+ " selected: " + selected);
291294
}
295+
MappingDiagnostics.recordEvaluationPhase(
296+
reaction.getID(),
297+
currentTimeMillis() - evaluationStart);
292298
} catch (Exception e) {
293299
LOGGER.error(SEVERE, "Bond change calculation error", e);
294300
throw new Exception(NEW_LINE + "ERROR: Unable to calculate bond changes: " + e.getMessage(), e);
@@ -927,13 +933,22 @@ private boolean isIdentityLike(List<EvaluationCandidate> candidates) {
927933
private List<EvaluationCandidate> limitCandidatesForFullScoring(
928934
List<EvaluationCandidate> candidates,
929935
boolean identityLike) {
930-
if (candidates.size() <= DEFAULT_FULL_SCORING_CANDIDATES) {
936+
if (candidates.size() <= 1) {
931937
return candidates;
932938
}
933939

934940
List<EvaluationCandidate> ranked = new ArrayList<>(candidates);
935941
ranked.sort(evaluationCandidateComparator(identityLike));
936942

943+
if (hasDominantTopCandidate(ranked, identityLike)) {
944+
LOGGER.debug("Top candidate dominates quick-score ranking; scoring 1 candidate only");
945+
return new ArrayList<>(ranked.subList(0, 1));
946+
}
947+
948+
if (candidates.size() <= DEFAULT_FULL_SCORING_CANDIDATES) {
949+
return ranked;
950+
}
951+
937952
int limit = Math.min(DEFAULT_FULL_SCORING_CANDIDATES, ranked.size());
938953
if (hasAmbiguousTopTier(ranked)) {
939954
limit = Math.min(MAX_FULL_SCORING_CANDIDATES, ranked.size());
@@ -957,6 +972,23 @@ private List<EvaluationCandidate> limitCandidatesForFullScoring(
957972
return retained;
958973
}
959974

975+
private boolean hasDominantTopCandidate(List<EvaluationCandidate> ranked, boolean identityLike) {
976+
if (identityLike || ranked.size() < 2) {
977+
return false;
978+
}
979+
980+
EvaluationCandidate best = ranked.get(0);
981+
EvaluationCandidate challenger = ranked.get(1);
982+
if (!best.coverage.isComplete() || !best.coverage.isBalancedMapped()) {
983+
return false;
984+
}
985+
if (!challenger.coverage.isComplete() || !challenger.coverage.isBalancedMapped()) {
986+
return true;
987+
}
988+
return !best.quickScore.isNear(challenger.quickScore)
989+
&& best.quickScore.totalScore() + 2 <= challenger.quickScore.totalScore();
990+
}
991+
960992
private boolean hasAmbiguousTopTier(List<EvaluationCandidate> ranked) {
961993
if (ranked.size() < 2) {
962994
return false;

src/test/java/com/bioinceptionlabs/aamtool/GoldenDatasetBenchmarkTest.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ public void benchmarkGoldenDataset() throws Exception {
100100
double totalQualityScore = 0;
101101
int qualityScored = 0;
102102
int mismatchReports = 0;
103+
int totalAlgorithmsExecuted = 0;
104+
Map<Integer, Integer> algorithmsPerReaction = new HashMap<>();
105+
Map<String, Integer> selectedAlgorithms = new HashMap<>();
106+
long totalMappingPhaseMs = 0;
107+
long totalEvaluationPhaseMs = 0;
103108

104109
long startTime = System.currentTimeMillis();
105110

@@ -140,9 +145,16 @@ public void benchmarkGoldenDataset() throws Exception {
140145
// Run RDT mapping
141146
ReactionMechanismTool rmt = performAtomAtomMapping(rdtRxn, "GOLDEN_" + (i + 1));
142147
MappingSolution solution = rmt.getSelectedSolution();
148+
MappingDiagnostics.ReactionSnapshot snapshot = MappingDiagnostics.snapshot("GOLDEN_" + (i + 1));
149+
int executedAlgorithms = snapshot.algorithms.size();
150+
totalAlgorithmsExecuted += executedAlgorithms;
151+
algorithmsPerReaction.merge(executedAlgorithms, 1, Integer::sum);
152+
totalMappingPhaseMs += snapshot.mappingPhaseMillis;
153+
totalEvaluationPhaseMs += snapshot.evaluationPhaseMillis;
143154

144155
if (solution != null && solution.getBondChangeCalculator() != null) {
145156
success++;
157+
selectedAlgorithms.merge(solution.getAlgorithmID().name(), 1, Integer::sum);
146158

147159
// --- Metric 1: Atom-level accuracy ---
148160
IReaction mappedRxn = solution.getReaction();
@@ -326,6 +338,14 @@ public void benchmarkGoldenDataset() throws Exception {
326338
System.out.println("Errors: " + errors);
327339
System.out.println("Speed: " + String.format("%.1f", rxnPerSec) + " rxn/sec");
328340
System.out.println("Total time: " + (totalTime / 1000) + "s");
341+
System.out.println("Avg algorithms/run: " + String.format("%.2f",
342+
total == 0 ? 0.0 : (double) totalAlgorithmsExecuted / total));
343+
System.out.println("Algorithms/reaction: " + formatDistribution(algorithmsPerReaction));
344+
System.out.println("Selected algorithms: " + formatDistribution(selectedAlgorithms));
345+
System.out.println("Avg mapping phase: " + String.format("%.1f ms",
346+
total == 0 ? 0.0 : (double) totalMappingPhaseMs / total));
347+
System.out.println("Avg evaluation phase: " + String.format("%.1f ms",
348+
total == 0 ? 0.0 : (double) totalEvaluationPhaseMs / total));
329349
System.out.println();
330350
System.out.println("=== Comparison with Published Results (Lin et al. 2022) ===");
331351
System.out.println("| Tool | Exact Match | Atom Acc. | Bond Acc. | Training | Deterministic |");
@@ -782,6 +802,15 @@ private double pct_d(int num, int den) {
782802
return den == 0 ? 0.0 : 100.0 * num / den;
783803
}
784804

805+
private String formatDistribution(Map<?, Integer> distribution) {
806+
List<String> entries = new ArrayList<>();
807+
for (Map.Entry<?, Integer> entry : distribution.entrySet()) {
808+
entries.add(entry.getKey() + "=" + entry.getValue());
809+
}
810+
Collections.sort(entries);
811+
return entries.toString();
812+
}
813+
785814
private static class GoldReaction {
786815
final String rxnBlock;
787816
GoldReaction(String rxnBlock) { this.rxnBlock = rxnBlock; }

0 commit comments

Comments
 (0)