Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,15 @@ export class MessageBridge {
sessionId?: string;
startedAt?: number;
outputFormat?: OutputFormat;
activeMessageId?: string;
} = {}
) {
this._onDone = onDone;
this._stateTracker = new StreamStateTracker({
sessionId: _options.sessionId,
startedAt: _options.startedAt,
hasOutputFormat: _options.outputFormat !== undefined,
activeMessageId: _options.activeMessageId,
});
}

Expand Down
2 changes: 2 additions & 0 deletions src/schemas/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ export const DroidWorkingStateChangedNotificationSchema = z
.object({
type: z.literal(SessionNotificationType.DROID_WORKING_STATE_CHANGED),
newState: z.nativeEnum(DroidWorkingState),
messageId: z.string().optional(),
requestId: z.string().optional(),
})
.passthrough();

Expand Down
8 changes: 5 additions & 3 deletions src/session.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { v4 as uuidv4 } from 'uuid';

import { DroidClient } from './client.js';
import { ConnectionError } from './errors.js';
import {
Expand Down Expand Up @@ -190,11 +192,13 @@ export class DroidSession {
throwIfAborted(options?.abortSignal);

const startedAt = Date.now();
const activeMessageId = options?.messageId ?? uuidv4();
const bridge = new MessageBridge(undefined, {
includePartialMessages: options?.includePartialMessages,
sessionId: this._sessionId,
startedAt,
outputFormat: options?.outputFormat,
activeMessageId,
});
const unsubscribe = this._client.onNotification(bridge.notificationHandler);
let resolveAbort: () => void = () => {};
Expand All @@ -210,9 +214,7 @@ export class DroidSession {
try {
await Promise.race([
this._client.addUserMessage({
...(options?.messageId !== undefined && {
messageId: options.messageId,
}),
messageId: activeMessageId,
text: prompt,
images: options?.images,
files: options?.files,
Expand Down
28 changes: 28 additions & 0 deletions src/stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ export interface DroidUserMessage {

export interface ToolResult {
readonly type: 'tool_result';
readonly messageId?: string;
readonly toolUseId: string;
readonly toolName: string;
readonly content: string | JsonValue[];
Expand All @@ -133,6 +134,8 @@ export interface ToolProgress {
export interface WorkingStateChanged {
readonly type: 'working_state_changed';
readonly state: DroidWorkingState;
readonly messageId?: string;
readonly requestId?: string;
}

export type TokenUsageUpdate = Readonly<
Expand Down Expand Up @@ -260,6 +263,7 @@ export type DroidResultSubtype =

interface DroidResultBase {
readonly type: 'result';
readonly messageId?: string;
readonly sessionId: string;
readonly durationMs: number;
readonly numTurns: number;
Expand Down Expand Up @@ -388,6 +392,7 @@ export function convertNotificationToStreamMessage(
case SessionNotificationType.TOOL_RESULT:
return {
type: DroidMessageType.ToolResult,
messageId: notification.messageId,
toolUseId: notification.toolUseId,
toolName: '',
content: normalizeToolResultContent(notification.content),
Expand All @@ -410,6 +415,12 @@ export function convertNotificationToStreamMessage(
return {
type: DroidMessageType.WorkingStateChanged,
state: notification.newState,
...(notification.messageId !== undefined && {
messageId: notification.messageId,
}),
...(notification.requestId !== undefined && {
requestId: notification.requestId,
}),
};

case SessionNotificationType.SESSION_TOKEN_USAGE_CHANGED: {
Expand Down Expand Up @@ -620,6 +631,7 @@ export class StreamStateTracker {
sessionId?: string;
startedAt?: number;
hasOutputFormat?: boolean;
activeMessageId?: string;
} = {}
) {}

Expand Down Expand Up @@ -676,6 +688,10 @@ export class StreamStateTracker {
}

if (message.type === DroidMessageType.WorkingStateChanged) {
if (!this.shouldUseWorkingStateForCompletion(message)) {
return { message, additional };
}

if (message.state !== DroidWorkingState.Idle) {
this.hasBeenNonIdle = true;
} else if (this.hasBeenNonIdle) {
Expand Down Expand Up @@ -709,6 +725,9 @@ export class StreamStateTracker {
const result = this.finalAssistantText || this.fullText;
const base = {
type: DroidMessageType.Result,
...(this.options.activeMessageId !== undefined && {
messageId: this.options.activeMessageId,
}),
sessionId: this.options.sessionId ?? '',
durationMs: Date.now() - (this.options.startedAt ?? Date.now()),
numTurns: this.numTurns,
Expand Down Expand Up @@ -762,6 +781,15 @@ export class StreamStateTracker {
}
}

private shouldUseWorkingStateForCompletion(
message: WorkingStateChanged
): boolean {
if (this.options.activeMessageId === undefined) {
return true;
}
return message.messageId === this.options.activeMessageId;
}

private resetTurnState(): void {
this.lastTokenUsage = null;
this.fullText = '';
Expand Down
4 changes: 2 additions & 2 deletions tests/helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ export function sendDefaultStreamSequence(
transport.injectMessage(
makeSessionNotification(
SessionNotificationType.DROID_WORKING_STATE_CHANGED,
{ newState: initialState }
{ newState: initialState, messageId }
)
);

Expand Down Expand Up @@ -337,7 +337,7 @@ export function sendDefaultStreamSequence(
transport.injectMessage(
makeSessionNotification(
SessionNotificationType.DROID_WORKING_STATE_CHANGED,
{ newState: finalState }
{ newState: finalState, messageId }
)
);
}
52 changes: 39 additions & 13 deletions tests/run.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function setupRunResponder(
transport: InMemoryTransport,
sessionId: string
): void {
wireTransportSend(transport, ({ method, id }) => {
wireTransportSend(transport, ({ method, id, params }) => {
if (method === DroidServerMethod.INITIALIZE_SESSION) {
queueMicrotask(() => {
transport.injectMessage(
Expand All @@ -42,6 +42,7 @@ function setupRunResponder(
transport.injectMessage(makeSuccessResponse(id, {}));
sendDefaultStreamSequence(transport, {
deltas: ['Run ', 'result'],
messageId: String(params['messageId']),
tokenUsageSessionId: sessionId,
});
});
Expand All @@ -62,6 +63,7 @@ describe('run()', () => {
const result = await run('Say hello', { transport });

expect(result.text).toBe('Run result');
expect(result.messageId).toEqual(expect.any(String));
expect(result.messages.length).toBeGreaterThan(0);
expect(result.tokenUsage).not.toBeNull();
expect(transport.isConnected).toBe(false);
Expand Down Expand Up @@ -159,7 +161,7 @@ describe('run()', () => {
const transport = new InMemoryTransport();
await transport.connect();

wireTransportSend(transport, ({ method, id }) => {
wireTransportSend(transport, ({ method, id, params }) => {
if (method === DroidServerMethod.INITIALIZE_SESSION) {
queueMicrotask(() => {
transport.injectMessage(
Expand All @@ -176,7 +178,10 @@ describe('run()', () => {
transport.injectMessage(
makeSessionNotification(
SessionNotificationType.DROID_WORKING_STATE_CHANGED,
{ newState: DroidWorkingState.StreamingAssistantMessage }
{
newState: DroidWorkingState.StreamingAssistantMessage,
messageId: params['messageId'],
}
)
);
transport.injectMessage(
Expand All @@ -189,7 +194,10 @@ describe('run()', () => {
transport.injectMessage(
makeSessionNotification(
SessionNotificationType.DROID_WORKING_STATE_CHANGED,
{ newState: DroidWorkingState.Idle }
{
newState: DroidWorkingState.Idle,
messageId: params['messageId'],
}
)
);
});
Expand Down Expand Up @@ -217,7 +225,7 @@ describe('run()', () => {
const transport = new InMemoryTransport();
await transport.connect();

wireTransportSend(transport, ({ method, id }) => {
wireTransportSend(transport, ({ method, id, params }) => {
if (method === DroidServerMethod.INITIALIZE_SESSION) {
queueMicrotask(() => {
transport.injectMessage(
Expand All @@ -234,7 +242,10 @@ describe('run()', () => {
transport.injectMessage(
makeSessionNotification(
SessionNotificationType.DROID_WORKING_STATE_CHANGED,
{ newState: DroidWorkingState.StreamingAssistantMessage }
{
newState: DroidWorkingState.StreamingAssistantMessage,
messageId: params['messageId'],
}
)
);
transport.injectMessage(
Expand All @@ -256,7 +267,10 @@ describe('run()', () => {
transport.injectMessage(
makeSessionNotification(
SessionNotificationType.DROID_WORKING_STATE_CHANGED,
{ newState: DroidWorkingState.Idle }
{
newState: DroidWorkingState.Idle,
messageId: params['messageId'],
}
)
);
});
Expand Down Expand Up @@ -304,7 +318,7 @@ describe('run()', () => {
const transport = new InMemoryTransport();
await transport.connect();

wireTransportSend(transport, ({ method, id }) => {
wireTransportSend(transport, ({ method, id, params }) => {
if (method === DroidServerMethod.INITIALIZE_SESSION) {
queueMicrotask(() => {
transport.injectMessage(
Expand All @@ -321,7 +335,10 @@ describe('run()', () => {
transport.injectMessage(
makeSessionNotification(
SessionNotificationType.DROID_WORKING_STATE_CHANGED,
{ newState: DroidWorkingState.StreamingAssistantMessage }
{
newState: DroidWorkingState.StreamingAssistantMessage,
messageId: params['messageId'],
}
)
);
transport.injectMessage(
Expand Down Expand Up @@ -350,7 +367,10 @@ describe('run()', () => {
transport.injectMessage(
makeSessionNotification(
SessionNotificationType.DROID_WORKING_STATE_CHANGED,
{ newState: DroidWorkingState.Idle }
{
newState: DroidWorkingState.Idle,
messageId: params['messageId'],
}
)
);
});
Expand Down Expand Up @@ -383,7 +403,7 @@ describe('run()', () => {
const transport = new InMemoryTransport();
await transport.connect();

wireTransportSend(transport, ({ method, id }) => {
wireTransportSend(transport, ({ method, id, params }) => {
if (method === DroidServerMethod.INITIALIZE_SESSION) {
queueMicrotask(() => {
transport.injectMessage(
Expand All @@ -400,7 +420,10 @@ describe('run()', () => {
transport.injectMessage(
makeSessionNotification(
SessionNotificationType.DROID_WORKING_STATE_CHANGED,
{ newState: DroidWorkingState.StreamingAssistantMessage }
{
newState: DroidWorkingState.StreamingAssistantMessage,
messageId: params['messageId'],
}
)
);
for (const textDelta of ['Hello ', 'beautiful ', 'world!']) {
Expand All @@ -418,7 +441,10 @@ describe('run()', () => {
transport.injectMessage(
makeSessionNotification(
SessionNotificationType.DROID_WORKING_STATE_CHANGED,
{ newState: DroidWorkingState.Idle }
{
newState: DroidWorkingState.Idle,
messageId: params['messageId'],
}
)
);
});
Expand Down
Loading
Loading