diff --git a/src/helpers.ts b/src/helpers.ts index 9642de3..0e45abe 100644 --- a/src/helpers.ts +++ b/src/helpers.ts @@ -68,6 +68,7 @@ export class MessageBridge { sessionId?: string; startedAt?: number; outputFormat?: OutputFormat; + activeMessageId?: string; } = {} ) { this._onDone = onDone; @@ -75,6 +76,7 @@ export class MessageBridge { sessionId: _options.sessionId, startedAt: _options.startedAt, hasOutputFormat: _options.outputFormat !== undefined, + activeMessageId: _options.activeMessageId, }); } diff --git a/src/schemas/server.ts b/src/schemas/server.ts index e4afabe..c3a2ae7 100644 --- a/src/schemas/server.ts +++ b/src/schemas/server.ts @@ -161,6 +161,8 @@ export const DroidWorkingStateChangedNotificationSchema = z .object({ type: z.literal(SessionNotificationType.DROID_WORKING_STATE_CHANGED), newState: z.nativeEnum(DroidWorkingState), + messageId: z.string().optional(), + requestId: z.string().optional(), }) .passthrough(); diff --git a/src/session.ts b/src/session.ts index a25b545..51dc4d2 100644 --- a/src/session.ts +++ b/src/session.ts @@ -1,3 +1,5 @@ +import { v4 as uuidv4 } from 'uuid'; + import { DroidClient } from './client.js'; import { ConnectionError } from './errors.js'; import { @@ -190,11 +192,13 @@ export class DroidSession { throwIfAborted(options?.abortSignal); const startedAt = Date.now(); + const activeMessageId = options?.messageId ?? uuidv4(); const bridge = new MessageBridge(undefined, { includePartialMessages: options?.includePartialMessages, sessionId: this._sessionId, startedAt, outputFormat: options?.outputFormat, + activeMessageId, }); const unsubscribe = this._client.onNotification(bridge.notificationHandler); let resolveAbort: () => void = () => {}; @@ -210,9 +214,7 @@ export class DroidSession { try { await Promise.race([ this._client.addUserMessage({ - ...(options?.messageId !== undefined && { - messageId: options.messageId, - }), + messageId: activeMessageId, text: prompt, images: options?.images, files: options?.files, diff --git a/src/stream.ts b/src/stream.ts index d1af7d7..6c57d1b 100644 --- a/src/stream.ts +++ b/src/stream.ts @@ -116,6 +116,7 @@ export interface DroidUserMessage { export interface ToolResult { readonly type: 'tool_result'; + readonly messageId?: string; readonly toolUseId: string; readonly toolName: string; readonly content: string | JsonValue[]; @@ -133,6 +134,8 @@ export interface ToolProgress { export interface WorkingStateChanged { readonly type: 'working_state_changed'; readonly state: DroidWorkingState; + readonly messageId?: string; + readonly requestId?: string; } export type TokenUsageUpdate = Readonly< @@ -260,6 +263,7 @@ export type DroidResultSubtype = interface DroidResultBase { readonly type: 'result'; + readonly messageId?: string; readonly sessionId: string; readonly durationMs: number; readonly numTurns: number; @@ -388,6 +392,7 @@ export function convertNotificationToStreamMessage( case SessionNotificationType.TOOL_RESULT: return { type: DroidMessageType.ToolResult, + messageId: notification.messageId, toolUseId: notification.toolUseId, toolName: '', content: normalizeToolResultContent(notification.content), @@ -410,6 +415,12 @@ export function convertNotificationToStreamMessage( return { type: DroidMessageType.WorkingStateChanged, state: notification.newState, + ...(notification.messageId !== undefined && { + messageId: notification.messageId, + }), + ...(notification.requestId !== undefined && { + requestId: notification.requestId, + }), }; case SessionNotificationType.SESSION_TOKEN_USAGE_CHANGED: { @@ -620,6 +631,7 @@ export class StreamStateTracker { sessionId?: string; startedAt?: number; hasOutputFormat?: boolean; + activeMessageId?: string; } = {} ) {} @@ -676,6 +688,10 @@ export class StreamStateTracker { } if (message.type === DroidMessageType.WorkingStateChanged) { + if (!this.shouldUseWorkingStateForCompletion(message)) { + return { message, additional }; + } + if (message.state !== DroidWorkingState.Idle) { this.hasBeenNonIdle = true; } else if (this.hasBeenNonIdle) { @@ -709,6 +725,9 @@ export class StreamStateTracker { const result = this.finalAssistantText || this.fullText; const base = { type: DroidMessageType.Result, + ...(this.options.activeMessageId !== undefined && { + messageId: this.options.activeMessageId, + }), sessionId: this.options.sessionId ?? '', durationMs: Date.now() - (this.options.startedAt ?? Date.now()), numTurns: this.numTurns, @@ -762,6 +781,15 @@ export class StreamStateTracker { } } + private shouldUseWorkingStateForCompletion( + message: WorkingStateChanged + ): boolean { + if (this.options.activeMessageId === undefined) { + return true; + } + return message.messageId === this.options.activeMessageId; + } + private resetTurnState(): void { this.lastTokenUsage = null; this.fullText = ''; diff --git a/tests/helpers.ts b/tests/helpers.ts index 3f04686..37bfc5a 100644 --- a/tests/helpers.ts +++ b/tests/helpers.ts @@ -284,7 +284,7 @@ export function sendDefaultStreamSequence( transport.injectMessage( makeSessionNotification( SessionNotificationType.DROID_WORKING_STATE_CHANGED, - { newState: initialState } + { newState: initialState, messageId } ) ); @@ -337,7 +337,7 @@ export function sendDefaultStreamSequence( transport.injectMessage( makeSessionNotification( SessionNotificationType.DROID_WORKING_STATE_CHANGED, - { newState: finalState } + { newState: finalState, messageId } ) ); } diff --git a/tests/run.test.ts b/tests/run.test.ts index a0557c1..51c7b4f 100644 --- a/tests/run.test.ts +++ b/tests/run.test.ts @@ -22,7 +22,7 @@ function setupRunResponder( transport: InMemoryTransport, sessionId: string ): void { - wireTransportSend(transport, ({ method, id }) => { + wireTransportSend(transport, ({ method, id, params }) => { if (method === DroidServerMethod.INITIALIZE_SESSION) { queueMicrotask(() => { transport.injectMessage( @@ -42,6 +42,7 @@ function setupRunResponder( transport.injectMessage(makeSuccessResponse(id, {})); sendDefaultStreamSequence(transport, { deltas: ['Run ', 'result'], + messageId: String(params['messageId']), tokenUsageSessionId: sessionId, }); }); @@ -62,6 +63,7 @@ describe('run()', () => { const result = await run('Say hello', { transport }); expect(result.text).toBe('Run result'); + expect(result.messageId).toEqual(expect.any(String)); expect(result.messages.length).toBeGreaterThan(0); expect(result.tokenUsage).not.toBeNull(); expect(transport.isConnected).toBe(false); @@ -159,7 +161,7 @@ describe('run()', () => { const transport = new InMemoryTransport(); await transport.connect(); - wireTransportSend(transport, ({ method, id }) => { + wireTransportSend(transport, ({ method, id, params }) => { if (method === DroidServerMethod.INITIALIZE_SESSION) { queueMicrotask(() => { transport.injectMessage( @@ -176,7 +178,10 @@ describe('run()', () => { transport.injectMessage( makeSessionNotification( SessionNotificationType.DROID_WORKING_STATE_CHANGED, - { newState: DroidWorkingState.StreamingAssistantMessage } + { + newState: DroidWorkingState.StreamingAssistantMessage, + messageId: params['messageId'], + } ) ); transport.injectMessage( @@ -189,7 +194,10 @@ describe('run()', () => { transport.injectMessage( makeSessionNotification( SessionNotificationType.DROID_WORKING_STATE_CHANGED, - { newState: DroidWorkingState.Idle } + { + newState: DroidWorkingState.Idle, + messageId: params['messageId'], + } ) ); }); @@ -217,7 +225,7 @@ describe('run()', () => { const transport = new InMemoryTransport(); await transport.connect(); - wireTransportSend(transport, ({ method, id }) => { + wireTransportSend(transport, ({ method, id, params }) => { if (method === DroidServerMethod.INITIALIZE_SESSION) { queueMicrotask(() => { transport.injectMessage( @@ -234,7 +242,10 @@ describe('run()', () => { transport.injectMessage( makeSessionNotification( SessionNotificationType.DROID_WORKING_STATE_CHANGED, - { newState: DroidWorkingState.StreamingAssistantMessage } + { + newState: DroidWorkingState.StreamingAssistantMessage, + messageId: params['messageId'], + } ) ); transport.injectMessage( @@ -256,7 +267,10 @@ describe('run()', () => { transport.injectMessage( makeSessionNotification( SessionNotificationType.DROID_WORKING_STATE_CHANGED, - { newState: DroidWorkingState.Idle } + { + newState: DroidWorkingState.Idle, + messageId: params['messageId'], + } ) ); }); @@ -304,7 +318,7 @@ describe('run()', () => { const transport = new InMemoryTransport(); await transport.connect(); - wireTransportSend(transport, ({ method, id }) => { + wireTransportSend(transport, ({ method, id, params }) => { if (method === DroidServerMethod.INITIALIZE_SESSION) { queueMicrotask(() => { transport.injectMessage( @@ -321,7 +335,10 @@ describe('run()', () => { transport.injectMessage( makeSessionNotification( SessionNotificationType.DROID_WORKING_STATE_CHANGED, - { newState: DroidWorkingState.StreamingAssistantMessage } + { + newState: DroidWorkingState.StreamingAssistantMessage, + messageId: params['messageId'], + } ) ); transport.injectMessage( @@ -350,7 +367,10 @@ describe('run()', () => { transport.injectMessage( makeSessionNotification( SessionNotificationType.DROID_WORKING_STATE_CHANGED, - { newState: DroidWorkingState.Idle } + { + newState: DroidWorkingState.Idle, + messageId: params['messageId'], + } ) ); }); @@ -383,7 +403,7 @@ describe('run()', () => { const transport = new InMemoryTransport(); await transport.connect(); - wireTransportSend(transport, ({ method, id }) => { + wireTransportSend(transport, ({ method, id, params }) => { if (method === DroidServerMethod.INITIALIZE_SESSION) { queueMicrotask(() => { transport.injectMessage( @@ -400,7 +420,10 @@ describe('run()', () => { transport.injectMessage( makeSessionNotification( SessionNotificationType.DROID_WORKING_STATE_CHANGED, - { newState: DroidWorkingState.StreamingAssistantMessage } + { + newState: DroidWorkingState.StreamingAssistantMessage, + messageId: params['messageId'], + } ) ); for (const textDelta of ['Hello ', 'beautiful ', 'world!']) { @@ -418,7 +441,10 @@ describe('run()', () => { transport.injectMessage( makeSessionNotification( SessionNotificationType.DROID_WORKING_STATE_CHANGED, - { newState: DroidWorkingState.Idle } + { + newState: DroidWorkingState.Idle, + messageId: params['messageId'], + } ) ); }); diff --git a/tests/session.test.ts b/tests/session.test.ts index f2d0973..eced128 100644 --- a/tests/session.test.ts +++ b/tests/session.test.ts @@ -105,11 +105,14 @@ function setupLoadResponder( function setupFullResponder( transport: InMemoryTransport, sessionId: string, - responseMethods?: Record void> + responseMethods?: Record< + string, + (id: string, params: Record) => void + > ): void { - wireTransportSend(transport, ({ method, id }) => { + wireTransportSend(transport, ({ method, id, params }) => { if (responseMethods && responseMethods[method]) { - responseMethods[method](id); + responseMethods[method](id, params); return; } @@ -137,6 +140,7 @@ function setupFullResponder( transport.injectMessage(makeSuccessResponse(id, {})); sendDefaultStreamSequence(transport, { deltas: ['Hello world'], + messageId: String(params['messageId']), tokenUsageSessionId: sessionId, }); }); @@ -388,7 +392,7 @@ describe('DroidSession', () => { await session.close(); }); - it('omits messageId from addUserMessage RPC params by default', async () => { + it('generates messageId for addUserMessage RPC params by default', async () => { const transport = new InMemoryTransport(); await transport.connect(); @@ -406,7 +410,7 @@ describe('DroidSession', () => { DroidServerMethod.ADD_USER_MESSAGE ) as Record; const addParams = addMsg['params'] as Record; - expect(addParams).not.toHaveProperty('messageId'); + expect(addParams['messageId']).toEqual(expect.any(String)); expect(addParams['text']).toBe('Hello'); await session.close(); @@ -453,7 +457,7 @@ describe('DroidSession', () => { const transport = new InMemoryTransport(); await transport.connect(); - wireTransportSend(transport, ({ method, id }) => { + wireTransportSend(transport, ({ method, id, params }) => { if (method === DroidServerMethod.INITIALIZE_SESSION) { queueMicrotask(() => { transport.injectMessage( @@ -471,7 +475,9 @@ describe('DroidSession', () => { } else if (method === DroidServerMethod.ADD_USER_MESSAGE) { queueMicrotask(() => { transport.injectMessage(makeSuccessResponse(id, {})); - sendDefaultStreamSequence(transport); + sendDefaultStreamSequence(transport, { + messageId: String(params['messageId']), + }); }); } else if (method === DroidServerMethod.CLOSE_SESSION) { queueMicrotask(() => { @@ -527,7 +533,7 @@ describe('DroidSession', () => { const transport = new InMemoryTransport(); await transport.connect(); - wireTransportSend(transport, ({ method, id }) => { + wireTransportSend(transport, ({ method, id, params }) => { if (method === DroidServerMethod.INITIALIZE_SESSION) { queueMicrotask(() => { transport.injectMessage( @@ -544,6 +550,7 @@ describe('DroidSession', () => { transport.injectMessage(makeSuccessResponse(id, {})); sendDefaultStreamSequence(transport, { deltas: [], + messageId: String(params['messageId']), includeTokenUsage: false, structuredOutputMessageId: 'msg-structured', structuredOutput: { name: 'Ada' }, @@ -583,7 +590,7 @@ describe('DroidSession', () => { const transport = new InMemoryTransport(); await transport.connect(); - wireTransportSend(transport, ({ method, id }) => { + wireTransportSend(transport, ({ method, id, params }) => { if (method === DroidServerMethod.INITIALIZE_SESSION) { queueMicrotask(() => { transport.injectMessage( @@ -600,6 +607,7 @@ describe('DroidSession', () => { transport.injectMessage(makeSuccessResponse(id, {})); sendDefaultStreamSequence(transport, { deltas: [], + messageId: String(params['messageId']), includeTokenUsage: false, structuredOutputMessageId: 'msg-structured', structuredOutputError: { @@ -1242,6 +1250,7 @@ describe('DroidSession', () => { const msg = message as Record; const method = msg['method'] as string; const id = msg['id'] as string; + const params = msg['params'] as Record; if (method === DroidServerMethod.INITIALIZE_SESSION) { queueMicrotask(() => { @@ -1260,7 +1269,10 @@ describe('DroidSession', () => { transport.injectMessage( makeSessionNotification( SessionNotificationType.DROID_WORKING_STATE_CHANGED, - { newState: DroidWorkingState.StreamingAssistantMessage } + { + newState: DroidWorkingState.StreamingAssistantMessage, + messageId: params['messageId'], + } ) ); @@ -1278,7 +1290,10 @@ describe('DroidSession', () => { transport.injectMessage( makeSessionNotification( SessionNotificationType.DROID_WORKING_STATE_CHANGED, - { newState: DroidWorkingState.Idle } + { + newState: DroidWorkingState.Idle, + messageId: params['messageId'], + } ) ); }); @@ -1395,6 +1410,7 @@ describe('DroidSession', () => { const msg = message as Record; const method = msg['method'] as string; const id = msg['id'] as string; + const params = msg['params'] as Record; if (method === DroidServerMethod.INITIALIZE_SESSION) { queueMicrotask(() => { @@ -1416,7 +1432,10 @@ describe('DroidSession', () => { transport.injectMessage( makeSessionNotification( SessionNotificationType.DROID_WORKING_STATE_CHANGED, - { newState: DroidWorkingState.StreamingAssistantMessage } + { + newState: DroidWorkingState.StreamingAssistantMessage, + messageId: params['messageId'], + } ) ); @@ -1450,7 +1469,10 @@ describe('DroidSession', () => { transport.injectMessage( makeSessionNotification( SessionNotificationType.DROID_WORKING_STATE_CHANGED, - { newState: DroidWorkingState.Idle } + { + newState: DroidWorkingState.Idle, + messageId: params['messageId'], + } ) ); }); @@ -1598,7 +1620,7 @@ describe('DroidSession', () => { let addUserMessageCount = 0; setupFullResponder(transport, 'sess-recovery', { - [DroidServerMethod.ADD_USER_MESSAGE]: (id) => { + [DroidServerMethod.ADD_USER_MESSAGE]: (id, params) => { addUserMessageCount++; if (addUserMessageCount === 1) { queueMicrotask(() => { @@ -1613,7 +1635,10 @@ describe('DroidSession', () => { transport.injectMessage( makeSessionNotification( SessionNotificationType.DROID_WORKING_STATE_CHANGED, - { newState: DroidWorkingState.StreamingAssistantMessage } + { + newState: DroidWorkingState.StreamingAssistantMessage, + messageId: params['messageId'], + } ) ); @@ -1631,7 +1656,10 @@ describe('DroidSession', () => { transport.injectMessage( makeSessionNotification( SessionNotificationType.DROID_WORKING_STATE_CHANGED, - { newState: DroidWorkingState.Idle } + { + newState: DroidWorkingState.Idle, + messageId: params['messageId'], + } ) ); }); diff --git a/tests/stream.test.ts b/tests/stream.test.ts index bd4c5ee..7f35cb3 100644 --- a/tests/stream.test.ts +++ b/tests/stream.test.ts @@ -752,6 +752,26 @@ describe('convertNotificationToStreamMessage', () => { expect(result.state).toBe(DroidWorkingState.ExecutingTool); }); + it('preserves message and request correlation fields', () => { + const notification = makeNotification( + SessionNotificationType.DROID_WORKING_STATE_CHANGED, + { + newState: DroidWorkingState.Idle, + messageId: 'msg-active', + requestId: 'req-active', + } + ); + const result = convertNotificationToStreamMessage( + notification + ) as WorkingStateChanged; + expect(result).toMatchObject({ + type: 'working_state_changed', + state: DroidWorkingState.Idle, + messageId: 'msg-active', + requestId: 'req-active', + }); + }); + it('handles Idle state', () => { const notification = makeNotification( SessionNotificationType.DROID_WORKING_STATE_CHANGED, @@ -1501,6 +1521,66 @@ describe('StreamStateTracker', () => { expect(result.additional[0].type).toBe('result'); }); + it('does NOT emit Result for stale uncorrelated idle with active messageId', () => { + tracker = new StreamStateTracker({ activeMessageId: 'msg-active' }); + + tracker.processMessage({ + type: 'working_state_changed', + state: DroidWorkingState.StreamingAssistantMessage, + messageId: 'msg-active', + }); + const result = tracker.processMessage({ + type: 'working_state_changed', + state: DroidWorkingState.Idle, + }); + + expect(result.additional).toEqual([]); + }); + + it('does NOT emit Result for another messageId with active messageId', () => { + tracker = new StreamStateTracker({ activeMessageId: 'msg-active' }); + + tracker.processMessage({ + type: 'working_state_changed', + state: DroidWorkingState.StreamingAssistantMessage, + messageId: 'msg-active', + }); + const result = tracker.processMessage({ + type: 'working_state_changed', + state: DroidWorkingState.Idle, + messageId: 'msg-stale', + }); + + expect(result.additional).toEqual([]); + }); + + it('emits Result only for matching active messageId', () => { + tracker = new StreamStateTracker({ activeMessageId: 'msg-active' }); + + tracker.processMessage({ + type: 'working_state_changed', + state: DroidWorkingState.StreamingAssistantMessage, + messageId: 'msg-active', + }); + const stale = tracker.processMessage({ + type: 'working_state_changed', + state: DroidWorkingState.Idle, + messageId: 'msg-stale', + }); + const active = tracker.processMessage({ + type: 'working_state_changed', + state: DroidWorkingState.Idle, + messageId: 'msg-active', + }); + + expect(stale.additional).toEqual([]); + expect(active.additional).toHaveLength(1); + expect(active.additional[0]).toMatchObject({ + type: 'result', + messageId: 'msg-active', + }); + }); + it('can emit Result again after reset', () => { tracker.processMessage({ type: 'working_state_changed',