Skip to content

Commit 3338565

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Token count estimation fallback for tail retention compaction
PiperOrigin-RevId: 868224060
1 parent be35b22 commit 3338565

2 files changed

Lines changed: 79 additions & 34 deletions

File tree

core/src/main/java/com/google/adk/summarizer/TailRetentionEventCompactor.java

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -73,37 +73,12 @@ public Completable compact(Session session, BaseSessionService sessionService) {
7373
logger.debug("Running tail retention event compaction for session {}", session.id());
7474

7575
return Maybe.just(session.events())
76-
.filter(this::shouldCompact)
77-
.flatMap(events -> getCompactionEvents(events))
76+
.flatMap(this::getCompactionEvents)
7877
.flatMap(summarizer::summarizeEvents)
7978
.flatMapSingle(e -> sessionService.appendEvent(session, e))
8079
.ignoreElement();
8180
}
8281

83-
private boolean shouldCompact(List<Event> events) {
84-
int count = getLatestPromptTokenCount(events).orElse(0);
85-
86-
// TODO b/480013930 - Add a way to estimate the prompt token if the usage metadata is not
87-
// available.
88-
if (count <= tokenThreshold) {
89-
logger.debug(
90-
"Skipping compaction. Prompt token count {} is within threshold {}",
91-
count,
92-
tokenThreshold);
93-
return false;
94-
}
95-
return true;
96-
}
97-
98-
private Optional<Integer> getLatestPromptTokenCount(List<Event> events) {
99-
return Lists.reverse(events).stream()
100-
.map(Event::usageMetadata)
101-
.flatMap(Optional::stream)
102-
.map(GenerateContentResponseUsageMetadata::promptTokenCount)
103-
.flatMap(Optional::stream)
104-
.findFirst();
105-
}
106-
10782
/**
10883
* Identifies events to be compacted based on the tail retention strategy.
10984
*
@@ -161,8 +136,19 @@ private Optional<Integer> getLatestPromptTokenCount(List<Event> events) {
161136
* together. The new compaction event will cover the range from the start of the included
162137
* compaction event (C2, T=1) to the end of the new events (E4, T=4).
163138
* </ol>
139+
*
140+
* @param events The list of events to process.
164141
*/
165142
private Maybe<List<Event>> getCompactionEvents(List<Event> events) {
143+
Optional<Integer> count = getLatestPromptTokenCount(events);
144+
if (count.isPresent() && count.get() <= tokenThreshold) {
145+
logger.debug(
146+
"Skipping compaction. Prompt token count {} is within threshold {}",
147+
count.get(),
148+
tokenThreshold);
149+
return Maybe.empty();
150+
}
151+
166152
long compactionEndTimestamp = Long.MIN_VALUE;
167153
Event lastCompactionEvent = null;
168154
List<Event> eventsToSummarize = new ArrayList<>();
@@ -195,11 +181,6 @@ private Maybe<List<Event>> getCompactionEvents(List<Event> events) {
195181
}
196182
}
197183

198-
// If there are not enough events to summarize, we can return early.
199-
if (eventsToSummarize.size() <= retentionSize) {
200-
return Maybe.empty();
201-
}
202-
203184
// Add the last compaction event to the list of events to summarize.
204185
// This is to ensure that the last compaction event is included in the summary.
205186
if (lastCompactionEvent != null) {
@@ -214,6 +195,22 @@ private Maybe<List<Event>> getCompactionEvents(List<Event> events) {
214195

215196
Collections.reverse(eventsToSummarize);
216197

198+
if (count.isEmpty()) {
199+
int estimatedCount = estimateTokenCount(eventsToSummarize);
200+
if (estimatedCount <= tokenThreshold) {
201+
logger.debug(
202+
"Skipping compaction. Estimated prompt token count {} is within threshold {}",
203+
estimatedCount,
204+
tokenThreshold);
205+
return Maybe.empty();
206+
}
207+
}
208+
209+
// If there are not enough events to summarize, we can return early.
210+
if (eventsToSummarize.size() <= retentionSize) {
211+
return Maybe.empty();
212+
}
213+
217214
// Apply retention: keep the most recent 'retentionSize' events out of the summary.
218215
// We do this by removing them from the list of events to be summarized.
219216
eventsToSummarize
@@ -222,6 +219,22 @@ private Maybe<List<Event>> getCompactionEvents(List<Event> events) {
222219
return Maybe.just(eventsToSummarize);
223220
}
224221

222+
private int estimateTokenCount(List<Event> events) {
223+
// A common rule of thumb is that one token roughly corresponds to 4 characters of text for
224+
// common English text.
225+
// See https://platform.openai.com/tokenizer
226+
return events.stream().mapToInt(event -> event.stringifyContent().length()).sum() / 4;
227+
}
228+
229+
private Optional<Integer> getLatestPromptTokenCount(List<Event> events) {
230+
return Lists.reverse(events).stream()
231+
.map(Event::usageMetadata)
232+
.flatMap(Optional::stream)
233+
.map(GenerateContentResponseUsageMetadata::promptTokenCount)
234+
.flatMap(Optional::stream)
235+
.findFirst();
236+
}
237+
225238
private static boolean isCompactEvent(Event event) {
226239
return event.actions() != null && event.actions().compaction().isPresent();
227240
}

core/src/test/java/com/google/adk/summarizer/TailRetentionEventCompactorTest.java

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,13 @@ public void constructor_negativeRetentionSize_throwsException() {
7575
}
7676

7777
@Test
78-
// TODO: b/480013930 - Add a test case for estimating the prompt token if the usage metadata is
79-
// not available.
80-
public void compaction_skippedWhenTokenUsageMissing() {
78+
public void compaction_skippedWhenEstimatedTokenUsageBelowThreshold() {
79+
// Threshold is 100.
80+
// Event1: "Event1" -> length 6.
81+
// Retain1: "Retain1" -> length 7.
82+
// Retain2: "Retain2" -> length 7.
83+
// Total length = 20. Estimated tokens = 20 / 4 = 5.
84+
// 5 <= 100 -> Skip.
8185
EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 100);
8286
ImmutableList<Event> events =
8387
ImmutableList.of(
@@ -92,6 +96,34 @@ public void compaction_skippedWhenTokenUsageMissing() {
9296
verify(mockSessionService, never()).appendEvent(any(), any());
9397
}
9498

99+
@Test
100+
public void compaction_happensWhenEstimatedTokenUsageAboveThreshold() {
101+
// Threshold is 2.
102+
// Event1: "Event1" -> length 6.
103+
// Retain1: "Retain1" -> length 7.
104+
// Retain2: "Retain2" -> length 7.
105+
// Total eligible for estimation (including retained ones as per current logic):
106+
// Logic: getCompactionEvents returns [Event1, Retain1, Retain2] for estimation.
107+
// Total length = 20. Estimated tokens = 20 / 4 = 5.
108+
// 5 > 2 -> Compact.
109+
EventCompactor compactor = new TailRetentionEventCompactor(mockSummarizer, 2, 2);
110+
ImmutableList<Event> events =
111+
ImmutableList.of(
112+
createEvent(1, "Event1"),
113+
createEvent(2, "Retain1"),
114+
createEvent(3, "Retain2")); // No usage metadata
115+
Session session = Session.builder("id").events(events).build();
116+
Event summaryEvent = createEvent(4, "Summary");
117+
118+
when(mockSummarizer.summarizeEvents(any())).thenReturn(Maybe.just(summaryEvent));
119+
when(mockSessionService.appendEvent(any(), any())).thenReturn(Single.just(summaryEvent));
120+
121+
compactor.compact(session, mockSessionService).blockingSubscribe();
122+
123+
verify(mockSummarizer).summarizeEvents(any());
124+
verify(mockSessionService).appendEvent(eq(session), eq(summaryEvent));
125+
}
126+
95127
@Test
96128
public void compaction_skippedWhenTokenUsageBelowThreshold() {
97129
// Threshold is 300, usage is 200.

0 commit comments

Comments
 (0)